diff --git a/csrc/search/device_mesh.cpp b/csrc/search/device_mesh.cpp deleted file mode 100644 index dd260d9..0000000 --- a/csrc/search/device_mesh.cpp +++ /dev/null @@ -1,87 +0,0 @@ -#include -#include -#include - -// DeviceMesh::DeviceMesh() -// : device_mesh_name(""), n_nodes(0), n_gpus(0), node_names({}), gpu_ids({}) { -// }; - -DeviceMesh::DeviceMesh(int n_nodes, int n_gpus_per_node, std::vector> mapping, - std::string global_mesh_name, std::string name) - : n_nodes(n_nodes), - n_gpus_per_node(n_gpus_per_node), - mapping(mapping), - global_mesh_name(global_mesh_name), - name(name) { - assert(n_nodes == static_cast(mapping.size())); - for (int i = 0; i < n_nodes; i++) { - assert(n_gpus_per_node == static_cast(mapping[i].size())); - } -}; - -bool is_all_overlap(std::vector device_meshes, DeviceMesh device_mesh) { - for (DeviceMesh *other : device_meshes) { - if (!device_mesh.overlap(*other)) return false; - } - return true; -}; - -bool is_all_overlap(std::unordered_set device_meshes, DeviceMesh device_mesh) { - for (DeviceMesh *other : device_meshes) { - if (!device_mesh.overlap(*other)) return false; - } - return true; -}; - -bool DeviceMesh::contain(const DeviceMesh &other) { - // check whether one device mapping is contained by another by - // checking 1. whether global_mesh_name is identical - // 2. whether mapping of one device mesh is contained by the other one - if (global_mesh_name != other.global_mesh_name) return false; - for (int i = 0; i < n_nodes; i++) { - for (int j = 0; j < n_gpus_per_node; j++) { - if (mapping[i][j] == 0 && other.mapping[i][j] == 1) return false; - } - } - return true; -}; - -bool DeviceMesh::contained_by(const DeviceMesh &other) { - if (global_mesh_name != other.global_mesh_name) return false; - for (int i = 0; i < n_nodes; i++) { - for (int j = 0; j < n_gpus_per_node; j++) { - if (mapping[i][j] == 1 && other.mapping[i][j] == 0) return false; - } - } - return true; -}; - -bool DeviceMesh::overlap(const DeviceMesh &other) { - if (global_mesh_name != other.global_mesh_name) return false; - for (int i = 0; i < n_nodes; i++) { - for (int j = 0; j < n_gpus_per_node; j++) { - if (mapping[i][j] == 1 && other.mapping[i][j] == 1) return true; - } - } - return false; -}; - -ModelParallelStrategy::ModelParallelStrategy(int num_pp, int num_dp, int num_mp) - : num_pp(num_pp), num_dp(num_dp), num_mp(num_mp) {}; - -bool ModelParallelStrategy::operator==(const ModelParallelStrategy &other) const { - return num_pp == other.num_pp && num_dp == other.num_dp && num_mp == other.num_mp; -}; - -bool DeviceMesh::operator==(const DeviceMesh &other) const { - return name == other.name && global_mesh_name == other.global_mesh_name; -}; - -std::string ModelParallelStrategy::to_string() { - return "num_pp:" + std::to_string(num_pp) + ";" + "num_dp:" + std::to_string(num_dp) + ";" - + "num_mp:" + std::to_string(num_mp); -}; - -std::string ModelParallelStrategy::to_key() { - return std::to_string(num_pp) + "," + std::to_string(num_mp) + "," + std::to_string(num_dp); -} \ No newline at end of file diff --git a/csrc/search/device_mesh.hpp b/csrc/search/device_mesh.hpp deleted file mode 100644 index 0dc1994..0000000 --- a/csrc/search/device_mesh.hpp +++ /dev/null @@ -1,49 +0,0 @@ -#ifndef DEVICE_MESH_HPP -#define DEVICE_MESH_HPP - -#include -#include -#include -#include -// #include - -class RPCInstance; - -class DeviceMesh { - public: - int n_nodes; - int n_gpus_per_node; - std::vector> mapping; - std::string global_mesh_name; - std::string name; - RPCInstance *pre_task = nullptr; - - // DeviceMesh(); - DeviceMesh(int n_nodes, int n_gpus_per_node, std::vector> mapping, - std::string global_mesh_name, std::string name); - - bool overlap(const DeviceMesh &other); - bool contain(const DeviceMesh &other); - bool contained_by(const DeviceMesh &other); - - bool operator==(const DeviceMesh &other) const; -}; - -bool is_all_overlap(std::vector device_meshes, DeviceMesh device_mesh); -bool is_all_overlap(std::unordered_set device_meshes, DeviceMesh device_mesh); - -class ModelParallelStrategy { - public: - int num_pp, num_dp, num_mp; - - ModelParallelStrategy(int num_pp, int num_dp, int num_mp); - - bool operator==(const ModelParallelStrategy &other) const; - - std::string to_string(); - std::string to_key(); -}; - -class ModelDeviceMapping {}; - -#endif // DEVICE_MESH_HPP \ No newline at end of file diff --git a/csrc/search/rpc.cpp b/csrc/search/rpc.cpp deleted file mode 100644 index 2d3ad88..0000000 --- a/csrc/search/rpc.cpp +++ /dev/null @@ -1,233 +0,0 @@ -#include -#include -#include -#include -#include -#include - -RPC::RPC(std::string model_name, std::string rpc_name, std::string interface_type) - : model_name(model_name), rpc_name(rpc_name), interface_type(interface_type) {}; - -RPCExecution::RPCExecution(RPC *rpc_ptr, DeviceMesh &device_mesh, - ModelParallelStrategy &model_parallel_strategy, uint64_t time_cost, - uint64_t mem, uint64_t static_mem) - : rpc_ptr(rpc_ptr), - device_mesh(device_mesh), - model_parallel_strategy(model_parallel_strategy), - time_cost(time_cost), - mem(mem), - static_mem(static_mem) {}; - -bool OverlapGroup::maybe_add(RPCExecution *rpc_exe) { - if (rpc_executions.empty()) { - rpc_executions.insert(rpc_exe); - device_meshes.insert(&rpc_exe->device_mesh); - mem_static = rpc_exe->static_mem; - mem_active = rpc_exe->mem - rpc_exe->static_mem; - return true; - } - if (is_all_overlap(device_meshes, rpc_exe->device_mesh)) { - rpc_executions.insert(rpc_exe); - // bool dm_in_group = device_meshes.find(&rpc_exe -> device_mesh) != device_meshes.end(); - device_meshes.insert(&rpc_exe->device_mesh); - mem_static += rpc_exe->static_mem; - mem_active = std::max(mem_active, rpc_exe->mem - rpc_exe->static_mem); - return true; - } - return false; -}; - -void DeviceMeshGroup::add_to_groups(RPCExecution *rpc_exe) { - if (overlap_groups.empty()) { - OverlapGroup *og = new OverlapGroup(); - og->maybe_add(rpc_exe); - overlap_groups.push_back(og); - return; - } - - std::vector tmp_new_ogs; - for (OverlapGroup *og : overlap_groups) { - // OverlapGroup og_copy = *og; - bool update = og->maybe_add(rpc_exe); - if (!update) { - tmp_new_ogs.push_back(new OverlapGroup()); - tmp_new_ogs.back()->maybe_add(rpc_exe); - for (RPCExecution *rpc_exe : og->rpc_executions) { tmp_new_ogs.back()->maybe_add(rpc_exe); } - } - tmp_new_ogs.push_back(og); - } - overlap_groups.clear(); - for (OverlapGroup *og : tmp_new_ogs) { overlap_groups.push_back(og); } -}; - -void GroupedRPCExecutions::resolve(RPCExecution *rpc_exe) {} - -void GroupedRPCExecutions::add(RPCExecution *rpc_exe) { group.add_to_groups(rpc_exe); }; - -void GroupedRPCExecutions::offload(std::string model_name) {}; - -uint64_t GroupedRPCExecutions::total_mem_cost() { - uint64_t max_mem = 0; - for (auto &og : group.overlap_groups) { - // double og_ma = (og -> mem_active/(1024*1024))/1024.0; - // double og_ms = (og -> mem_static/(1024*1024))/1024.0; - // std::cout << "og size " << og -> rpc_executions.size() - // << " mem active " << og_ma << " GB" - // << " mem static " << og_ms << " GB" << std::endl; - if (og->mem_active + og->mem_static > max_mem) { max_mem = og->mem_active + og->mem_static; } - } - return max_mem; -}; - -RPCInstance::RPCInstance(RPC *rpc_ptr, int id, std::string name) - : rpc_ptr(rpc_ptr), id(id), name(name) {}; - -void RPCInstance::remove_parent(RPCInstance *parent) { - auto it = std::find(parents.begin(), parents.end(), parent); - if (it != parents.end()) { parents.erase(it); } -}; - -void RPCInstance::remove_child(RPCInstance *child) { - auto it = std::find(children.begin(), children.end(), child); - if (it != children.end()) { children.erase(it); } -}; - -void RPCInstance::add_parent(RPCInstance *parent) { parents.push_back(parent); }; - -void RPCInstance::add_child(RPCInstance *child) { children.push_back(child); }; - -void RPCInstance::remove_tmp_child(RPCInstance *child) { - auto it = std::find(tmp_children.begin(), tmp_children.end(), child); - if (it != tmp_children.end()) { tmp_children.erase(it); } -}; - -void RPCInstance::remove_tmp_parent(RPCInstance *parent) { - auto it = std::find(tmp_parents.begin(), tmp_parents.end(), parent); - if (it != tmp_parents.end()) { tmp_parents.erase(it); } -}; - -void RPCInstance::add_tmp_parent(RPCInstance *parent) { tmp_parents.push_back(parent); }; - -void RPCInstance::add_tmp_child(RPCInstance *child) { tmp_children.push_back(child); }; - -uint64_t parameter_sync_cost(uint64_t model_size, RPCExecution *src, RPCExecution *dst, - std::unordered_map &cost_table) { - // 7b size 13738442752 Bytes - // double size_multiplier = double(model_size) / 13738442752.0; - std::string model_key = std::to_string(model_size); - std::string src_key = src->model_parallel_strategy.to_key(); - std::string dst_key = dst->model_parallel_strategy.to_key(); - if (src_key == dst_key) return 0; - std::string key = model_key + "," + src_key + "," + dst_key; - if (cost_table.find(key) == cost_table.end()) { - // std::cout << "key " << key << " not found" << std::endl; - return 0; - } - return cost_table[key]; -} - -void RPCInstance::resolve_parameter_sync(std::vector tmp_graph, - std::unordered_map &cost_table) { - // add parameter synchronization edges - if (!param_sync) return; - - // dst to train - uint64_t from_cost = - parameter_sync_cost(param_sync_size, param_sync_rpc_exe_ptr, rpc_exe_ptr, cost_table); - uint64_t to_cost = - parameter_sync_cost(param_sync_size, rpc_exe_ptr, param_sync_rpc_exe_ptr, cost_table); - // if (param_sync_cost > 0) - // std::cout << "Param sync cost " << param_sync_cost << " from " - // << param_sync_rpc_exe_ptr -> rpc_ptr -> rpc_name << " to " - // << rpc_exe_ptr -> rpc_ptr -> rpc_name << std::endl; - - // add param sync from src to dst - RPCExecution *from_src_exe = - new RPCExecution(rpc_ptr, param_sync_rpc_exe_ptr->device_mesh, - param_sync_rpc_exe_ptr->model_parallel_strategy, from_cost, 0, 0); - RPCInstance *from_src = new RPCInstance(rpc_ptr, id, name + ":from_src"); - from_src->rpc_exe_ptr = from_src_exe; - - RPCExecution *from_dst_exe = new RPCExecution( - rpc_ptr, rpc_exe_ptr->device_mesh, rpc_exe_ptr->model_parallel_strategy, from_cost, 0, 0); - RPCInstance *from_dst = new RPCInstance(rpc_ptr, id, name + ":from_dst"); - from_dst->rpc_exe_ptr = from_dst_exe; - - // bool overlap = src -> rpc_exe_ptr -> device_mesh.overlap( - // dst -> rpc_exe_ptr -> device_mesh); - - for (RPCInstance *parent : parents) { - parent->remove_tmp_child(this); - from_src->add_tmp_parent(parent); - from_dst->add_tmp_parent(parent); - parent->add_tmp_child(from_src); - parent->add_tmp_child(from_dst); - } - this->tmp_parents.clear(); - - from_src->add_tmp_child(this); - from_dst->add_tmp_child(this); - this->add_tmp_parent(from_src); - this->add_tmp_parent(from_dst); - - tmp_graph.push_back(from_src); - tmp_graph.push_back(from_dst); - - tmp_ris.push_back(from_src); - tmp_ris.push_back(from_dst); - tmp_exes.push_back(from_src_exe); - tmp_exes.push_back(from_dst_exe); - - // add param sync from dst to src - RPCExecution *to_src_exe = new RPCExecution(rpc_ptr, rpc_exe_ptr->device_mesh, - rpc_exe_ptr->model_parallel_strategy, to_cost, 0, 0); - RPCInstance *to_src = new RPCInstance(rpc_ptr, id, name + ":to_src"); - to_src->rpc_exe_ptr = to_src_exe; - - RPCExecution *to_dst_exe = - new RPCExecution(rpc_ptr, param_sync_rpc_exe_ptr->device_mesh, - param_sync_rpc_exe_ptr->model_parallel_strategy, to_cost, 0, 0); - RPCInstance *to_dst = new RPCInstance(rpc_ptr, id, name + ":to_dst"); - to_dst->rpc_exe_ptr = to_dst_exe; - - for (RPCInstance *child : children) { - child->remove_tmp_parent(this); - to_src->add_tmp_child(child); - to_dst->add_tmp_child(child); - child->add_tmp_parent(to_src); - child->add_tmp_parent(to_dst); - } - this->tmp_children.clear(); - - to_src->add_tmp_parent(this); - to_dst->add_tmp_parent(this); - this->add_tmp_child(to_src); - this->add_tmp_child(to_dst); - - tmp_graph.push_back(to_src); - tmp_graph.push_back(to_dst); - - tmp_ris.push_back(to_src); - tmp_ris.push_back(to_dst); - tmp_exes.push_back(to_src_exe); - tmp_exes.push_back(to_dst_exe); -} - -CommStats::CommStats(uint64_t local_send, uint64_t local_recv, uint64_t remote_send, - uint64_t remote_recv, uint64_t offload_store, uint64_t offload_load) - : local_send(local_send), - local_recv(local_recv), - remote_send(remote_send), - remote_recv(remote_recv), - offload_store(offload_store), - offload_load(offload_load) {}; - -// ModelConfig::ModelConfig(std::string model_name, -// uint64_t param_size_bytes) -// : model_name(model_name), param_size_bytes(param_size_bytes) { -// }; - -std::string RPCExecution::to_string() { - return rpc_ptr->rpc_name + " on " + device_mesh.name - + ", parallel strategy: " + model_parallel_strategy.to_string(); -}; \ No newline at end of file diff --git a/csrc/search/rpc.hpp b/csrc/search/rpc.hpp deleted file mode 100644 index f9b9792..0000000 --- a/csrc/search/rpc.hpp +++ /dev/null @@ -1,121 +0,0 @@ -#ifndef RPC_HPP -#define RPC_HPP - -#include -#include -#include -#include - -class CommStats { - public: - uint64_t local_send, local_recv, remote_send, remote_recv, offload_store, offload_load; - - CommStats(uint64_t local_send, uint64_t local_recv, uint64_t remote_send, uint64_t remote_recv, - uint64_t offload_store, uint64_t offload_load); -}; - -class RPC { - public: - std::string model_name; - std::string rpc_name; - // interface_type: 0=generate, 1=train_step, 2=inference - std::string interface_type; - - RPC(std::string model_name, std::string rpc_name, std::string interface_type); -}; - -class RPCExecution { - public: - RPC *rpc_ptr; - DeviceMesh &device_mesh; - ModelParallelStrategy &model_parallel_strategy; - uint64_t time_cost, mem, static_mem; - - RPCExecution(RPC *rpc_ptr, DeviceMesh &device_mesh, - ModelParallelStrategy &model_parallel_strategy, uint64_t time_cost, uint64_t mem, - uint64_t static_mem); - - std::string to_string(); -}; - -class OverlapGroup { - public: - std::unordered_set rpc_executions; - std::unordered_set device_meshes; - uint64_t mem_static; - uint64_t mem_active; - - bool maybe_add(RPCExecution *rpc_exe); -}; - -class DeviceMeshGroup { - public: - // std::string device_mesh_name; - std::vector overlap_groups; - - void add_to_groups(RPCExecution *rpc_exe); -}; - -class GroupedRPCExecutions { - public: - // std::unordered_map dn_to_group; - DeviceMeshGroup group; - - void add(RPCExecution *rpc_exe); - void resolve(RPCExecution *rpc_exe); - void offload(std::string model_name); - uint64_t total_mem_cost(); -}; - -class RPCInstance { - public: - RPC *rpc_ptr; - int id; - std::string name; - std::vector children; - std::vector parents; - std::vector tmp_children; - std::vector tmp_parents; - std::vector tmp_ris; // pointers to tmp rpc instances - std::vector tmp_exes; // pointers to tmp rpc executions - - RPCExecution *rpc_exe_ptr = nullptr; - RPCExecution *param_sync_rpc_exe_ptr = nullptr; - bool param_sync = false; - uint64_t param_sync_size = 0; - bool offload = false; - uint64_t offload_size = 0; - - RPCInstance(RPC *rpc_ptr, int id, std::string name); - - uint64_t ready_time = 0, start_time = 0, end_time = 0; - - void remove_parent(RPCInstance *parent); - void remove_child(RPCInstance *child); - void add_parent(RPCInstance *parent); - void add_child(RPCInstance *child); - - void add_tmp_parent(RPCInstance *parent); - void add_tmp_child(RPCInstance *child); - void remove_tmp_parent(RPCInstance *parent); - void remove_tmp_child(RPCInstance *child); - - void resolve_parameter_sync(std::vector tmp_graph, - std::unordered_map &cost_table); - // void resolve_offload(std::vector tmp_graph, - // CommStats& comm_stats); -}; - -uint64_t parameter_sync_cost(uint64_t param_size_bytes, RPCExecution *src, RPCExecution *dst, - std::unordered_map &cost_table); - -uint64_t remote_param_sync_size(uint64_t size, RPCExecution *src, RPCExecution *dst); - -// class ModelConfig { -// std::string model_name; -// uint64_t param_size_bytes; - -// ModelConfig(std::string model_name, uint64_t param_size_bytes); -// }; - -#endif \ No newline at end of file diff --git a/csrc/search/search.cpp b/csrc/search/search.cpp deleted file mode 100644 index 61efe5f..0000000 --- a/csrc/search/search.cpp +++ /dev/null @@ -1,827 +0,0 @@ -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -namespace py = pybind11; - -uint64_t VALID_COUNT_CAP = 25000000; // 25000000 -size_t MAX_EXE_PER_RPC = 1000; -// std::unordered_map device_mesh_map; - -void print_int_vector(std::vector &vec) { - std::cout << "["; - for (int i = 0; i < static_cast(vec.size()); i++) { std::cout << vec[i] << ", "; } - std::cout << "] "; -}; - -std::size_t vector_hash(const std::vector &vec) { - std::size_t seed = vec.size(); - for (const auto &i : vec) { - seed ^= std::hash{}(i) + 0x9e3779b9 + (seed << 6) + (seed >> 2); - } - return seed; -} - -long check_memory_bytes() { - std::ifstream statm_file("/proc/self/statm"); - long memory_usage_bytes = -1; - if (!statm_file) { - std::cerr << "Failed to open /proc/self/statm\n"; - } else { - long size, resident, share, text, lib, data, dt; - statm_file >> size >> resident >> share >> text >> lib >> data >> dt; - - // size is in pages, to convert to bytes, multiply by the page size - long page_size = sysconf(_SC_PAGESIZE); - memory_usage_bytes = size * page_size; - } - return memory_usage_bytes; -}; - -void make_rpc_exe_table(std::unordered_map> &rpc_exe_table, - std::vector &rpc_exes) { - for (auto &rpc_exe : rpc_exes) { - if (rpc_exe_table[rpc_exe->rpc_ptr->rpc_name].size() < MAX_EXE_PER_RPC) - rpc_exe_table[rpc_exe->rpc_ptr->rpc_name].push_back(rpc_exe); - } - - for (auto &x : rpc_exe_table) { - std::string rpc_name = x.first; - std::vector &rpc_exe_list = x.second; - // sort first - std::sort(rpc_exe_list.begin(), rpc_exe_list.end(), - [](const RPCExecution *a, const RPCExecution *b) { - if (a->time_cost == b->time_cost) - return a->device_mesh.name < b->device_mesh.name; - else - return a->time_cost < b->time_cost; - }); - } -} - -void make_sorted_rpc_names( - std::vector &sorted_rpc_names, - std::unordered_map> &rpc_exe_table) { - std::vector> average_time_cost; - for (auto &x : rpc_exe_table) { - std::string rpc_name = x.first; - std::vector &rpc_exe_list = x.second; - uint64_t total_time_cost = 0; - int c = 0; - for (auto &rpc_exe : rpc_exe_list) { - total_time_cost += rpc_exe->time_cost; - c += 1; - if (c > 10) break; - } - average_time_cost.push_back(std::make_pair(rpc_name, total_time_cost / 10)); - } - - std::sort(average_time_cost.begin(), average_time_cost.end(), - [](const std::pair &a, const std::pair &b) { - return a.second > b.second; - }); - - for (auto &x : average_time_cost) { sorted_rpc_names.push_back(x.first); } -} - -void prepare(std::unordered_map> &rpc_exe_table, - std::unordered_map &rpc_table, - std::vector &sorted_rpc_names, - std::unordered_map> &ri_table, - std::unordered_map> &model_name_ri_table, - std::vector rpcs, std::vector rpc_exes, - std::vector graph) { - std::vector> average_time_cost; - for (auto &rpc : rpcs) { rpc_table[rpc->rpc_name] = rpc; } - - make_rpc_exe_table(rpc_exe_table, rpc_exes); - make_sorted_rpc_names(sorted_rpc_names, rpc_exe_table); - - for (auto &rpc_instance : graph) { - ri_table[rpc_instance->rpc_ptr->rpc_name].push_back(rpc_instance); - } - - for (auto &rpc_instance : graph) { - model_name_ri_table[rpc_instance->rpc_ptr->model_name].push_back(rpc_instance); - } -}; - -std::vector mcmc_search(std::vector rpcs, - std::vector rpc_exes, - std::vector graph, - std::unordered_map &cost_table, - std::unordered_map model_sizes, - double beta, double time_limit, - MinEndTimeQueue &top_k_queue) { - std::unordered_map> rpc_exe_table; - std::unordered_map rpc_table; - std::vector sorted_rpc_names; - std::unordered_map> ri_table; - std::unordered_map> model_name_ri_table; - std::chrono::duration time_limit_duration(time_limit); - std::vector time_cost_cache; - - prepare(rpc_exe_table, rpc_table, sorted_rpc_names, ri_table, model_name_ri_table, rpcs, rpc_exes, - graph); - - std::vector index; - std::vector min_index; - std::vector max_index; - uint64_t min_index_mem = 0; - uint64_t valid_count = 0; - uint64_t oom_count = 0; - int num_rpcs = static_cast(sorted_rpc_names.size()); - uint64_t min_time_cost = std::numeric_limits::max(); - uint64_t max_time_cost = 0; - double avg = 0; - - // [6, 2, 3, 2, 12, 15, ] - // [0, 0, 0, 7, 8, 13, ] - // index = { 23, 7, 30, 154, 173, 265 }; - // SimulateResult sr1 = simulate(graph, cost_table, model_sizes, - // rpc_table, rpc_exe_table, ri_table, - // model_name_ri_table, sorted_rpc_names, - // index); - // std::cout << "index 1 end time " << sr1.end_time << " mem cost " << sr1.mem_cost << std::endl; - // std::cout << "************************" << std::endl; - // index = { 1, 0, 0, 11, 20, 3 }; - // SimulateResult sr2 = simulate(graph, cost_table, model_sizes, - // rpc_table, rpc_exe_table, ri_table, - // model_name_ri_table, sorted_rpc_names, - // index); - // std::cout << "index 2 end time " << sr2.end_time << " mem cost " << sr2.mem_cost << std::endl; - // exit(0); - - // initial value - index.resize(sorted_rpc_names.size(), 0); - - // index 1 - SimulateResult first_sr = simulate(graph, cost_table, model_sizes, rpc_table, rpc_exe_table, - ri_table, model_name_ri_table, sorted_rpc_names, index); - SimulateResult final_sr = first_sr; - uint64_t current_cost = first_sr.end_time; - time_cost_cache.push_back(first_sr); - - if (first_sr.oom) - oom_count += 1; - else - top_k_queue.insert(first_sr); - // std::cout << "initial cost " << current_cost << " oom " - // << first_sr.oom << std::endl; - // index = {0, 1, 3, 34, 4, 10}; - // std::unordered_map time_cost_cache; - - auto start = std::chrono::high_resolution_clock::now(); - // bool outer_loop_break_flag = false; - while (valid_count < VALID_COUNT_CAP) { - // only change one model execution in each iteration - // std::vector sr_vector; - std::unordered_map> flatten_to_pair; - std::vector weight; - // double beta = 0.0075; - int max_step_range = 10000; - int current = 0; - - std::vector new_index(index); - for (int i = 0; i < num_rpcs; i++) { - std::string rpc_name = sorted_rpc_names[i]; - int c_i = index[i]; - int min_i = std::max(0, c_i - max_step_range); - int max_i = - std::min(static_cast(rpc_exe_table[rpc_name].size()), c_i + max_step_range + 1); - - for (int j = min_i; j < max_i; j++) { - if (j == c_i) continue; - // int tmp = new_index[i]; - // new_index[i] = j; - // SimulateResult sr = simulate(graph, cost_table, model_sizes, - // rpc_table, rpc_exe_table, ri_table, - // model_name_ri_table, sorted_rpc_names, - // new_index); - // sr_vector.push_back(sr); - // new_index[i] = tmp; - flatten_to_pair[current] = std::make_pair(i, j); - current++; - } - } - - // if (time_cost_cache.size() > 10000000) { - // time_cost_cache.clear(); - // } - - // std::cout << "sr vector size " << sr_vector.size() << std::endl; - // for (int i = 0; i < static_cast(sr_vector.size()); i++) { - // weight.push_back(std::exp(-beta * (sr_vector[i].end_time/100000))); - // } - // exit(0); - - std::random_device rd; - std::mt19937 gen(rd()); - // std::discrete_distribution d(weight.begin(), weight.end()); - std::uniform_int_distribution d(0, static_cast(flatten_to_pair.size() - 1)); - int selected = d(gen); - - // assign new index - int selected_i = flatten_to_pair[selected].first; - int selected_j = flatten_to_pair[selected].second; - new_index[selected_i] = selected_j; - - SimulateResult selected_sr = - simulate(graph, cost_table, model_sizes, rpc_table, rpc_exe_table, ri_table, - model_name_ri_table, sorted_rpc_names, new_index); - uint64_t selected_cost = selected_sr.end_time; - // if (selected_sr.oom) { - // std::cout << "oom max end time " << selected_cost << std::endl; - // } else { - // std::cout << "max end time " << selected_cost << std::endl; - // } - - if (selected_cost < std::numeric_limits::max()) { - bool accepted = true; - if (current_cost < selected_cost) { - double accept_prob = std::exp(-beta * ((selected_cost - current_cost) / 100000)); - // double accept_prob = ap * ap; - std::bernoulli_distribution accept_dist(accept_prob); - accepted = accept_dist(gen); - } - - if (accepted) { - // std::cout << "accepted" << std::endl; - index = new_index; - current_cost = selected_cost; - valid_count++; - if (!selected_sr.oom) { - // min_time_cost = std::min(min_time_cost, selected_cost); - // max_time_cost = std::max(max_time_cost, selected_cost); - avg = (selected_cost + avg * (valid_count - 1)) / valid_count; - if (min_time_cost > selected_cost) { - min_time_cost = selected_cost; - min_index = index; - min_index_mem = selected_sr.mem_cost; - // final_sr = selected_sr; - auto now = std::chrono::high_resolution_clock::now(); - double diff = - std::chrono::duration_cast(now - start).count(); - selected_sr.used_time = diff; - time_cost_cache.push_back(selected_sr); - } - if (min_time_cost == selected_cost) { - if (min_index_mem > selected_sr.mem_cost) { - min_index = index; - min_index_mem = selected_sr.mem_cost; - // final_sr = selected_sr; - } - } - top_k_queue.insert(selected_sr); - if (max_time_cost < selected_cost) { - max_time_cost = selected_cost; - max_index = index; - } - } else { - oom_count += 1; - } - // if (min_time_cost <= 100000000) break; // DEBUG - - auto end = std::chrono::high_resolution_clock::now(); - std::chrono::duration diff = end - start; - - if (valid_count % 1000 == 0) { - std::cout << " valid_count " << valid_count << " oom count " << oom_count << " time " - << diff.count() << " min time cost " << min_time_cost << " min index mem cost " - << min_index_mem << " max time cost " << max_time_cost << " current cost " - << current_cost << " avg " << avg << " mem usage " << check_memory_bytes() - << std::endl; - - std::cout << "min index : "; - print_int_vector(min_index); - std::cout << std::endl; - std::cout << "max index : "; - print_int_vector(max_index); - std::cout << std::endl; - std::cout << "current index : "; - print_int_vector(index); - std::cout << std::endl; - } - if (diff > time_limit_duration) break; - } - } - } - - // std::cout << "break" << std::endl; - - // int rpc_index = 0; - // for (int index : final_sr.index) { - // // std::cout << "index " << index << " rpc_index " << rpc_index << std::endl; - // // std::cout << "rpc name " << sorted_rpc_names[rpc_index] << std::endl; - // RPCExecution* re_ptr = rpc_exe_table[sorted_rpc_names[rpc_index]][index]; - // final_sr.rpc_exe_list.push_back(re_ptr); - // rpc_index ++; - // } - // std::cout << "final_sr.rpc_exe_list size " << final_sr.rpc_exe_list.size() << std::endl; - - auto end = std::chrono::high_resolution_clock::now(); - std::chrono::duration diff = end - start; - std::cout << "MCMC Search finished, beta " << beta << " time limit " << time_limit << " seconds" - << " valid_count " << valid_count << " oom_count " << oom_count << " time " - << diff.count() << " min time cost " << min_time_cost << " max time cost " - << max_time_cost << " avg " << avg << std::endl; - std::cout << "RPCExecutions: " << std::endl; - for (auto &re_ptr : final_sr.rpc_exe_list) { std::cout << re_ptr->to_string() << std::endl; } - return time_cost_cache; - // return final_sr; -} - -void multi_mcmc_search(std::vector rpcs, std::vector rpc_exes, - std::vector graph, - std::unordered_map &cost_table, - std::unordered_map model_sizes, double beta_min, - double beta_max, double beta_step, MinEndTimeQueue &res_queue, - int top_k = 10, double time_limit = 60.0, - int repeat = 1 // Remove the trailing comma here -) { - SimulateResult sr; - std::vector queues; - // std::vector ts; - for (int i = 0; i < repeat; i++) { - for (double beta = beta_min; beta < beta_max; beta += beta_step) { - MinEndTimeQueue *q = new MinEndTimeQueue(10); - queues.push_back(q); - - // Create a new thread to run mcmc_search - std::vector r = - mcmc_search(rpcs, rpc_exes, graph, cost_table, model_sizes, beta, time_limit, *q); - } - } - - for (auto &q : queues) { - mergeMinEndTimeQueues(res_queue, *q); - // delete q; - } - - // std::cout << "Best result: " << sr.end_time << std::endl; - // for (auto& re_ptr : sr.rpc_exe_list) { - // std::cout << re_ptr -> to_string() << std::endl; - // } - // std::cout << "Index: "; - // for (int i : sr.index) { - // std::cout << i << " "; - // } - // std::cout << std::endl; - // std::cout << "Time cost: " << sr.end_time - // << ", mem cost: " << sr.mem_cost << std::endl; - // std::cout << "sr.rpc_exe_list size " << sr.rpc_exe_list.size() << std::endl; - // return sr; -} - -void input_check(std::vector rpcs, std::vector rpc_exes, - std::vector graph, CommStats &comm_stats) { - std::cout << "rpcs: " << rpcs.size() << std::endl; - std::cout << "rpc_exes: " << rpc_exes.size() << std::endl; - std::cout << "graph: " << graph.size() << std::endl; - - for (auto &rpc_instance : graph) { - std::cout << "==================" << std::endl; - std::cout << "rpc instance: " << rpc_instance->name << std::endl; - std::cout << "parents" << std::endl; - for (auto &parent : rpc_instance->parents) { std::cout << parent->name << " "; } - std::cout << std::endl; - std::cout << "children" << std::endl; - for (auto &child : rpc_instance->children) { std::cout << child->name << " "; } - std::cout << std::endl; - // std::cout << "parents: " << rpc_instance -> parents.size() - // << " children: " << rpc_instance -> children.size() << std::endl; - } - - std::cout << "comm_stats: " << std::endl; - std::cout << "local_send: " << comm_stats.local_send << std::endl; - std::cout << "local_recv: " << comm_stats.local_recv << std::endl; - std::cout << "remote_send: " << comm_stats.remote_send << std::endl; - std::cout << "remote_recv: " << comm_stats.remote_recv << std::endl; - std::cout << "offload_store: " << comm_stats.offload_store << std::endl; - std::cout << "offload_load: " << comm_stats.offload_load << std::endl; -} - -RPC *cast_rpc(py::handle rpc_py) { - return new RPC(py::str(rpc_py.attr("model_name")).cast(), - rpc_py.attr("name").cast(), - py::str(rpc_py.attr("interface_type")).cast()); -} - -DeviceMesh *cast_device_mesh(py::handle device_mesh_py, - std::unordered_map &device_mesh_map) { - std::string name = device_mesh_py.attr("name").cast(); - if (device_mesh_map.find(name) == device_mesh_map.end()) { - py::array_t mapping_array = - device_mesh_py.attr("mapping").cast>(); - py::buffer_info buf_info = mapping_array.request(); - - auto rows = buf_info.shape[0]; - auto cols = buf_info.shape[1]; - - std::vector> mapping(rows, std::vector(cols)); - - // Get a pointer to the data - int32_t *data = static_cast(buf_info.ptr); - - // Fill the 2D vector with data from the numpy array - for (size_t i = 0; i < static_cast(rows); ++i) { - for (size_t j = 0; j < static_cast(cols); ++j) { mapping[i][j] = data[i * cols + j]; } - } - - DeviceMesh *device_mesh = - new DeviceMesh(device_mesh_py.attr("n_nodes").cast(), - device_mesh_py.attr("n_gpus_per_node").cast(), mapping, - device_mesh_py.attr("global_mesh_name").cast(), - device_mesh_py.attr("name").cast()); - - device_mesh_map[name] = device_mesh; - return device_mesh; - } else { - return device_mesh_map[name]; - } -} - -ModelParallelStrategy *cast_model_parallel_strategy(py::handle model_parallel_strategy_py) { - return new ModelParallelStrategy( - model_parallel_strategy_py.attr("pipeline_parallel_size").cast(), - model_parallel_strategy_py.attr("data_parallel_size").cast(), - model_parallel_strategy_py.attr("model_parallel_size").cast()); -} - -RPCExecution *cast_rpc_execution(py::handle rpc_exe_py, std::unordered_map &tmp, - std::unordered_map &device_mesh_map) { - DeviceMesh *device_mesh = cast_device_mesh(rpc_exe_py.attr("device_mesh"), device_mesh_map); - ModelParallelStrategy *model_parallel_strategy = - cast_model_parallel_strategy(rpc_exe_py.attr("parallel_strategy")); - // RPC* rpc = cast_rpc(rpc_exe_py.attr("rpc")); - - return new RPCExecution( - tmp[rpc_exe_py.attr("rpc").attr("name").cast()], *device_mesh, - *model_parallel_strategy, rpc_exe_py.attr("time_cost").cast(), - rpc_exe_py.attr("mem").cast(), rpc_exe_py.attr("static_mem").cast()); -} - -RPCInstance *cast_rpc_instance_wo_dependency(py::handle rpc_instance_py, - std::unordered_map &tmp) { - return new RPCInstance(tmp[rpc_instance_py.attr("rpc").attr("name").cast()], - rpc_instance_py.attr("iteration_id").cast(), - rpc_instance_py.attr("name").cast()); -} - -void cast_rpc_instance_dependency(py::handle rpc_instance_py, RPCInstance *ri_ptr, - std::unordered_map &tmp_graph) { - for (py::handle parent_py : rpc_instance_py.attr("parents")) - ri_ptr->parents.push_back(tmp_graph[parent_py.attr("name").cast()]); - for (py::handle child_py : rpc_instance_py.attr("children")) - ri_ptr->children.push_back(tmp_graph[child_py.attr("name").cast()]); -} - -py::list py_single_mcmc_search_time_profile(py::list rpcs_py, py::list rpc_exes_py, - py::list graph_py, py::dict cost_table_py, - py::dict model_sizes_py, py::object beta, - py::object time_limit) { - std::vector rpcs; - std::unordered_map tmp; - for (py::handle rpc_py : rpcs_py) { - RPC *rpc_ptr = cast_rpc(rpc_py); - rpcs.push_back(rpc_ptr); - tmp[rpc_ptr->rpc_name] = rpc_ptr; - } - - std::vector rpc_exes; - std::unordered_map tmp_device_mesh; - for (py::handle rpc_exe_py : rpc_exes_py) { - RPCExecution *rpc_exe_ptr = cast_rpc_execution(rpc_exe_py, tmp, tmp_device_mesh); - rpc_exes.push_back(rpc_exe_ptr); - } - - std::vector graph; - std::unordered_map tmp_graph; - for (py::handle ri_py : graph_py) { - RPCInstance *ri_ptr = cast_rpc_instance_wo_dependency(ri_py, tmp); - // std::cout << "cast " << ri_ptr -> name << std::endl; - tmp_graph[ri_ptr->name] = ri_ptr; - } - // build dependecny - for (py::handle ri_py : graph_py) { - std::string ri_name = ri_py.attr("name").cast(); - cast_rpc_instance_dependency(ri_py, tmp_graph[ri_name], tmp_graph); - graph.push_back(tmp_graph[ri_name]); - } - - std::unordered_map cost_table = - cost_table_py.cast>(); - - std::unordered_map model_sizes = - model_sizes_py.cast>(); - MinEndTimeQueue res_queue(10); - std::vector rlist = - mcmc_search(rpcs, rpc_exes, graph, cost_table, model_sizes, beta.cast(), - time_limit.cast(), res_queue); - - std::unordered_map> rpc_exe_table; - std::vector sorted_rpc_names; - make_rpc_exe_table(rpc_exe_table, rpc_exes); - make_sorted_rpc_names(sorted_rpc_names, rpc_exe_table); - py::list result; - std::cout << "rlist.size " << rlist.size() << std::endl; - for (auto &r : rlist) { - // SimulateResult r = res_queue.getQueue().top(); - // res_queue.getQueue().pop(); - - std::cout << "End time: " << r.end_time << std::endl; - for (auto &re_ptr : r.rpc_exe_list) { std::cout << re_ptr->to_string() << std::endl; } - std::cout << "Index: "; - for (int i : r.index) { std::cout << i << " "; } - std::cout << std::endl; - std::cout << "Time cost: " << r.end_time << ", mem cost: " << r.mem_cost << std::endl; - - int rpc_index = 0; - for (int index : r.index) { - // std::cout << "index " << index << " rpc_index " << rpc_index << std::endl; - // std::cout << "rpc name " << sorted_rpc_names[rpc_index] << std::endl; - RPCExecution *re_ptr = rpc_exe_table[sorted_rpc_names[rpc_index]][index]; - r.rpc_exe_list.push_back(re_ptr); - rpc_index++; - } - - py::dict rdict; - for (auto &re_ptr : r.rpc_exe_list) { - py::dict rpc_exe_info; - std::string rpc_name = re_ptr->rpc_ptr->rpc_name; - py::object rpc_name_obj = py::str(rpc_name); - // rpc_exe_info.append(re_ptr -> device_mesh.device_mesh_name); - // rpc_exe_info.append(re_ptr -> model_parallel_strategy.num_dp); - // rpc_exe_info.append(re_ptr -> model_parallel_strategy.num_mp); - // rpc_exe_info.append(re_ptr -> model_parallel_strategy.num_pp); - rpc_exe_info["device_mesh"] = re_ptr->device_mesh.name; - rpc_exe_info["num_dp"] = re_ptr->model_parallel_strategy.num_dp; - rpc_exe_info["num_mp"] = re_ptr->model_parallel_strategy.num_mp; - rpc_exe_info["num_pp"] = re_ptr->model_parallel_strategy.num_pp; - rdict[rpc_name_obj] = rpc_exe_info; - // std::cout << "append key " << rpc_name_obj << std::endl; - } - rdict["end_time"] = r.end_time; - rdict["mem_cost"] = r.mem_cost; - rdict["used_time"] = r.used_time; - result.append(rdict); - } - return result; -}; - -py::list py_multi_mcmc_search(py::list rpcs_py, py::list rpc_exes_py, py::list graph_py, - py::dict cost_table_py, py::dict model_sizes_py, - py::object beta_min_py, py::object beta_max_py, - py::object beta_step_py, py::object time_limit_py, - py::object repeat) { - std::vector rpcs; - std::unordered_map tmp; - for (py::handle rpc_py : rpcs_py) { - RPC *rpc_ptr = cast_rpc(rpc_py); - rpcs.push_back(rpc_ptr); - tmp[rpc_ptr->rpc_name] = rpc_ptr; - } - - std::vector rpc_exes; - std::unordered_map tmp_device_mesh; - for (py::handle rpc_exe_py : rpc_exes_py) { - RPCExecution *rpc_exe_ptr = cast_rpc_execution(rpc_exe_py, tmp, tmp_device_mesh); - rpc_exes.push_back(rpc_exe_ptr); - } - - std::vector graph; - std::unordered_map tmp_graph; - for (py::handle ri_py : graph_py) { - RPCInstance *ri_ptr = cast_rpc_instance_wo_dependency(ri_py, tmp); - // std::cout << "cast " << ri_ptr -> name << std::endl; - tmp_graph[ri_ptr->name] = ri_ptr; - } - // build dependecny - for (py::handle ri_py : graph_py) { - std::string ri_name = ri_py.attr("name").cast(); - cast_rpc_instance_dependency(ri_py, tmp_graph[ri_name], tmp_graph); - graph.push_back(tmp_graph[ri_name]); - } - - std::unordered_map cost_table = - cost_table_py.cast>(); - - std::unordered_map model_sizes = - model_sizes_py.cast>(); - - double beta_min = beta_min_py.cast(); - double beta_max = beta_max_py.cast(); - double beta_step = beta_step_py.cast(); - double time_limit = time_limit_py.cast(); - int rp = repeat.cast(); - - MinEndTimeQueue res_queue(10); - multi_mcmc_search(rpcs, rpc_exes, graph, cost_table, model_sizes, beta_min, beta_max, beta_step, - res_queue, 10, time_limit, rp); - - std::unordered_map> rpc_exe_table; - std::vector sorted_rpc_names; - make_rpc_exe_table(rpc_exe_table, rpc_exes); - make_sorted_rpc_names(sorted_rpc_names, rpc_exe_table); - - // std::cout << "r.rpc_exe_list size " << r.rpc_exe_list.size() << std::endl; - // for (int rpc_index = 0; rpc_index < 6; rpc_index++) { - // std::cout << "rpc name " << sorted_rpc_names[rpc_index] << std::endl; - // int index = 0; - // for (auto& re_ptr : rpc_exe_table[sorted_rpc_names[rpc_index]]) { - // std::cout << "index " << index << " " << re_ptr -> to_string() << std::endl; - // index ++; - // } - // } - - py::list result; - std::cout << "res_queue.getQueue().size " << res_queue.getQueue().size() << std::endl; - while (!res_queue.getQueue().empty()) { - SimulateResult r = res_queue.getQueue().top(); - res_queue.getQueue().pop(); - - std::cout << "End time: " << r.end_time << std::endl; - for (auto &re_ptr : r.rpc_exe_list) { std::cout << re_ptr->to_string() << std::endl; } - std::cout << "Index: "; - for (int i : r.index) { std::cout << i << " "; } - std::cout << std::endl; - std::cout << "Time cost: " << r.end_time << ", mem cost: " << r.mem_cost << std::endl; - - int rpc_index = 0; - for (int index : r.index) { - // std::cout << "index " << index << " rpc_index " << rpc_index << std::endl; - // std::cout << "rpc name " << sorted_rpc_names[rpc_index] << std::endl; - RPCExecution *re_ptr = rpc_exe_table[sorted_rpc_names[rpc_index]][index]; - r.rpc_exe_list.push_back(re_ptr); - rpc_index++; - } - - py::dict rdict; - for (auto &re_ptr : r.rpc_exe_list) { - py::dict rpc_exe_info; - std::string rpc_name = re_ptr->rpc_ptr->rpc_name; - py::object rpc_name_obj = py::str(rpc_name); - // convert device mesh mapping into py::array_t - std::vector> mapping = re_ptr->device_mesh.mapping; - int rows = mapping.size(); - int cols = mapping[0].size(); - - py::array_t numpy_array({rows, cols}); - for (int i = 0; i < rows; ++i) { - for (int j = 0; j < cols; ++j) { - *numpy_array.mutable_data(i, j) = mapping[i][j]; - // std::cout << i << j << mapping[i][j] << std::endl; - } - } - // store in py::dict - rpc_exe_info["device_mesh_mapping"] = numpy_array; - rpc_exe_info["device_mesh_name"] = re_ptr->device_mesh.name; - rpc_exe_info["num_dp"] = re_ptr->model_parallel_strategy.num_dp; - rpc_exe_info["num_mp"] = re_ptr->model_parallel_strategy.num_mp; - rpc_exe_info["num_pp"] = re_ptr->model_parallel_strategy.num_pp; - rdict[rpc_name_obj] = rpc_exe_info; - // std::cout << "append key " << rpc_name_obj << std::endl; - } - rdict["end_time"] = r.end_time; - rdict["mem_cost"] = r.mem_cost; - result.append(rdict); - } - - return result; -}; - -PYBIND11_MODULE(mdm_search, m) { - m.doc() = "model device mapping search module"; - - // for debug - // m.def("mcmc_search", [](py::list rpcs_py, py::list rpc_exes_py, - // py::list graph_py, py::object comm_stats_py, - // py::dict model_sizes_py, py::object beta_py, - // py::object time_limit_py) { - // std::vector rpcs; - // std::unordered_map tmp; - // for (py::handle rpc_py : rpcs_py) { - // RPC* rpc_ptr = cast_rpc(rpc_py); - // rpcs.push_back(rpc_ptr); - // tmp[rpc_ptr -> rpc_name] = rpc_ptr; - // } - - // std::vector rpc_exes; - // for (py::handle rpc_exe_py : rpc_exes_py) { - // RPCExecution* rpc_exe_ptr = cast_rpc_execution(rpc_exe_py, tmp); - // rpc_exes.push_back(rpc_exe_ptr); - // } - - // std::vector graph; - // std::unordered_map tmp_graph; - // for (py::handle ri_py : graph_py) { - // RPCInstance* ri_ptr = cast_rpc_instance_wo_dependency(ri_py, tmp); - // std::cout << "cast " << ri_ptr -> name << std::endl; - // tmp_graph[ri_ptr -> name] = ri_ptr; - // } - // // build dependecny - // for (py::handle ri_py : graph_py) { - // std::string ri_name = ri_py.attr("name").cast(); - // cast_rpc_instance_dependency(ri_py, tmp_graph[ri_name], tmp_graph); - // graph.push_back(tmp_graph[ri_name]); - // } - - // CommStats comm_stats( - // comm_stats_py.attr("local_send").cast(), - // comm_stats_py.attr("local_recv").cast(), - // comm_stats_py.attr("remote_send").cast(), - // comm_stats_py.attr("remote_recv").cast(), - // comm_stats_py.attr("offload_load").cast(), - // comm_stats_py.attr("offload_store").cast() - // ); - - // std::unordered_map model_sizes - // = model_sizes_py.cast>(); - - // double beta = beta_py.cast(); - // double time_limit = time_limit_py.cast(); - - // mcmc_search(rpcs, rpc_exes, graph, - // comm_stats, model_sizes, beta, - // time_limit); - // }); - - m.def("input_check", - [](py::list rpcs_py, py::list rpc_exes_py, py::list graph_py, py::object comm_stats_py) { - std::vector rpcs; - std::unordered_map tmp; - for (py::handle rpc_py : rpcs_py) { - RPC *rpc_ptr = cast_rpc(rpc_py); - rpcs.push_back(rpc_ptr); - tmp[rpc_ptr->rpc_name] = rpc_ptr; - } - - std::vector rpc_exes; - std::unordered_map tmp_device_mesh; - for (py::handle rpc_exe_py : rpc_exes_py) { - RPCExecution *rpc_exe_ptr = cast_rpc_execution(rpc_exe_py, tmp, tmp_device_mesh); - rpc_exes.push_back(rpc_exe_ptr); - } - - std::vector graph; - std::unordered_map tmp_graph; - for (py::handle ri_py : graph_py) { - RPCInstance *ri_ptr = cast_rpc_instance_wo_dependency(ri_py, tmp); - std::cout << "cast " << ri_ptr->name << std::endl; - tmp_graph[ri_ptr->name] = ri_ptr; - } - // build dependecny - for (py::handle ri_py : graph_py) { - std::string ri_name = ri_py.attr("name").cast(); - cast_rpc_instance_dependency(ri_py, tmp_graph[ri_name], tmp_graph); - graph.push_back(tmp_graph[ri_name]); - } - - CommStats comm_stats(comm_stats_py.attr("local_send").cast(), - comm_stats_py.attr("local_recv").cast(), - comm_stats_py.attr("remote_send").cast(), - comm_stats_py.attr("remote_recv").cast(), - comm_stats_py.attr("offload_load").cast(), - comm_stats_py.attr("offload_store").cast()); - - input_check(rpcs, rpc_exes, graph, comm_stats); - }); - - // mcmc search to py result - m.def("multi_mcmc_search", &py_multi_mcmc_search); - - m.def("parameter_sync_cost", [](py::object rpcs_py, py::object param_size_bytes_py, - py::dict cost_table_py, py::object src_py, py::object dst_py) { - uint64_t param_size_bytes = param_size_bytes_py.cast(); - std::unordered_map cost_table = - cost_table_py.cast>(); - std::vector rpcs; - std::unordered_map tmp; - for (py::handle rpc_py : rpcs_py) { - RPC *rpc_ptr = cast_rpc(rpc_py); - rpcs.push_back(rpc_ptr); - tmp[rpc_ptr->rpc_name] = rpc_ptr; - } - - std::unordered_map tmp_device_mesh; - RPCExecution *src = cast_rpc_execution(src_py, tmp, tmp_device_mesh); - RPCExecution *dst = cast_rpc_execution(dst_py, tmp, tmp_device_mesh); - - return parameter_sync_cost(param_size_bytes, src, dst, cost_table); - }); - - m.def("mcmc_search_time_profile", &py_single_mcmc_search_time_profile); -}; diff --git a/csrc/search/simulate.cpp b/csrc/search/simulate.cpp deleted file mode 100644 index 223e33c..0000000 --- a/csrc/search/simulate.cpp +++ /dev/null @@ -1,244 +0,0 @@ -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#define MAX(a, b) ((a) > (b) ? (a) : (b)) - -uint64_t SOFT_GPU_MEM_CAP = 85899345920; // 80G - -SimulateResult::SimulateResult() - : end_time(std::numeric_limits::max()), oom(true), mem_cost(0) {} - -SimulateResult::SimulateResult(uint64_t end_time, bool oom, uint64_t mem_cost, - std::vector &index) - : end_time(end_time), oom(oom), mem_cost(mem_cost), index(index) {} - -SimulateResult simulate( - std::vector &graph, std::unordered_map &cost_table, - std::unordered_map &model_sizes, - std::unordered_map &rpc_table, - std::unordered_map> &rpc_exe_table, - std::unordered_map> &ri_table, - std::unordered_map> &model_name_ri_table, - std::vector &sorted_rpc_names, std::vector &index) { - auto start = std::chrono::high_resolution_clock::now(); - GroupedRPCExecutions grouped_rpc_exe; - // std::unordered_map param_dst; // model_name -> rpc_exe_ptr - std::unordered_set offloaded; - uint64_t oom_penalty = 3; - int num_rpcs = static_cast(sorted_rpc_names.size()); - - for (int i = 0; i < num_rpcs; i++) { - std::string rpc_name = sorted_rpc_names[i]; - RPC *rpc = rpc_table[rpc_name]; - RPCExecution *rpc_exe = rpc_exe_table[rpc_name][index[i]]; - - for (auto &ri : ri_table[rpc_name]) { - ri->rpc_exe_ptr = rpc_exe; // assign rpc_exe to rpc_instance - } - - // dirty implementation, check whether instance is hooked with param syncs - for (auto &rpc_instance : model_name_ri_table[rpc->model_name]) { - if (rpc_instance->rpc_ptr->interface_type == "ModelInterfaceType.TRAIN_STEP" - && rpc->interface_type != "ModelInterfaceType.TRAIN_STEP") { - // param_dst[rpc -> model_name] = rpc_exe; - rpc_instance->param_sync = true; - rpc_instance->param_sync_size = model_sizes[rpc_instance->rpc_ptr->model_name]; - rpc_instance->param_sync_rpc_exe_ptr = rpc_exe; - } - } - grouped_rpc_exe.add(rpc_exe); - std::string model_name = rpc->model_name; - } - - // rpc_instances: list of rpc instances, graph - std::priority_queue, CompareReadyTime> ready_queue; - // std::vector executed; // for debug, remove later - std::unordered_map parent_executed; - std::unordered_set device_meshes; - - // for offload and parameter sync RPC instances - std::vector tmp_graph; - - // resolve parameter sync - for (RPCInstance *node : graph) { - tmp_graph.push_back(node); - node->tmp_children = node->children; - node->tmp_parents = node->parents; - // std::cout << "Resolve parameter sync: " << node -> name - // << " " << node -> param_sync << std::endl; - node->resolve_parameter_sync(tmp_graph, cost_table); - // node -> resolve_offload(tmp_graph, comm_stats); - } - - uint64_t max_end_time = 0; - for (RPCInstance *node : tmp_graph) { - // std::cout << "Node: " << node -> name << " parents: " - // << node -> parents.size() << std::endl; - if (node->parents.size() == 0) ready_queue.push(node); - - // init device meshes - RPCExecution *rpc_exe = node->rpc_exe_ptr; - device_meshes.insert(&rpc_exe->device_mesh); - } - - std::vector executed; - - // simulate - while (!ready_queue.empty()) { - RPCInstance *t = ready_queue.top(); - RPCExecution *rpc_exe = t->rpc_exe_ptr; - uint64_t exec_time = rpc_exe->time_cost; - DeviceMesh *device_mesh = &rpc_exe->device_mesh; - ready_queue.pop(); - - if (device_mesh->pre_task == nullptr) { - t->start_time = t->ready_time; - } else { - t->start_time = MAX(t->ready_time, device_mesh->pre_task->end_time); - } - t->end_time = t->start_time + exec_time; - max_end_time = MAX(t->end_time, max_end_time); - - for (DeviceMesh *mesh : device_meshes) { - if (device_mesh->overlap(*mesh)) { - if (mesh->pre_task == nullptr || mesh->pre_task->end_time <= t->end_time) { - mesh->pre_task = t; - } - // mesh -> pre_task = t; - } - } - executed.push_back(t); - - for (RPCInstance *child : t->tmp_children) { - child->ready_time = MAX(t->end_time, child->ready_time); - // std::cout << "parent: " << t -> name - // << " child: " << child -> name << std::endl; - parent_executed[child->name] += 1; - // child -> remove_parent(t); - if (child->tmp_parents.size() == parent_executed[child->name]) { - ready_queue.push(child); - // std::cout << "Ready: " << child -> name - // << " ready time " << child -> ready_time << std::endl; - } - } - // std::cout << "ready_queue size " << ready_queue.size() - // << " executed size " << executed.size() << std::endl; - } - - // 110045999 - // if (max_end_time < 100000000) { // DEBUG - // std::cout << "INDEX: ["; - // for (int i : index) { - // std::cout << i << ", "; - // } - // std::cout << "]" << std::endl; - - // for (auto& x : rpc_exe_table){ - // std::string rpc_name = x.first; - // std::vector& rpc_exe_list = x.second; - // int count = 0; - // for (RPCExecution* rpc_exe : rpc_exe_list) { - // std::cout << "RPC: " << rpc_name - // << " device mesh " << rpc_exe -> device_mesh.device_mesh_name - // << " time cost " << rpc_exe -> time_cost << std::endl; - // count ++; - // if (count > 10) break; - // } - // } - - // for (RPCInstance* ri : executed) { - // for (RPCInstance* parent : ri -> tmp_parents) { - // std::cout << "Parent: " << parent -> name << " of " << ri -> name - // << " start time " << parent -> start_time - // << " end time " << parent -> end_time << std::endl; - // } - - // std::cout << "Executed: " << ri -> name - // << " start time " << ri -> start_time - // << " end time " << ri -> end_time - // << " rpc name " << ri -> rpc_ptr -> rpc_name - // << " device mesh " - // << ri -> rpc_exe_ptr -> device_mesh.device_mesh_name - // << " rpc exe time cost " << ri -> rpc_exe_ptr -> time_cost - // << std::endl; - // } - // } - - // clear device mesh pre tasks - for (DeviceMesh *mesh : device_meshes) { mesh->pre_task = nullptr; } - // clear rpc instance times - for (RPCInstance *node : graph) { - node->tmp_children.clear(); - node->tmp_parents.clear(); - node->ready_time = 0; - node->start_time = 0; - node->end_time = 0; - tmp_graph.clear(); - - for (RPCInstance *ptr : node->tmp_ris) { delete ptr; } - node->tmp_ris.clear(); - - for (RPCExecution *ptr : node->tmp_exes) { delete ptr; } - node->tmp_exes.clear(); - } - - uint64_t current_mem = grouped_rpc_exe.total_mem_cost(); - if (current_mem > SOFT_GPU_MEM_CAP) { max_end_time *= oom_penalty; } - // std::cout << "Max end time: " << max_end_time - // << " executed size " << executed.size() << std::endl; - std::chrono::duration elapsed = std::chrono::high_resolution_clock::now() - start; - // std::cout << "Elapsed time (micro seconds): " - // << std::chrono::duration_cast(elapsed).count() - // << std::endl; - - for (OverlapGroup *ptr : grouped_rpc_exe.group.overlap_groups) { delete ptr; } - grouped_rpc_exe.group.overlap_groups.clear(); - - return SimulateResult(max_end_time, current_mem > SOFT_GPU_MEM_CAP, current_mem, index); -}; - -SimulateResult &SimulateResult::operator=(const SimulateResult &other) { - if (this != &other) { - end_time = other.end_time; - oom = other.oom; - mem_cost = other.mem_cost; - index = other.index; - rpc_exe_list = other.rpc_exe_list; - } - return *this; -}; - -bool isPresent(MinEndTimeQueue &q, SimulateResult element) { - std::priority_queue, CompareEndTime> pq = - q.getQueue(); - std::queue tmp; - while (!pq.empty()) { - if (pq.top().end_time == element.end_time) { return true; } - tmp.push(pq.top()); - pq.pop(); - } - while (!tmp.empty()) { - pq.push(tmp.front()); - tmp.pop(); - } - return false; -} - -void mergeMinEndTimeQueues(MinEndTimeQueue &target, MinEndTimeQueue &q1) { - // Get the underlying priority queues - std::priority_queue, CompareEndTime> pq1 = - q1.getQueue(); - - // Insert all elements from q1 into the merged queue - while (!pq1.empty()) { - if (!isPresent(target, pq1.top())) target.insert(pq1.top()); - pq1.pop(); - } -} \ No newline at end of file diff --git a/csrc/search/simulate.hpp b/csrc/search/simulate.hpp deleted file mode 100644 index 3a35c1e..0000000 --- a/csrc/search/simulate.hpp +++ /dev/null @@ -1,75 +0,0 @@ -#ifndef SIMULATE_HPP -#define SIMULATE_HPP - -#include -#include -#include -#include - -class SimulateResult { - public: - uint64_t end_time; - bool oom; - uint64_t mem_cost; - std::vector index; - std::vector rpc_exe_list; - double used_time = 0; - - SimulateResult(); - - SimulateResult(uint64_t end_time, bool oom, uint64_t mem_cost, std::vector &index); - - SimulateResult &operator=(const SimulateResult &other); -}; - -SimulateResult simulate( - std::vector &graph, std::unordered_map &cost_table, - std::unordered_map &model_sizes, - std::unordered_map &rpc_table, - std::unordered_map> &rpc_exe_table, - std::unordered_map> &ri_table, - std::unordered_map> &model_name_ri_table, - std::vector &sorted_rpc_names, std::vector &index); - -// Comparator for priority queue -struct CompareEndTime { - bool operator()(SimulateResult const &r1, SimulateResult const &r2) { - // We want largest end_time at the top of the queue, so we reverse the comparison - return r1.end_time < r2.end_time; - } -}; - -class MinEndTimeQueue { - public: - MinEndTimeQueue(int capacity) : k(capacity) {} - - void insert(SimulateResult r) { - if (queue.size() < k) { - // std::cout << "push " << "end_time: " << r.end_time << " qsize " << queue.size() << - // std::endl; - queue.push(r); - } else if (r.end_time < queue.top().end_time) { - // std::cout << "push " << "end_time: " << r.end_time << " qsize " << queue.size() << - // std::endl; - queue.pop(); - queue.push(r); - } - } - - std::priority_queue, CompareEndTime> &getQueue() { - return queue; - } - - private: - std::priority_queue, CompareEndTime> queue; - int k; -}; - -void mergeMinEndTimeQueues(MinEndTimeQueue &target, MinEndTimeQueue &q1); - -class CompareReadyTime { - public: - bool operator()(RPCInstance *r1, RPCInstance *r2) { return r1->ready_time > r2->ready_time; } -}; - -#endif // SIMULATE_HPP \ No newline at end of file diff --git a/functioncall/base/call.py b/functioncall/base/call.py index 3f4bb73..977de18 100644 --- a/functioncall/base/call.py +++ b/functioncall/base/call.py @@ -83,7 +83,7 @@ async def async_invoke_function( url: str, timeout: aiohttp.ClientTimeout, payload: Dict[str, Any] = None, - max_retries: int = 100, + max_retries: int = 2, initial_retry_interval: float = 0.5, max_retry_interval: float = 10.0, ): @@ -137,7 +137,7 @@ async def async_invoke_function( ) retries += 1 - if retries > max_retries: + if retries >= max_retries: return { "uid": payload.get("uid", ""), "success": False, @@ -189,12 +189,13 @@ async def batch_function_call_async(payload_list, url, timeout, concurrency=1500 data_list.append(data) elapsed_times.append(elapsed) - p50 = median(elapsed_times) - p90 = calculate_percentile(elapsed_times, 90) - p99 = calculate_percentile(elapsed_times, 99) - logger.info( - f"Longest functioncall took {max_elapsed:.4f} seconds, timeout: {timeout}, uid: {max_elapsed_uid}, Active connections: {len(connector._conns)}, p50: {p50}, p90: {p90}, p99: {p99}" - ) + if len(elapsed_times) > 0: + p50 = median(elapsed_times) + p90 = calculate_percentile(elapsed_times, 90) + p99 = calculate_percentile(elapsed_times, 99) + logger.info( + f"Longest functioncall took {max_elapsed:.4f} seconds, timeout: {timeout}, uid: {max_elapsed_uid}, Active connections: {len(connector._conns)}, p50: {p50}, p90: {p90}, p99: {p99}" + ) return data_list diff --git a/functioncall/code/function/testing_util.py b/functioncall/code/function/testing_util.py index a526ff3..c7c8985 100644 --- a/functioncall/code/function/testing_util.py +++ b/functioncall/code/function/testing_util.py @@ -1,6 +1,7 @@ import ast import faulthandler import json +import os import platform # to run the solution files we're using a timing based approach @@ -38,6 +39,14 @@ def truncatefn(s, length=300): return s[: length // 2] + "...(truncated) ..." + s[-length // 2 :] +def load_from_path(l): + outputs = [] + for x in l: + with open(x) as f: + outputs.append(f.read()) + return outputs + + class CODE_TYPE(Enum): call_based = 0 standard_input = 1 @@ -105,6 +114,13 @@ def run_test(sample, test=None, debug=False, timeout=6): which_type = CODE_TYPE.call_based if in_outs: + assert "inputs" in in_outs + assert "outputs" in in_outs + if os.path.isfile(in_outs["inputs"][0]): + assert os.path.isfile(in_outs["outputs"][0]) + in_outs["inputs"] = load_from_path(in_outs["inputs"]) + in_outs["outputs"] = load_from_path(in_outs["outputs"]) + if in_outs.get("fn_name", "") == "": which_type = CODE_TYPE.standard_input # Standard input method_name = None diff --git a/functioncall/code/local_verify.py b/functioncall/code/local_verify.py index 321edc2..be32608 100644 --- a/functioncall/code/local_verify.py +++ b/functioncall/code/local_verify.py @@ -59,9 +59,10 @@ def call_verify(problem, generation, debug, timeout=SINGLE_CASE_EXEC_TIMEOUT): shell=True, preexec_fn=os.setsid, stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, ) try: - pro.wait(600) + pro.wait(200) except Exception as e: pass try: diff --git a/functioncall/code/verify.py b/functioncall/code/verify.py index 63b5a06..cbfad34 100644 --- a/functioncall/code/verify.py +++ b/functioncall/code/verify.py @@ -9,7 +9,7 @@ from functioncall.base.utils import construct_uid, load_jsonl, logger SINGLE_CASE_EXEC_TIMEOUT = 6 TEST_CASE_BATCH_SIZE = 1 -FUNCTIONCALL_TIMEOUT = 1000 +FUNCTIONCALL_TIMEOUT = 100 def round_up_memory(memory): diff --git a/patch/sglang/v0.4.6.post2.patch b/patch/sglang/v0.4.6.post2.patch new file mode 100644 index 0000000..6bf47bf --- /dev/null +++ b/patch/sglang/v0.4.6.post2.patch @@ -0,0 +1,144 @@ +diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py +index 174656b2..33fe0a5f 100644 +--- a/python/sglang/srt/managers/io_struct.py ++++ b/python/sglang/srt/managers/io_struct.py +@@ -687,10 +687,21 @@ class FlushCacheReqOutput: + success: bool + + ++@dataclass ++class InterruptAllReqInput: ++ pass ++ ++ ++@dataclass ++class InterruptAllReqOutput: ++ num_interrupted_requests: int ++ ++ + @dataclass + class UpdateWeightFromDiskReqInput: + # The model path with the new weights + model_path: str ++ allow_interrupt: bool = False + # The format to load the weights + load_format: Optional[str] = None + +diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py +index 8891115c..843a8a82 100644 +--- a/python/sglang/srt/managers/scheduler.py ++++ b/python/sglang/srt/managers/scheduler.py +@@ -70,6 +70,8 @@ from sglang.srt.managers.io_struct import ( + HealthCheckOutput, + InitWeightsUpdateGroupReqInput, + InitWeightsUpdateGroupReqOutput, ++ InterruptAllReqInput, ++ InterruptAllReqOutput, + OpenSessionReqInput, + OpenSessionReqOutput, + ProfileReq, +@@ -419,6 +421,7 @@ class Scheduler( + # Init request dispatcher + self._request_dispatcher = TypeBasedDispatcher( + [ ++ (InterruptAllReqInput, self.interrupt_all_requests), + (TokenizedGenerateReqInput, self.handle_generate_request), + (TokenizedEmbeddingReqInput, self.handle_embedding_request), + (FlushCacheReqInput, self.flush_cache_wrapped), +@@ -1938,6 +1941,15 @@ class Scheduler( + def _pause_engine(self) -> Tuple[List[Req], int]: + raise NotImplementedError() + ++ def interrupt_all_requests(self, recv_req: InterruptAllReqInput): ++ num = len(self.waiting_queue) + len(self.running_batch.reqs) ++ for req in self.waiting_queue: ++ req.sampling_params.max_new_tokens = 0 ++ for req in self.running_batch.reqs: ++ req.sampling_params.max_new_tokens = len(req.output_ids) ++ logger.info(f"Interrupt {num} requests.") ++ return InterruptAllReqOutput(num) ++ + def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput): + """In-place update of the weights from disk.""" + success, message = self.tp_worker.update_weights_from_disk(recv_req) +diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py +index 82709b09..bfab3ce7 100644 +--- a/python/sglang/srt/managers/tokenizer_manager.py ++++ b/python/sglang/srt/managers/tokenizer_manager.py +@@ -76,6 +76,8 @@ from sglang.srt.managers.io_struct import ( + HealthCheckOutput, + InitWeightsUpdateGroupReqInput, + InitWeightsUpdateGroupReqOutput, ++ InterruptAllReqInput, ++ InterruptAllReqOutput, + OpenSessionReqInput, + OpenSessionReqOutput, + ProfileReq, +@@ -265,6 +267,9 @@ class TokenizerManager: + self.resume_memory_occupation_communicator = _Communicator( + self.send_to_scheduler, server_args.dp_size + ) ++ self.interrupt_requests_communicator = _Communicator( ++ self.send_to_scheduler, server_args.dp_size ++ ) + self.flush_cache_communicator = _Communicator( + self.send_to_scheduler, server_args.dp_size + ) +@@ -294,6 +299,10 @@ class TokenizerManager: + UpdateWeightFromDiskReqOutput, + self._handle_update_weights_from_disk_req_output, + ), ++ ( ++ InterruptAllReqOutput, ++ self.interrupt_requests_communicator.handle_recv, ++ ), + ( + InitWeightsUpdateGroupReqOutput, + self.init_weights_update_group_communicator.handle_recv, +@@ -767,6 +776,13 @@ class TokenizerManager: + ) -> Tuple[bool, str]: + self.auto_create_handle_loop() + ++ if obj.allow_interrupt: ++ num_interrupted_requests = await self.interrupt_all_requests( ++ InterruptAllReqInput() ++ ) ++ # Set a break point to wait for the interrupt to finish ++ await asyncio.sleep(0.1) ++ + # default the load format to the server_args + if obj.load_format is None: + obj.load_format = self.server_args.load_format +@@ -776,7 +792,12 @@ class TokenizerManager: + # Hold the lock if it is not async. This means that weight sync + # cannot run while requests are in progress. + async with self.model_update_lock.writer_lock: +- return await self._wait_for_model_update_from_disk(obj) ++ success, message, n_paused = ( ++ await self._wait_for_model_update_from_disk(obj) ++ ) ++ if obj.allow_interrupt: ++ return success, message, num_interrupted_requests ++ return success, message, n_paused + + async def _wait_for_model_update_from_disk( + self, obj: UpdateWeightFromDiskReqInput +@@ -849,6 +870,18 @@ class TokenizerManager: + result = (await self.update_weights_from_tensor_communicator(obj))[0] + return result.success, result.message + ++ async def interrupt_all_requests( ++ self, ++ obj: InterruptAllReqInput, ++ request: Optional[fastapi.Request] = None, ++ ) -> Tuple[bool, str]: ++ self.auto_create_handle_loop() ++ result = await self.interrupt_requests_communicator(obj) ++ if self.server_args.dp_size == 1: ++ return result[0].num_interrupted_requests ++ else: ++ return [r.num_interrupted_requests for r in result] ++ + async def get_weights_by_name( + self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None + ): diff --git a/patch/sglang/v0.4.6.post4.patch b/patch/sglang/v0.4.6.post4.patch new file mode 100644 index 0000000..b7dbd09 --- /dev/null +++ b/patch/sglang/v0.4.6.post4.patch @@ -0,0 +1,144 @@ +diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py +index 5390668c..db370d19 100644 +--- a/python/sglang/srt/managers/io_struct.py ++++ b/python/sglang/srt/managers/io_struct.py +@@ -687,10 +687,21 @@ class FlushCacheReqOutput: + success: bool + + ++@dataclass ++class InterruptAllReqInput: ++ pass ++ ++ ++@dataclass ++class InterruptAllReqOutput: ++ num_interrupted_requests: int ++ ++ + @dataclass + class UpdateWeightFromDiskReqInput: + # The model path with the new weights + model_path: str ++ allow_interrupt: bool = False + # The format to load the weights + load_format: Optional[str] = None + +diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py +index 1178eec5..318dee33 100644 +--- a/python/sglang/srt/managers/scheduler.py ++++ b/python/sglang/srt/managers/scheduler.py +@@ -73,6 +73,8 @@ from sglang.srt.managers.io_struct import ( + HealthCheckOutput, + InitWeightsUpdateGroupReqInput, + InitWeightsUpdateGroupReqOutput, ++ InterruptAllReqInput, ++ InterruptAllReqOutput, + OpenSessionReqInput, + OpenSessionReqOutput, + ProfileReq, +@@ -427,6 +429,7 @@ class Scheduler( + # Init request dispatcher + self._request_dispatcher = TypeBasedDispatcher( + [ ++ (InterruptAllReqInput, self.interrupt_all_requests), + (TokenizedGenerateReqInput, self.handle_generate_request), + (TokenizedEmbeddingReqInput, self.handle_embedding_request), + (FlushCacheReqInput, self.flush_cache_wrapped), +@@ -1971,6 +1974,15 @@ class Scheduler( + def _pause_engine(self) -> Tuple[List[Req], int]: + raise NotImplementedError() + ++ def interrupt_all_requests(self, recv_req: InterruptAllReqInput): ++ num = len(self.waiting_queue) + len(self.running_batch.reqs) ++ for req in self.waiting_queue: ++ req.sampling_params.max_new_tokens = 0 ++ for req in self.running_batch.reqs: ++ req.sampling_params.max_new_tokens = len(req.output_ids) ++ logger.info(f"Interrupt {num} requests.") ++ return InterruptAllReqOutput(num) ++ + def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput): + """In-place update of the weights from disk.""" + success, message = self.tp_worker.update_weights_from_disk(recv_req) +diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py +index b646fae1..c668728b 100644 +--- a/python/sglang/srt/managers/tokenizer_manager.py ++++ b/python/sglang/srt/managers/tokenizer_manager.py +@@ -80,6 +80,8 @@ from sglang.srt.managers.io_struct import ( + HealthCheckOutput, + InitWeightsUpdateGroupReqInput, + InitWeightsUpdateGroupReqOutput, ++ InterruptAllReqInput, ++ InterruptAllReqOutput, + OpenSessionReqInput, + OpenSessionReqOutput, + ProfileReq, +@@ -279,6 +281,9 @@ class TokenizerManager: + self.slow_down_communicator = _Communicator( + self.send_to_scheduler, server_args.dp_size + ) ++ self.interrupt_requests_communicator = _Communicator( ++ self.send_to_scheduler, server_args.dp_size ++ ) + self.flush_cache_communicator = _Communicator( + self.send_to_scheduler, server_args.dp_size + ) +@@ -309,6 +314,10 @@ class TokenizerManager: + UpdateWeightFromDiskReqOutput, + self._handle_update_weights_from_disk_req_output, + ), ++ ( ++ InterruptAllReqOutput, ++ self.interrupt_requests_communicator.handle_recv, ++ ), + ( + InitWeightsUpdateGroupReqOutput, + self.init_weights_update_group_communicator.handle_recv, +@@ -799,6 +808,13 @@ class TokenizerManager: + ) -> Tuple[bool, str]: + self.auto_create_handle_loop() + ++ if obj.allow_interrupt: ++ num_interrupted_requests = await self.interrupt_all_requests( ++ InterruptAllReqInput() ++ ) ++ # Set a break point to wait for the interrupt to finish ++ await asyncio.sleep(0.1) ++ + # default the load format to the server_args + if obj.load_format is None: + obj.load_format = self.server_args.load_format +@@ -808,7 +824,12 @@ class TokenizerManager: + # Hold the lock if it is not async. This means that weight sync + # cannot run while requests are in progress. + async with self.model_update_lock.writer_lock: +- return await self._wait_for_model_update_from_disk(obj) ++ success, message, n_paused = ( ++ await self._wait_for_model_update_from_disk(obj) ++ ) ++ if obj.allow_interrupt: ++ return success, message, num_interrupted_requests ++ return success, message, n_paused + + async def _wait_for_model_update_from_disk( + self, obj: UpdateWeightFromDiskReqInput +@@ -881,6 +902,18 @@ class TokenizerManager: + result = (await self.update_weights_from_tensor_communicator(obj))[0] + return result.success, result.message + ++ async def interrupt_all_requests( ++ self, ++ obj: InterruptAllReqInput, ++ request: Optional[fastapi.Request] = None, ++ ) -> Tuple[bool, str]: ++ self.auto_create_handle_loop() ++ result = await self.interrupt_requests_communicator(obj) ++ if self.server_args.dp_size == 1: ++ return result[0].num_interrupted_requests ++ else: ++ return [r.num_interrupted_requests for r in result] ++ + async def get_weights_by_name( + self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None + ): diff --git a/realhf/api/cli_args.py b/realhf/api/cli_args.py index 20a5a86..06b2b3c 100644 --- a/realhf/api/cli_args.py +++ b/realhf/api/cli_args.py @@ -97,8 +97,7 @@ class PromptOnlyDatasetConfig: class ModelFamily: """Identifier for HuggingFace model types (e.g., llama, gpt2). - Used for model registration and allocation. The size parameter is specifically - relevant for the 'search' allocation mode. + Used for model registration and allocation. """ _class: str = field( @@ -107,12 +106,6 @@ class ModelFamily: "`realhf/api/from_hf` for supported models.", } ) - size: int = field( - default=0, - metadata={ - "help": "Model size parameter. Only used by 'search' allocation mode, ignored otherwise", - }, - ) is_critic: bool = field( default=False, metadata={ @@ -121,8 +114,8 @@ class ModelFamily: ) def __repr__(self): - """Returns formatted string representation: '{class}-{size}[-critic]'.""" - s = f"{self._class}-{self.size}" + """Returns formatted string representation: '{class}[-critic]'.""" + s = f"{self._class}" if self.is_critic: s += "-critic" return s @@ -136,7 +129,7 @@ class ParallelismConfig: Sequence parallelism is only used in combination with tensor-model parallelism. """ - model_parallel_size: int = field( + tensor_parallel_size: int = field( default=1, metadata={"help": "Size of tensor-model parallelism"} ) pipeline_parallel_size: int = field( @@ -155,7 +148,7 @@ class ParallelismConfig: def __str__(self): """Returns compact string representation: 'Parallel(mp=X,pp=Y,dp=Z)'.""" return ( - f"Parallel(mp={self.model_parallel_size}," + f"Parallel(mp={self.tensor_parallel_size}," f"pp={self.pipeline_parallel_size}," f"dp={self.data_parallel_size})" ) @@ -168,7 +161,7 @@ class ParallelismConfig: Implemented as static method to avoid OmegaConf compatibility issues. """ return ( - (this.model_parallel_size == other.model_parallel_size) + (this.tensor_parallel_size == other.tensor_parallel_size) and (this.pipeline_parallel_size == other.pipeline_parallel_size) and (this.data_parallel_size == other.data_parallel_size) ) @@ -186,7 +179,7 @@ class OptimizerConfig: default="adam", metadata={"help": "Optimizer type", "choices": ["adam", "empty"]}, ) - lr: float = field(default=1e-5, metadata={"help": "Learning rate"}) + lr: float = field(default=2e-5, metadata={"help": "Learning rate"}) weight_decay: float = field(default=0.05, metadata={"help": "Weight decay"}) beta1: float = field(default=0.9, metadata={"help": "Adam beta1 parameter"}) beta2: float = field(default=0.95, metadata={"help": "Adam beta2 parameter"}) @@ -198,14 +191,14 @@ class OptimizerConfig: }, ) lr_scheduler_type: str = field( - default="cosine", + default="constant", metadata={ "help": "Learning rate scheduler type", "choices": ["linear", "cosine", "constant"], }, ) warmup_steps_proportion: float = field( - default=0.02, + default=0.001, metadata={ "help": "Proportion of training steps for warmup", }, @@ -237,6 +230,7 @@ class vLLMConfig: """ max_num_seqs: int = 256 + dtype: str = "float16" kv_cache_type: str = "auto" num_scheduler_steps: int = 1 multi_step_stream_outputs: bool = True @@ -278,7 +272,6 @@ class SGLangConfig: enable_nccl_nvls: bool = False disable_outlines_disk_cache: bool = False disable_custom_all_reduce: bool = False - disable_mla: bool = False disable_overlap_schedule: bool = False enable_mixed_chunk: bool = False enable_dp_attention: bool = False @@ -296,7 +289,7 @@ class SGLangConfig: enable_memory_saver: bool = False allow_auto_truncate: bool = False # NOTE: to avoid the illegal memory access error - attention_backend: Optional[str] = "triton" + attention_backend: Optional[str] = "flashinfer" sampling_backend: Optional[str] = None context_length: Optional[int] = 32768 mem_fraction_static: Optional[float] = 0.9 @@ -309,15 +302,19 @@ class SGLangConfig: schedule_conservativeness: float = 1.0 cpu_offload_gb: int = 0 hybrid_train: bool = False + dtype: str = "float16" + kv_cache_dtype: str = "auto" # logging - log_level: str = "info" + log_level: str = "warning" log_level_http: Optional[str] = "warning" log_requests: bool = False log_requests_level: int = 0 show_time_cost: bool = False enable_metrics: bool = True # Exports Prometheus-like metrics - decode_log_interval: int = 1000 # How often (in tokens) to log decode progress. + # The interval (in decoding iterations) to log throughput + # and update prometheus metrics + decode_log_interval: int = 1 # Use staticmethod to make OmegaConf happy. @staticmethod @@ -327,6 +324,7 @@ class SGLangConfig: tp_size, server_index, base_gpu_id, + dist_init_addr: Optional[str] = None, ): from realhf.base import constants, network, pkg_version, seeding from realhf.experiments.common.utils import asdict as conf_as_dict @@ -345,7 +343,6 @@ class SGLangConfig: tokenizer_mode="auto", load_format="auto", trust_remote_code=True, - kv_cache_dtype="auto", device="cuda", served_model_name=f"{constants.experiment_name()}/{constants.trial_name()}/{model_path}", is_embedding=False, @@ -365,6 +362,7 @@ class SGLangConfig: ep_size=1, # TODO: check nnodes=1, node_rank=0, + dist_init_addr=dist_init_addr, **args, ) @@ -385,6 +383,10 @@ class SGLangConfig: if v is True: flags.append(f"--{k.replace('_','-')} ") continue + if isinstance(v, list): + values = " ".join(map(str, v)) + flags.append(f"--{k.replace('_','-')} {values}") + continue flags.append(f"--{k.replace('_','-')} {v}") flags = " ".join(flags) return f"python3 -m sglang.launch_server {flags}" @@ -444,7 +446,7 @@ class ModelTrainEvalConfig: # Model Architecture Configuration type: ModelFamily = field( - default=ModelFamily("llama", 7, False), + default=ModelFamily("llama", False), metadata={"help": "Model family specification"}, ) path: str = field(default="", metadata={"help": "Path to HuggingFace checkpoint"}) @@ -679,13 +681,13 @@ class PPOHyperparameters: value_norm_eps: float = field( default=1e-5, metadata={"help": "Epsilon term for numerical stability"} ) - - # Experimental Features recompute_logprob: bool = field( default=False, - metadata={ - "help": "Recompute log probabilities after generation. Used mainly for debugging purposes" - }, + metadata={"help": "Recompute logp and replace the logp returned by inference."}, + ) + use_decoupled_loss: bool = field( + default=False, + metadata={"help": "Use the decoupled loss. recompute_logprob must be True."}, ) @@ -772,6 +774,13 @@ class ExperimentSaveEvalControl: "For benchmarking purposes only. None indicates normal training." }, ) + benchmark_n_seqs: Optional[int] = field( + default=None, + metadata={ + "help": "Terminate training after consuming this number of samples. " + "For benchmarking purposes only. None indicates normal training." + }, + ) @dataclass @@ -847,7 +856,7 @@ class BaseExperimentConfig: Note: - Recovery modes: auto, fault, resume, disabled - - Allocation modes: manual, search, heuristic, or pattern-based + - Allocation modes: manual, heuristic, or pattern-based """ experiment_name: str = field( @@ -919,13 +928,9 @@ class BaseExperimentConfig: default="", metadata={ "help": "GPU parallel strategy allocation mode. " - "Options: manual/search/heuristic or pattern-based." + "Options: manual/heuristic or pattern-based." }, ) - allocation_use_cache: bool = field( - default=False, - metadata={"help": "Use allocation search cache (search mode only)."}, - ) n_nodes: int = field( default=1, metadata={"help": "Number of nodes for experiment."} ) @@ -998,9 +1003,17 @@ class BaseExperimentConfig: @dataclass class AsyncRLOptions: + schedule_policy: str = field( + default="round_robin", + metadata={ + "help": "The request schedule policy during generation. Available options: [round_robin]." + }, + ) new_tokens_per_chunk: int = field( - default=1024, - metadata={"help": "The lenght of chunked generation."}, + default=int(1e10), + metadata={ + "help": "The length of chunked generation. Only valid if inference can't be interrupted." + }, ) max_head_offpolicyness: int = field( default=0, @@ -1013,9 +1026,11 @@ class AsyncRLOptions: "help": "Number of rollout workers. None defaults to train world size." }, ) - max_concurrent_rollouts: int = field( - default=1024, - metadata={"help": "Max concurrent rollout jobs in each worker."}, + max_concurrent_rollouts: Optional[int] = field( + default=None, + metadata={ + "help": "Max concurrent rollouts globally. Defaults to train batch size." + }, ) flush_request_timeout: int = field( default=120, @@ -1225,6 +1240,12 @@ class PPOMATHExperimentOptions: }, ) + # testing only + no_training: bool = field( + default=False, + metadata={"help": "Run without training. Test-only."}, + ) + @dataclass class MathCodeEvalOptions: diff --git a/realhf/api/core/config.py b/realhf/api/core/config.py index e6a45f8..b516d2d 100644 --- a/realhf/api/core/config.py +++ b/realhf/api/core/config.py @@ -100,8 +100,8 @@ class ModelShardID: :type model_name: ModelName :param dp_rank: The data parallel rank. :type dp_rank: int - :param mp_rank: The tensor-model parallel rank. - :type mp_rank: int + :param tp_rank: The tensor-model parallel rank. + :type tp_rank: int :param pp_rank: The pipeline-model parallel rank. :type pp_rank: int :param topo: The 3D parallelism topology of this model. @@ -110,22 +110,22 @@ class ModelShardID: model_name: ModelName dp_rank: int - mp_rank: int + tp_rank: int pp_rank: int topo: topology.ProcessTopology def __post_init__(self): - assert self.dp_rank >= 0 and self.mp_rank >= 0 and self.pp_rank >= 0 + assert self.dp_rank >= 0 and self.tp_rank >= 0 and self.pp_rank >= 0 if "@" in self.model_name.role: raise ValueError("model_name cannot contain @") assert self.dp_rank < self.topo.get_dim("data") - assert self.mp_rank < self.topo.get_dim("model") + assert self.tp_rank < self.topo.get_dim("tensor") assert self.pp_rank < self.topo.get_dim("pipe") @property def parallelism_rank(self): return self.topo.get_rank( - data=self.dp_rank, model=self.mp_rank, pipe=self.pp_rank + data=self.dp_rank, tensor=self.tp_rank, pipe=self.pp_rank ) @classmethod @@ -134,14 +134,14 @@ class ModelShardID: return cls( model_name=model_name, dp_rank=c.data, - mp_rank=c.model, + tp_rank=c.tensor, pp_rank=c.pipe, topo=topo, ) def __repr__(self): n = cluster.spec.suffix_n_digits - return f"{self.model_name}@pp{self.pp_rank:0{n}d}@mp{self.mp_rank:0{n}d}@dp{self.dp_rank:0{n}d}" + return f"{self.model_name}@pp{self.pp_rank:0{n}d}@tp{self.tp_rank:0{n}d}@dp{self.dp_rank:0{n}d}" def __hash__(self): return hash(str(self)) @@ -152,7 +152,7 @@ class ModelShardID: return ( self.model_name == other.model_name and self.dp_rank == other.dp_rank - and self.mp_rank == other.mp_rank + and self.tp_rank == other.tp_rank and self.pp_rank == other.pp_rank ) return False diff --git a/realhf/api/core/data_api.py b/realhf/api/core/data_api.py index 372de9a..379474f 100644 --- a/realhf/api/core/data_api.py +++ b/realhf/api/core/data_api.py @@ -547,6 +547,7 @@ class SequenceSample: return [[seqlen] for seqlen in seqlens] elif key in [ "packed_logprobs", + "prox_logp", "logprobs", "packed_ref_logprobs", "ref_logprobs", diff --git a/realhf/api/core/dfg.py b/realhf/api/core/dfg.py index 16b0c0d..a5543ac 100644 --- a/realhf/api/core/dfg.py +++ b/realhf/api/core/dfg.py @@ -10,7 +10,6 @@ import matplotlib.pyplot as plt import networkx as nx import realhf.base.logging as logging -from realhf.api.cli_args import ModelFamily from realhf.api.core.config import ( ModelInterfaceAbstraction, ModelInterfaceType, @@ -94,13 +93,6 @@ class MFCDef: :type min_n_seqs_per_pass: int :param log_return_value: Whether to log the return value of the interface implementation. :type log_return_value: bool - :param model_type: The specification of the LLM, e.g., LLaMA-7B. Used by the profiler and - search engine to produce an optimal execution plan. Can be omitted if the search engine - is not used. - :type model_type: Optional[ModelFamily] - :param model_path: The path to the model file. Used to get the config for the search engine. - Can be omitted if the search engine is not used. - :type model_path: Optional[str] """ # The unique identifier of this model function call. @@ -126,10 +118,6 @@ class MFCDef: min_n_seqs_per_pass: int | float = 1 log_return_value: bool = False - # Only used by search. - model_type: Optional[Any | ModelFamily] = None - model_path: Optional[str] = None - # Reserved dataclasses.fields. Should not be set by the user. _G: nx.DiGraph = None _pre_hooks: List[RPCHook] = dataclasses.field(default_factory=lambda: []) diff --git a/realhf/api/core/model_api.py b/realhf/api/core/model_api.py index 64223c8..809e7df 100644 --- a/realhf/api/core/model_api.py +++ b/realhf/api/core/model_api.py @@ -11,6 +11,7 @@ from typing import Any, Callable, Dict, Hashable, List, Literal, Optional, Tuple import aiohttp import numpy as np import torch +import torch.distributed as dist import torch.utils.data import transformers @@ -24,6 +25,7 @@ from realhf.api.core.config import ( ModelWrapperAbstraction, ) from realhf.api.core.data_api import MicroBatchSpec, SequenceSample, load_hf_tokenizer +from realhf.base.datapack import flat2d from realhf.base.recover import StepInfo logger = logging.getLogger("model_api") @@ -37,15 +39,19 @@ class ZeroTotalLossWeightException(Exception): class GenRespMeta: qid: str accepted: bool + n_tokens: int @dataclasses.dataclass class GenReqMeta: ## Meta info used to schedule the request. ## + qid: Hashable prompt_len: int group_size: int new_token_budget: int predicted_new_tokens: int | None + previous_server_url: str = "" + previous_version: int = -1 @dataclasses.dataclass @@ -120,6 +126,7 @@ class APIGenerateOutput: @staticmethod def concat(outputs: List["APIGenerateOutput"]): + assert len(set([o.qid for o in outputs])) == 1 return APIGenerateOutput( qid=outputs[0].qid, prompt_ids=outputs[0].prompt_ids, @@ -436,6 +443,8 @@ class ReaLModelConfig: rotary_special_impl: Optional[str] = None # for gemma normalize_embed: bool = False + # for qwen3 + qk_layernorm: bool = False # for opt, it's 2 abs_position_embedding_offset: int = 0 do_layernorm_before: bool = True @@ -798,7 +807,7 @@ class ModelInterface(abc.ABC): model: Model, data: SequenceSample, mb_spec: MicroBatchSpec, - ) -> Dict: + ) -> Dict | List[Dict]: raise NotImplementedError() # Mock methods for creating data and profiling an individual MFC. @@ -860,7 +869,17 @@ class NullInterface(ModelInterface): def train_step( self, model: Model, data: SequenceSample, mb_spec: MicroBatchSpec - ) -> Dict: + ) -> Dict | List[Dict]: + from realhf.base import constants + + n_tokens = sum(flat2d(data.seqlens[data._get_split_key()])) + n_tokens = torch.tensor( + n_tokens, dtype=torch.long, device=constants.current_device() + ) + dist.all_reduce(n_tokens, group=constants.data_parallel_group()) + if constants.parallelism_rank() == 0: + logger.info(f"Number of tokens in NullInterface training: {int(n_tokens)}") + model.inc_version() return {} def save(self, model: Model, save_dir: str): diff --git a/realhf/api/core/system_api.py b/realhf/api/core/system_api.py index 9ec77be..a0a6746 100644 --- a/realhf/api/core/system_api.py +++ b/realhf/api/core/system_api.py @@ -462,8 +462,8 @@ class ExperimentConfig: ) self_topo = model_topos[rpc.model_name] if ( - self_topo.get_dim("model") % other_topo.get_dim("model") != 0 - and other_topo.get_dim("model") % self_topo.get_dim("model") != 0 + self_topo.get_dim("tensor") % other_topo.get_dim("tensor") != 0 + and other_topo.get_dim("tensor") % self_topo.get_dim("tensor") != 0 ): raise ValueError( "To synchronize parameters between two models, " diff --git a/realhf/api/from_hf/qwen3.py b/realhf/api/from_hf/qwen3.py new file mode 100644 index 0000000..2c6a656 --- /dev/null +++ b/realhf/api/from_hf/qwen3.py @@ -0,0 +1,252 @@ +# Copyright 2025 Ant Group Inc. +# Copyright 2024 Wei Fu & Zhiyu Mei +# Licensed under the Apache License, Version 2.0 (the "License"). + +from typing import * + +from transformers.configuration_utils import PretrainedConfig + +from realhf.api.core.model_api import ReaLModelConfig, register_hf_family +from realhf.base.testing import ( + TESTING_MODEL_HEAD_DIM, + TESTING_MODEL_HIDDEN_SIZE, + TESTING_MODEL_INTERMEDIATE_SIZE, + TESTING_MODEL_N_HEADS, + TESTING_MODEL_N_LAYERS, + TESTING_MODEL_N_POSITIONS, + TESTING_MODEL_VOCAB_SIZE, +) + +from .llama import ( + convert_state_dict_llama, + llama_embedding_layer_names, + llama_output_head_param_name, + to_llama_state_dict, +) + + +class Qwen3Config(PretrainedConfig): + + model_type = "qwen3" + keys_to_ignore_at_inference = ["past_key_values"] + + # Default tensor parallel plan for base model `Qwen3` + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.mlp.gate_proj": "colwise", + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", + } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } + + def __init__( + self, + vocab_size=151936, + hidden_size=4096, + intermediate_size=22016, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=32, + head_dim=128, + hidden_act="silu", + max_position_embeddings=32768, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + attention_bias=False, + use_sliding_window=False, + sliding_window=4096, + max_window_layers=28, + attention_dropout=0.0, + **kwargs, + ): + from transformers.modeling_rope_utils import rope_config_validation + + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.use_sliding_window = use_sliding_window + self.sliding_window = ( + sliding_window # we check `use_sliding_window` in the modeling code + ) + self.max_window_layers = max_window_layers + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.head_dim = head_dim + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + # Validate the correctness of rotary position embeddings parameters + # BC: if there is a 'type' field, move it to 'rope_type'. + if self.rope_scaling is not None and "type" in self.rope_scaling: + self.rope_scaling["rope_type"] = self.rope_scaling["type"] + rope_config_validation(self) + + super().__init__( + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +def convert_config_qwen3( + hf_config: Qwen3Config, +) -> ReaLModelConfig: + return ReaLModelConfig( + n_layers=hf_config.num_hidden_layers, + n_kv_heads=hf_config.num_key_value_heads, + hidden_dim=hf_config.hidden_size, + n_q_heads=hf_config.num_attention_heads, + head_dim=getattr( + hf_config, + "head_dim", + hf_config.hidden_size // hf_config.num_attention_heads, + ), + intermediate_dim=hf_config.intermediate_size, + vocab_size=hf_config.vocab_size, + n_positions=hf_config.max_position_embeddings, + embd_pdrop=0.0, + attn_pdrop=( + hf_config.attention_dropout + if hasattr(hf_config, "attention_dropout") + else 0.1 + ), + layer_norm_epsilon=hf_config.rms_norm_eps, + activation_function=hf_config.hidden_act, + use_attention_bias=False, + use_attn_proj_bias=False, + scale_attn_by_inverse_layer_idx=False, + layer_norm_type="rms", + qk_layernorm=True, + mlp_type="llama", + apply_rotary=True, + rotary_base=hf_config.rope_theta, + rotary_interleaved=False, + tied_embedding=hf_config.tie_word_embeddings, + ) + + +def convert_config_back_qwen3( + config: ReaLModelConfig, +) -> Qwen3Config: + return Qwen3Config( + vocab_size=config.vocab_size, + hidden_size=config.hidden_dim, + intermediate_size=config.intermediate_dim, + num_hidden_layers=config.n_layers, + num_key_value_heads=config.n_kv_heads, + num_attention_heads=config.n_q_heads, + head_dim=config.head_dim, + max_position_embeddings=config.n_positions, + rms_norm_eps=config.layer_norm_epsilon, + hidden_act=config.activation_function, + attention_dropout=config.attn_pdrop, + rope_theta=config.rotary_base, + architectures=["Qwen3ForCausalLM"], # ["Qwen3ForCausalLM"], + tie_word_embeddings=config.tied_embedding, + ) + + +def qwen3_config_maker(): + hf_config = Qwen3Config( + vocab_size=TESTING_MODEL_VOCAB_SIZE, + max_position_embeddings=TESTING_MODEL_N_POSITIONS, + hidden_size=TESTING_MODEL_HIDDEN_SIZE, + intermediate_size=TESTING_MODEL_INTERMEDIATE_SIZE, + num_hidden_layers=TESTING_MODEL_N_LAYERS, + num_attention_heads=TESTING_MODEL_N_HEADS, + head_dim=TESTING_MODEL_HEAD_DIM, + num_key_value_heads=8, + hidden_act="silu", + rms_norm_eps=1e-5, + ) + return convert_config_qwen3(hf_config) + + +def convert_state_dict_qwen3(state_dict: Dict, config: ReaLModelConfig) -> Dict: + llama_state_dict = convert_state_dict_llama(state_dict, config) + # model.layers.0.self_attn.k_norm.weight -> 1.attn.k_ln.weight + new_state_dict = {} + for k, v in llama_state_dict.items(): + if "k_norm" in k: + k = k.replace("k_norm", "k_ln") + if "q_norm" in k: + k = k.replace("q_norm", "q_ln") + new_state_dict[k] = v + return new_state_dict + + +def convert_state_dict_back_qwen3(state_dict: Dict, config: ReaLModelConfig) -> Dict: + new_sd = to_llama_state_dict(state_dict, config) + layer_indices = list(set([int(k.split(".")[0]) for k in state_dict.keys()])) + for i in layer_indices: + if i == 0 or i == config.n_layers + 1: + continue + new_sd[f"model.layers.{i - 1}.self_attn.k_norm.weight"] = state_dict[ + f"{i}.attn.k_ln.weight" + ] + new_sd[f"model.layers.{i - 1}.self_attn.q_norm.weight"] = state_dict[ + f"{i}.attn.q_ln.weight" + ] + return new_sd + + +def qwen3_transformer_block_param_name(config: ReaLModelConfig, idx: int) -> List[str]: + names = [] + for k in ["weight", "bias"]: + names += [ + f"model.layers.{idx}.input_layernorm.{k}", + f"model.layers.{idx}.mlp.down_proj.{k}", + f"model.layers.{idx}.mlp.gate_proj.{k}", + f"model.layers.{idx}.mlp.up_proj.{k}", + f"model.layers.{idx}.post_attention_layernorm.{k}", + f"model.layers.{idx}.self_attn.k_proj.{k}", + f"model.layers.{idx}.self_attn.o_proj.{k}", + f"model.layers.{idx}.self_attn.q_proj.{k}", + # f"model.layers.{idx}.self_attn.rotary_emb.inv_freq", + f"model.layers.{idx}.self_attn.v_proj.{k}", + ] + if idx == config.n_layers - 1: + names += [f"model.norm.{k}"] + # Qwen3 + if config.qk_layernorm: + names += [ + f"model.layers.{idx}.self_attn.q_norm.weight", + f"model.layers.{idx}.self_attn.k_norm.weight", + ] + return names + + +register_hf_family( + name="qwen3", + hf_cls_name="Qwen3ForCausalLM", # "Qwen3ForCausalLM" + config_from_hf_converter=convert_config_qwen3, + config_to_hf_converter=convert_config_back_qwen3, + sd_from_hf_converter=convert_state_dict_qwen3, + sd_to_hf_converter=convert_state_dict_back_qwen3, + embedding_param_names=llama_embedding_layer_names, + tblock_param_names=qwen3_transformer_block_param_name, + head_param_names=llama_output_head_param_name, + real_config_maker=qwen3_config_maker, +) diff --git a/realhf/api/quickstart/device_mesh.py b/realhf/api/quickstart/device_mesh.py index be7bbd0..b3e2ea9 100644 --- a/realhf/api/quickstart/device_mesh.py +++ b/realhf/api/quickstart/device_mesh.py @@ -224,18 +224,18 @@ def find_parallel_strategies( ) -> List[ParallelismConfig]: n_gpus = np.sum(device_mesh.mapping) res = [] - for num_mp in [1, 2, 4, 8]: - if n_gpus >= num_mp: - assert n_gpus % num_mp == 0 - num_dp_pp = n_gpus // num_mp + for num_tp in [1, 2, 4, 8]: + if n_gpus >= num_tp: + assert n_gpus % num_tp == 0 + num_dp_pp = n_gpus // num_tp num_pp = 1 while num_pp <= num_dp_pp: - num_dp_mp = n_gpus // num_pp + num_dp_tp = n_gpus // num_pp valid = ( - num_dp_mp in [1, 2, 4, 8] or num_dp_mp % 8 == 0 + num_dp_tp in [1, 2, 4, 8] or num_dp_tp % 8 == 0 ) and num_dp_pp % num_pp == 0 if valid: - res.append(ParallelismConfig(num_pp, num_mp, num_dp_pp // num_pp)) + res.append(ParallelismConfig(num_pp, num_tp, num_dp_pp // num_pp)) num_pp += 1 return res @@ -248,7 +248,7 @@ class RPCAllocation: def __post_init__(self): world_size = ( - self.parallel.model_parallel_size + self.parallel.tensor_parallel_size * self.parallel.pipeline_parallel_size * self.parallel.data_parallel_size ) diff --git a/realhf/api/quickstart/entrypoint.py b/realhf/api/quickstart/entrypoint.py index 1315f5a..4b0968c 100644 --- a/realhf/api/quickstart/entrypoint.py +++ b/realhf/api/quickstart/entrypoint.py @@ -8,12 +8,10 @@ import functools import inspect import json import os -import pickle -import subprocess -from typing import Callable, Optional +from typing import Callable import hydra -import omegaconf +import yaml from hydra.core.config_store import ConfigStore from omegaconf import MISSING, OmegaConf @@ -29,6 +27,9 @@ def kind_reminder(config_name, logger, args): logger.info( f"Logs will be dumped to {os.path.join(LOG_ROOT, args.experiment_name, args.trial_name)}" ) + logger.info( + f"Experiment configs will be dumped to {os.path.join(LOG_ROOT, args.experiment_name, args.trial_name, 'config.yaml')}" + ) logger.info( f"Model checkpoints will be saved to {os.path.join(MODEL_SAVE_ROOT, args.experiment_name, args.trial_name)}" ) @@ -69,7 +70,7 @@ def register_quickstart_exp(config_name: str, exp_cls: Callable): logger = logging.getLogger("quickstart", "colored") - print_runtime_helper(OmegaConf.to_object(args)) + # print_runtime_helper(OmegaConf.to_object(args)) exp_name = args.experiment_name if args.trial_name == MISSING: @@ -80,6 +81,17 @@ def register_quickstart_exp(config_name: str, exp_cls: Callable): trial_name = args.trial_name from realhf.apps.main import main_start, main_stop + config_save_path = os.path.join( + LOG_ROOT, args.experiment_name, args.trial_name, "config.yaml" + ) + with open(config_save_path, "w") as f: + yaml.dump( + dataclasses.asdict(OmegaConf.to_object(args)), + f, + default_flow_style=False, + sort_keys=False, + ) + kind_reminder(config_name, logger, args) exp_fn = functools.partial(exp_cls, **args) diff --git a/realhf/apps/main.py b/realhf/apps/main.py index dcb07c9..699f6a4 100644 --- a/realhf/apps/main.py +++ b/realhf/apps/main.py @@ -87,7 +87,7 @@ def main_start(args, job_group_id: str = "", recover_count: int = 0): job_group_id = str(uuid.uuid4()) logger.info(f"AReaL Version: {get_full_version_with_dirty_description()}") logger.info(f"AReaL Job Group ID: {job_group_id}") - logger.info(f"AReaL Job Group Index: {recover_count}") + logger.info(f"AReaL Job Group Index (recover count): {recover_count}") if recover_count == 0: constants.set_experiment_trial_names(args.experiment_name, args.trial_name) experiment = config_package.make_experiment(args.experiment_name) @@ -110,10 +110,6 @@ def main_start(args, job_group_id: str = "", recover_count: int = 0): assert ( args.recover_mode == "disabled" ), "Recover mode is not supported for local runs!" - # Use search cache for recover runs - force_allocation_use_cache = ( - recover_count > 1 or args.recover_mode == "resume" - ) and args.allocation_mode == "search" # handle args args.ignore_worker_error = ( args.ignore_worker_error and args.recover_mode == "disabled" @@ -174,12 +170,8 @@ def main_start(args, job_group_id: str = "", recover_count: int = 0): ) for k, v in BASE_ENVIRONS.items(): os.environ[k] = v - os.environ["REAL_IS_REMOTE"] = "0" if not force_allocation_use_cache else "1" # setup experiments - if args.allocation_mode == "search": - experiment._search() - sched = sched_client.make( mode=scheduler_mode(args.mode), expr_name=expr_name, @@ -324,80 +316,6 @@ def main_find_config(args): print(exp_name) -def main_profile_layers(args): - from realhf.api.cli_args import ModelFamily - - _main_profile_layers( - ModelFamily(args.model_class, args.model_size, args.is_critic), - args.model_path, - ) - - -def _main_profile_layers(model_family, model_path): - from realhf.api.cli_args import ModelFamily - from realhf.base.slurm_utils import check_slurm_availability - from realhf.base.testing import clear_name_resolve - - expr_name = trial_name = "profile" - cmd = ( - f"python3 -m realhf.apps.profile_layers --expr_name {expr_name} --trial_name {trial_name} " - f"--model_path {model_path} --model_name {model_family} " - ) - - if check_slurm_availability(): - if not os.environ.get("CLUSTER_SPEC_PATH", ""): - raise ValueError( - "Environment variable CLUSTER_SPEC_PATH must be set for slurm mode! " - "See example/cluster_config.json for a template." - ) - BASE_ENVIRONS = constants.get_env_vars( - REAL_MODE="slurm", - CLUSTER_SPEC_PATH=os.environ.get("CLUSTER_SPEC_PATH", ""), - ) - clear_name_resolve(expr_name, trial_name) - sched = sched_client.make( - mode="slurm", expr_name=expr_name, trial_name=trial_name - ) - print( - f"Profiling {model_family} layers, model path {model_path}, " f"cmd {cmd}" - ) - sched.submit_array( - worker_type="profile_layer", - cmd=cmd, - count=1, - cpu=64, - gpu=8, - gpu_type="tesla", - mem=500000, - env_vars=BASE_ENVIRONS, - container_image=config_package._LLM_GPU_IMAGE, - ) - - try: - sched.wait(timeout=None) - except ( - KeyboardInterrupt, - sched_client.JobException, - TimeoutError, - ) as e: - sched.stop_all() - raise e - else: - try: - print( - f"Profiling {model_family} layers, model path {model_path}, " - f"cmd {cmd}" - ) - clear_name_resolve(expr_name, trial_name) - os.system(cmd) - except ( - KeyboardInterrupt, - sched_client.JobException, - TimeoutError, - ) as e: - raise e - - def main(): parser = argparse.ArgumentParser(prog="ReaLHF") subparsers = parser.add_subparsers(dest="cmd", help="sub-command help") @@ -482,7 +400,7 @@ def main(): type=str, required=False, default="pipe_model", - choices=["manual", "search", "heuristic", "pipe_model", "pipe_data"], + choices=["manual", "heuristic", "pipe_model", "pipe_data"], help="Mode of GPU resource/model parallel strategy allocation.", ) subparser.set_defaults(ignore_worker_error=False) @@ -514,15 +432,6 @@ def main(): subparser.add_argument("--regex", "-r", type=str, required=True) subparser.set_defaults(func=main_find_config) - subparser = subparsers.add_parser( - "profile_layers", help="profile layers of a model." - ) - subparser.add_argument("--model_class", type=str, required=True) - subparser.add_argument("--model_size", type=int, required=True) - subparser.add_argument("--is_critic", action="store_true") - subparser.add_argument("--model_path", type=str, required=True) - subparser.set_defaults(func=main_profile_layers) - args = parser.parse_args() args.func(args) diff --git a/realhf/apps/profile_layers.py b/realhf/apps/profile_layers.py deleted file mode 100644 index 726d6c8..0000000 --- a/realhf/apps/profile_layers.py +++ /dev/null @@ -1,91 +0,0 @@ -# Copyright 2025 Ant Group Inc. -# Copyright 2024 Wei Fu & Zhiyu Mei -# Licensed under the Apache License, Version 2.0 (the "License"). - -import argparse -import itertools -import time - -import realhf.base.testing as testing - -BATCH_SIZE_RANGE = [1, 2, 4, 8, 16, 32, 64, 128] -SEQ_LEN_RANGE = [128, 256, 512] - - -def profile_layer_func( - world_size, - model_path, - model_name, - warm_up_rounds, - profile_rounds, - batch_size_range, - seq_len_range, - use_sequence_parallel=False, - use_gradient_checkpointing=False, -): - # FIXME: use_sequence_parallel=True and use_gradient_checkpointing=True will cause bugs - import torch - - import realhf.base.constants as constants - - testing.init_global_constants( - 1, world_size, 1, sequence_parallel=False, gradient_checkpointing=False - ) - device = torch.device("cuda") - with constants.model_scope(testing.MODEL_NAME): - from realhf.search_engine.layers import make_profile_layers - - profile_layers = make_profile_layers(device, model_path, model_name) - - st = time.monotonic_ns() - for i in range(warm_up_rounds + profile_rounds): - for bs, seq_len in itertools.product(batch_size_range, seq_len_range): - profile_layers.fwd_gen(bs, seq_len) - profile_layers.fwd_bwd_opt(bs, seq_len) - - if i < warm_up_rounds: - profile_layers.reset_stats() - profile_layers.make_dataframe_and_print() - profile_layers.dump_stats(world_size) - t = (time.monotonic_ns() - st) / int(1e9) - print(f"profile world size {world_size} cost {t:4f} seconds") - - -if __name__ == "__main__": - st = time.monotonic_ns() - parser = argparse.ArgumentParser(prog="profile_layers") - parser.add_argument( - "--model_path", - type=str, - required=True, - ) - parser.add_argument("--expr_name", type=str, default="profile") - parser.add_argument("--trial_name", type=str, default="profile") - parser.add_argument("--model_name", type=str, default="Llama-2-70b") - parser.add_argument("--warm_up_rounds", type=int, default=1) - parser.add_argument("--profile_rounds", type=int, default=3) - # parser.add_argument("--use_sequence_parallel", action="store_true") - # parser.add_argument("--use_gradient_checkpointing", action="store_true") - args = parser.parse_args() - - world_sizes = [1, 2, 4, 8] - - for world_size in world_sizes: - testing.clear_name_resolve(args.expr_name, args.trial_name) - mp = testing.LocalMultiProcessTest( - world_size, - profile_layer_func, - world_size, - args.model_path, - args.model_name, - args.warm_up_rounds, - args.profile_rounds, - BATCH_SIZE_RANGE, - SEQ_LEN_RANGE, - expr_name=args.expr_name, - trial_name=args.trial_name, - ) - mp.launch() - - t = (time.monotonic_ns() - st) / int(1e9) - print(f"profile model {args.model_name} time cost {t:4f} seconds") diff --git a/realhf/base/constants.py b/realhf/base/constants.py index 687a36d..c46cd63 100644 --- a/realhf/base/constants.py +++ b/realhf/base/constants.py @@ -72,6 +72,7 @@ MODEL_SAVE_ROOT = f"{cluster_spec.fileroot}/checkpoints/{getpass.getuser()}" LOG_ROOT = f"{cluster_spec.fileroot}/logs/{getpass.getuser()}" RECOVER_ROOT = f"{cluster_spec.fileroot}/recover/{getpass.getuser()}" SLURM_LOCK_FILE_NAME = f"{cluster_spec.fileroot}/logs/slurm_scheduler.lock" +PORT_LOCK_FILE_ROOT = f"{cluster_spec.fileroot}/.cache/{getpass.getuser()}/ports" PYTORCH_KERNEL_CACHE_PATH = ( f"{LOCAL_CACHE_DIR}/.cache/{getpass.getuser()}/torch/kernels" ) @@ -120,6 +121,9 @@ BASE_ENVIRONS = { "REAL_GPU_MEMORY_KILL_THRESHOLD": os.getenv( "REAL_GPU_MEMORY_KILL_THRESHOLD", "0.95" ), + "LC_ALL": "C", + "LANG": "C", + "NCCL_DEBUG": "WARN", } # Set PPU-specific environment variables for stable training. @@ -146,7 +150,6 @@ elif cluster_spec.name == "na132": "NCCL_IB_SL": "5", "NCCL_IB_TC": "136", "NCCL_IB_HCA": "mlx5_bond", - "NCCL_DEBUG": "WARN", "NCCL_IB_QPS_PER_CONNECTION": "8", "NCCL_SET_THREAD_NAME": "1", "NCCL_DEBUG_SUBSYS": "INIT,TUNING,GRAPH", @@ -165,6 +168,7 @@ os.makedirs(DATASET_CACHE_PATH, exist_ok=True) os.makedirs(PROFILER_CACHE_PATH, exist_ok=True) os.makedirs(TORCH_EXTENSIONS_DIR, exist_ok=True) os.makedirs(QUICKSTART_EXPR_CACHE_PATH, exist_ok=True) +os.makedirs(PORT_LOCK_FILE_ROOT, exist_ok=True) os.makedirs(SGLANG_CACHE_PATH, exist_ok=True) # _model_name will be changed in the model_scope context manager @@ -186,16 +190,12 @@ _self_group = None _rank_mapping: Dict["ModelName", Dict["ModelShardID", int]] = {} _global_memory_buffer: GlobalMemoryBuffer = GlobalMemoryBuffer() -# used only in scripts and tests -_fake_mp_world_size = None -_fake_mp_rank = None - # TODO: As in Megatron, we can set NCCL group options. Is it necessary? def reset_run(): - global _model_name, _grids, _pgroups, _pgroup_ranks, _self_group, _rank_mapping, _global_memory_buffer, _fake_mp_world_size, _fake_mp_rank + global _model_name, _grids, _pgroups, _pgroup_ranks, _self_group, _rank_mapping, _global_memory_buffer _model_name = None _grids = {} _pgroups = {} @@ -203,8 +203,6 @@ def reset_run(): _self_group = None _rank_mapping = {} _global_memory_buffer = GlobalMemoryBuffer() - _fake_mp_world_size = None - _fake_mp_rank = None @contextlib.contextmanager @@ -284,7 +282,7 @@ def set_rank_mapping( else: msid2mwid = {k: v for k, v in msid2mwid.items() if k.model_name == model_name} _rank_mapping[model_name] = { - topo.get_rank(data=s.dp_rank, model=s.mp_rank, pipe=s.pp_rank): mw_id + topo.get_rank(data=s.dp_rank, tensor=s.tp_rank, pipe=s.pp_rank): mw_id for s, mw_id in msid2mwid.items() } @@ -408,7 +406,7 @@ def parallelism_group_ranks(): def parallelism_group_size() -> int: """The 3D parallelism group size of a specific model, normally dp_size * - pp_size * mp_size.""" + pp_size * tp_size.""" import torch.distributed as dist return dist.get_world_size(group=parallelism_group()) @@ -470,37 +468,25 @@ def prev_pipe_stage(): def is_dp_head(): - return is_last_pipe_stage() and model_parallel_rank() == 0 + return is_last_pipe_stage() and tensor_parallel_rank() == 0 -def model_parallel_rank() -> int: +def tensor_parallel_rank() -> int: """Return the rank inside the tensor parallelism group.""" - try: - return grid().get_tensor_model_parallel_rank() - except RuntimeError as e: # used only in scripts and tests - if _fake_mp_rank is not None: - return _fake_mp_rank - else: - raise e + return grid().get_tensor_model_parallel_rank() -def model_parallel_world_size() -> int: +def tensor_parallel_world_size() -> int: """Return the world size of the tensor parallelism group.""" - try: - return grid().get_tensor_model_parallel_world_size() - except RuntimeError as e: # used only in scripts and tests - if _fake_mp_world_size is not None: - return _fake_mp_world_size - else: - raise e + return grid().get_tensor_model_parallel_world_size() -def model_parallel_group(): +def tensor_parallel_group(): """Return the NCCL tensor parallelism process group.""" return grid().get_tensor_model_parallel_group() -def model_parallel_cpu_group(): +def tensor_parallel_cpu_group(): """Return the GLOO tensor parallelism process group.""" return grid().get_tensor_model_parallel_cpu_group() @@ -536,26 +522,6 @@ def data_parallel_group(): return grid().get_data_parallel_group() -def set_fake_mp_world_size(world_size): - # used only in scripts and tests - global _fake_mp_world_size - _fake_mp_world_size = world_size - - -def set_fake_mp_rank(rank): - # used only in scripts and tests - global _fake_mp_rank - _fake_mp_rank = rank - - -def set_fake_grid(model_name, rank, topo): - # used only in scripts and tests - from realhf.base.topology import FakeGrid - - global _grids - _grids[model_name] = FakeGrid(rank=rank, topo=topo) - - def get_global_memory_buffer(): global _global_memory_buffer assert _global_memory_buffer is not None, "global memory buffer is not set" diff --git a/realhf/base/gpu_utils.py b/realhf/base/gpu_utils.py index cc93f2f..53c804e 100644 --- a/realhf/base/gpu_utils.py +++ b/realhf/base/gpu_utils.py @@ -62,7 +62,7 @@ def reveal_pg_identity(expr_name, trial_name, worker_index): master_group_name = names.distributed_peer( expr_name, trial_name, GLOBAL_PROCESS_GROUP_NAME ) - name_resolve.add_subentry(master_group_name, str(worker_index), keepalive_ttl=300) + name_resolve.add_subentry(master_group_name, str(worker_index)) def isolate_cuda_device( @@ -100,12 +100,10 @@ def isolate_cuda_device( name_resolve_identifier, ), rank, - keepalive_ttl=60, ) name_resolve.add_subentry( names.distributed_peer(experiment_name, trial_name, name_resolve_identifier), rank, - keepalive_ttl=30, ) logger.debug( f"Worker type {worker_type} rank {rank} waiting for peers, world size {world_size}..." diff --git a/realhf/base/logging.py b/realhf/base/logging.py index fb7fce6..7b46ea9 100644 --- a/realhf/base/logging.py +++ b/realhf/base/logging.py @@ -141,9 +141,18 @@ def getLogger( return logging.getLogger(name) +_LATEST_WANDB_STEP = 0 + + def log_wandb_tensorboard(data, step=None, summary_writer=None): import wandb + global _LATEST_WANDB_STEP + if step is None: + step = _LATEST_WANDB_STEP + else: + _LATEST_WANDB_STEP = max(_LATEST_WANDB_STEP, step) + wandb.log(data, step=step) if summary_writer is not None: for key, val in data.items(): diff --git a/realhf/base/name_resolve.py b/realhf/base/name_resolve.py index b6f5406..95e8e6a 100644 --- a/realhf/base/name_resolve.py +++ b/realhf/base/name_resolve.py @@ -618,7 +618,7 @@ class Etcd3NameRecordRepository(NameRecordRepository): self._to_delete = set() - logger.info(f"Connected to etcd3 at {self._host}:{self._port}") + logger.debug(f"Connected to etcd3 at {self._host}:{self._port}") def __del__(self): """Clean up resources when the object is deleted.""" @@ -945,12 +945,13 @@ def make_repository(type_="nfs", **kwargs): # DEFAULT_REPOSITORY_TYPE = "redis" if socket.gethostname().startswith("frl") else "nfs" DEFAULT_REPOSITORY_TYPE = "nfs" -if ( - etcd3 is not None - and cluster.spec.name in ["wa180", "na132", "su18"] - and os.getenv("REAL_ETCD_ADDR", "") -): +if etcd3 is not None and os.getenv("REAL_ETCD_ADDR", ""): DEFAULT_REPOSITORY_TYPE = "etcd3" +if os.getenv("REAL_ETCD_ADDR", "") and etcd3 is None: + logger.warning( + f"Detected REAL_ETCD_ADDR but etcd3 client is not available. " + "Please run `pip install -r requirements.txt` if you want to use etcd name resolve." + ) DEFAULT_REPOSITORY = make_repository(DEFAULT_REPOSITORY_TYPE) add = DEFAULT_REPOSITORY.add add_subentry = DEFAULT_REPOSITORY.add_subentry diff --git a/realhf/base/names.py b/realhf/base/names.py index b829676..8153b0e 100644 --- a/realhf/base/names.py +++ b/realhf/base/names.py @@ -93,5 +93,9 @@ def gen_servers(experiment_name, trial_name): return f"{USER_NAMESPACE}/{experiment_name}/{trial_name}/gen_servers" +def used_ports(experiment_name, trial_name, host_name): + return f"{USER_NAMESPACE}/{experiment_name}/{trial_name}/{host_name}/" + + def gen_server_manager(experiment_name, trial_name): return f"{USER_NAMESPACE}/{experiment_name}/{trial_name}/gen_server_manager" diff --git a/realhf/base/network.py b/realhf/base/network.py index ec8b02e..f6d17a4 100644 --- a/realhf/base/network.py +++ b/realhf/base/network.py @@ -2,31 +2,16 @@ # Copyright 2024 Wei Fu & Zhiyu Mei # Licensed under the Apache License, Version 2.0 (the "License"). +import fcntl +import os import socket +import time from contextlib import closing +from functools import wraps +from realhf.base import constants, logging, name_resolve, names -def find_free_port(low=1, high=65536, exclude_ports=None): - """Find a free port within the specified range, excluding certain ports.""" - if exclude_ports is None: - exclude_ports = set() - - while True: - with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: - s.bind(("", 0)) - s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - port = s.getsockname()[1] - if low <= port <= high and port not in exclude_ports: - return port - - -def find_multiple_free_ports(count, low=1, high=65536): - """Find multiple mutually exclusive free ports.""" - free_ports = set() - for _ in range(count): - port = find_free_port(low, high, exclude_ports=free_ports) - free_ports.add(port) - return list(free_ports) +logger = logging.getLogger(__name__) def gethostname(): @@ -35,3 +20,54 @@ def gethostname(): def gethostip(): return socket.gethostbyname(socket.gethostname()) + + +def find_free_port( + low=1, high=65536, exclude_ports=None, experiment_name="port", trial_name="port" +): + """Find a free port within the specified range, excluding certain ports.""" + + ports_name = names.used_ports(experiment_name, trial_name, gethostip()) + used_ports = list(map(int, name_resolve.get_subtree(ports_name))) + if exclude_ports is None: + exclude_ports = set(used_ports) + else: + exclude_ports = exclude_ports.union(set(used_ports)) + + free_port = None + lockfile = os.path.join(constants.PORT_LOCK_FILE_ROOT, gethostip()) + while True: + with open(lockfile, "w") as fd: + # This will block until lock is acquired + fcntl.flock(fd, fcntl.LOCK_EX) + try: + with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: + s.bind(("", 0)) + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + port = s.getsockname()[1] + if low <= port <= high and port not in exclude_ports: + name_resolve.add_subentry(ports_name, str(port)) + logger.info(f"Found free port {port}") + free_port = port + break + finally: + fcntl.flock(fd, fcntl.LOCK_UN) + time.sleep(0.05) + return free_port + + +def find_multiple_free_ports( + count, low=1, high=65536, experiment_name="port", trial_name="port" +): + """Find multiple mutually exclusive free ports.""" + free_ports = set() + for _ in range(count): + port = find_free_port( + low=low, + high=high, + exclude_ports=free_ports, + experiment_name=experiment_name, + trial_name=trial_name, + ) + free_ports.add(port) + return list(free_ports) diff --git a/realhf/base/stats_tracker.py b/realhf/base/stats_tracker.py index b4b904e..0ecc7af 100644 --- a/realhf/base/stats_tracker.py +++ b/realhf/base/stats_tracker.py @@ -171,6 +171,9 @@ class DistributedStatsTracker: else: raise ValueError(f"Unknown reduce type: {reduce_type}") + keys_to_pop = [k for k, v in result.items() if v is None] + for k in keys_to_pop: + result.pop(k) return result def _sum_of(self, key, reduce_group): @@ -209,7 +212,7 @@ class DistributedStatsTracker: dist.all_reduce(x, group=reduce_group) dist.all_reduce(d, group=reduce_group) if d == 0: - return 0 + return None return x / d def _min_of(self, key, reduce_group): @@ -224,7 +227,7 @@ class DistributedStatsTracker: if reduce_group is not None: dist.all_reduce(x, group=reduce_group, op=dist.ReduceOp.MIN) if torch.isinf(x): - return float("nan") + return None return float(x) def _max_of(self, key, reduce_group): @@ -239,7 +242,7 @@ class DistributedStatsTracker: if reduce_group is not None: dist.all_reduce(x, group=reduce_group, op=dist.ReduceOp.MAX) if torch.isinf(x): - return float("nan") + return None return float(x) diff --git a/realhf/base/testing.py b/realhf/base/testing.py index e35857a..b5fcbbc 100644 --- a/realhf/base/testing.py +++ b/realhf/base/testing.py @@ -22,9 +22,9 @@ import torch.utils.data from realhf.api.core.data_api import SequenceSample from realhf.base import constants, gpu_utils, logging, name_resolve, names, topology from realhf.base.topology import ( - DataPipeModelParallelTopology, + DataPipeTensorParallelTopology, ParallelGrid, - PipeDataModelParallelTopology, + PipeDataTensorParallelTopology, ) logger = logging.getLogger("testing") @@ -106,9 +106,6 @@ class StandaloneTestingProcess(mp.Process): self.expr_name, self.trial_name, self.rank, backend=self.dist_backend ) - # setup some useful constants - constants.set_experiment_trial_names(self.expr_name, self.trial_name) - # misc setup if constants.use_cuda(): pynvml.nvmlInit() @@ -207,7 +204,7 @@ class LocalMultiProcessTest: def init_global_constants( num_dp=1, - num_mp=1, + num_tp=1, num_pp=1, topo=None, model_name=None, @@ -227,9 +224,9 @@ def init_global_constants( if topo is None: if is_train: - topo = PipeDataModelParallelTopology( + topo = PipeDataTensorParallelTopology( num_dp=num_dp, - num_mp=num_mp, + num_tp=num_tp, num_pp=num_pp, sequence_parallel=sequence_parallel, gradient_checkpointing=gradient_checkpointing, @@ -237,13 +234,13 @@ def init_global_constants( max_prompt_len=max_prompt_len, ) else: - topo = DataPipeModelParallelTopology( + topo = DataPipeTensorParallelTopology( num_dp=num_dp, - num_mp=num_mp, + num_tp=num_tp, num_pp=num_pp, sequence_parallel=sequence_parallel, ) - ws = num_dp * num_mp * num_pp + ws = num_dp * num_tp * num_pp else: ws = topo.world_size() diff --git a/realhf/base/topology.py b/realhf/base/topology.py index b92d9a1..4faeaae 100644 --- a/realhf/base/topology.py +++ b/realhf/base/topology.py @@ -65,22 +65,22 @@ def decompose_to_three_factors(n: int) -> List[Tuple[int, int, int]]: return factors -class PipeDataModelrocessCoord(NamedTuple): +class PipeDataTensorProcessCoord(NamedTuple): pipe: int data: int - model: int + tensor: int -class DataPipeModelrocessCoord(NamedTuple): +class DataPipeTensorProcessCoord(NamedTuple): data: int pipe: int - model: int + tensor: int # Explicitly define these class to allow pickling. PROCESS_COORD_REGISTRY = { - "pipe#data#model": PipeDataModelrocessCoord, - "data#pipe#model": DataPipeModelrocessCoord, + "pipe#data#tensor": PipeDataTensorProcessCoord, + "data#pipe#tensor": DataPipeTensorProcessCoord, } @@ -327,20 +327,20 @@ def _prime_factors(N): return primes -class PipeDataModelParallelTopology(ProcessTopology): +class PipeDataTensorParallelTopology(ProcessTopology): """A topology for hybrid pipeline, model, and data parallelism.""" def __init__( self, num_pp: int, - num_mp: int, + num_tp: int, num_dp: int, sequence_parallel: bool, gradient_checkpointing: bool, gradient_accumulation_fusion: bool, max_prompt_len: Optional[int] = None, ): - super().__init__(axes=["pipe", "data", "model"], dims=[num_pp, num_dp, num_mp]) + super().__init__(axes=["pipe", "data", "tensor"], dims=[num_pp, num_dp, num_tp]) self.sequence_parallel = sequence_parallel self.gradient_checkpointing = gradient_checkpointing @@ -348,7 +348,7 @@ class PipeDataModelParallelTopology(ProcessTopology): self.gradient_accumulation_fusion = gradient_accumulation_fusion -class DataPipeModelParallelTopology(ProcessTopology): +class DataPipeTensorParallelTopology(ProcessTopology): """A topology for hybrid data, pipeline, and tensor parallelism. Note that DP is the most outer dimension. Used for inference only. @@ -357,12 +357,12 @@ class DataPipeModelParallelTopology(ProcessTopology): def __init__( self, num_pp: int, - num_mp: int, + num_tp: int, num_dp: int, sequence_parallel: bool, max_prompt_len: Optional[int] = None, ): - super().__init__(axes=["data", "pipe", "model"], dims=[num_dp, num_pp, num_mp]) + super().__init__(axes=["data", "pipe", "tensor"], dims=[num_dp, num_pp, num_tp]) self.sequence_parallel = sequence_parallel self.max_prompt_len = max_prompt_len @@ -414,7 +414,7 @@ class ParallelGrid: self.data_parallel_size = max(self._topo.get_dim("data"), 1) self.pipe_parallel_size = max(self._topo.get_dim("pipe"), 1) - self.model_parallel_size = max(self._topo.get_dim("model"), 1) + self.model_parallel_size = max(self._topo.get_dim("tensor"), 1) self.slice_parallel_size = self.model_parallel_size assert self._is_grid_valid(), ( "Invalid Grid", @@ -520,7 +520,7 @@ class ParallelGrid: self.slice_group = None self.slice_proc_group = self.slice_proc_group_gloo = None self.mp_group = [] - self.model_groups = self._topo.get_axis_comm_lists("model") + self.model_groups = self._topo.get_axis_comm_lists("tensor") for g in self.model_groups: proc_group = new_or_get_group(ranks=[rank_mapping[x] for x in g]) # NOTE: We must create the GLOO group for vLLM's usage. @@ -634,8 +634,8 @@ class ParallelGrid: def get_tensor_model_parallel_rank(self): if self.global_rank == -1: return -1 - if "model" in self._topo.get_axis_names(): - return self._topo.get_coord(rank=self.global_rank).model + if "tensor" in self._topo.get_axis_names(): + return self._topo.get_coord(rank=self.global_rank).tensor else: return 0 @@ -662,12 +662,12 @@ class FakeGrid: self.data_parallel_size = max(self._topo.get_dim("data"), 1) self.pipe_parallel_size = max(self._topo.get_dim("pipe"), 1) - self.model_parallel_size = max(self._topo.get_dim("model"), 1) + self.model_parallel_size = max(self._topo.get_dim("tensor"), 1) - self.coord: ProcessCoord = self._topo.get_coord(self.rank) + self.coord = self._topo.get_coord(self.rank) self.dp_id = self.coord.data self.pp_id = self.coord.pipe - self.mp_id = self.coord.model + self.mp_id = self.coord.tensor self.world_size = ( self.data_parallel_size * self.pipe_parallel_size * self.model_parallel_size diff --git a/realhf/experiments/async_exp/async_ppo_math_exp.py b/realhf/experiments/async_exp/async_ppo_math_exp.py index 00d6f44..baefa64 100644 --- a/realhf/experiments/async_exp/async_ppo_math_exp.py +++ b/realhf/experiments/async_exp/async_ppo_math_exp.py @@ -6,7 +6,11 @@ from typing import Any, Dict, List, Tuple import realhf.base.logging as logging from realhf.api.cli_args import ModelTrainEvalConfig, PPOMATHExperimentOptions -from realhf.api.core.config import AgentAbstraction, EnvServiceAbstraction +from realhf.api.core.config import ( + AgentAbstraction, + EnvServiceAbstraction, + ModelInterfaceAbstraction, +) from realhf.api.core.model_api import GenerationHyperparameters from realhf.api.quickstart.entrypoint import register_quickstart_exp from realhf.experiments.async_exp.async_rl_exp import AsyncRLExperimentConfig @@ -36,7 +40,7 @@ class AsyncPPOMATHConfig(AsyncRLExperimentConfig, PPOMATHConfig): @property def env(self) -> EnvServiceAbstraction: return EnvServiceAbstraction( - "math-single-step", args=dict(dataset_path=self.dataset.path) + "math-code-single-step", args=dict(dataset_path=self.dataset.path) ) @property @@ -71,6 +75,11 @@ class AsyncPPOMATHConfig(AsyncRLExperimentConfig, PPOMATHConfig): rpcs["ref_inf"].output_keys = ("packed_ref_logprobs",) if "rew_inf" in rpcs: rpcs.pop("rew_inf") + if self.no_training: + rpcs["actor_train"].interface_impl = ModelInterfaceAbstraction("null") + rpcs["actor_gen"].interface_impl = ModelInterfaceAbstraction("null") + if "actor_inf" in rpcs: + rpcs["actor_inf"].interface_impl = ModelInterfaceAbstraction("null") return rpcs @property diff --git a/realhf/experiments/async_exp/async_rl_exp.py b/realhf/experiments/async_exp/async_rl_exp.py old mode 100644 new mode 100755 index a6bbeac..64bd82f --- a/realhf/experiments/async_exp/async_rl_exp.py +++ b/realhf/experiments/async_exp/async_rl_exp.py @@ -60,7 +60,6 @@ GEN_WORKER_DEFAULT_CAPACITY = 512 @dataclasses.dataclass class AsyncRLExperimentConfig(CommonExperimentConfig, AsyncRLOptions): - @property def generation_config(self) -> GenerationHyperparameters: raise NotImplementedError() @@ -203,16 +202,17 @@ class AsyncRLExperimentConfig(CommonExperimentConfig, AsyncRLOptions): "config_from_hf_converter" ](hf_config) if ( - model_config.n_kv_heads % rpc_alloc.parallel.model_parallel_size + model_config.n_kv_heads % rpc_alloc.parallel.tensor_parallel_size != 0 ) or ( - model_config.n_q_heads % rpc_alloc.parallel.model_parallel_size != 0 + model_config.n_q_heads % rpc_alloc.parallel.tensor_parallel_size + != 0 ): raise ValueError( f"The number of KV heads {model_config.n_kv_heads} or " f"Q heads {model_config.n_q_heads} is not" f" divisible by the configured TP size " - f"({rpc_alloc.parallel.model_parallel_size}). " + f"({rpc_alloc.parallel.tensor_parallel_size}). " f"Please decrease TP size." ) mapping = rpc_alloc.device_mesh.mapping @@ -250,7 +250,7 @@ class AsyncRLExperimentConfig(CommonExperimentConfig, AsyncRLOptions): topo=topo, dp_rank=topo.get_coord(shard_idx).data, pp_rank=topo.get_coord(shard_idx).pipe, - mp_rank=topo.get_coord(shard_idx).model, + tp_rank=topo.get_coord(shard_idx).tensor, ), model=model, backend=backend, @@ -308,15 +308,18 @@ class AsyncRLExperimentConfig(CommonExperimentConfig, AsyncRLOptions): model_name = gen_rpc_alloc.rpc.model_name train_rpcs = [alloc.rpc for alloc in rpc_allocs if alloc.rpc.is_train()] assert all(rpc.n_seqs == train_rpcs[0].n_seqs for rpc in train_rpcs) + max_concurrent_rollouts = self.max_concurrent_rollouts + if max_concurrent_rollouts is None: + max_concurrent_rollouts = train_rpcs[0].n_seqs return [ GserverManager( model_name=model_name, flush_request_timeout=self.flush_request_timeout, n_servers=gen_world_size // gen_tp_size, - schedule_policy="round_robin", + schedule_policy=self.schedule_policy, max_head_offpolicyness=self.max_head_offpolicyness, train_batch_size=train_rpcs[0].n_seqs, - max_concurrent_rollouts=self.max_concurrent_rollouts, + max_concurrent_rollouts=max_concurrent_rollouts, ) ] diff --git a/realhf/experiments/benchmark/profile_exp.py b/realhf/experiments/benchmark/profile_exp.py index b4b1ae1..bd15b43 100644 --- a/realhf/experiments/benchmark/profile_exp.py +++ b/realhf/experiments/benchmark/profile_exp.py @@ -37,21 +37,21 @@ def default_parallel_config(n_gpus: int) -> List[Dict[str, Any]]: x = [ { "data_parallel_size": dp, - "model_parallel_size": mp, + "tensor_parallel_size": tp, "pipeline_parallel_size": pp, - "use_sequence_parallel": mp > 1, + "use_sequence_parallel": tp > 1, } - for dp, mp, pp in factors + for dp, tp, pp in factors ] x += [ { "data_parallel_size": dp, - "model_parallel_size": mp, + "tensor_parallel_size": tp, "pipeline_parallel_size": pp, "use_sequence_parallel": False, } - for dp, mp, pp in factors - if mp > 1 + for dp, tp, pp in factors + if tp > 1 ] return x @@ -122,7 +122,7 @@ class ProfileConfig(CommonExperimentConfig): k in [ "data_parallel_size", - "model_parallel_size", + "tensor_parallel_size", "pipeline_parallel_size", "use_sequence_parallel", ] @@ -130,7 +130,7 @@ class ProfileConfig(CommonExperimentConfig): ), pcfg.keys() assert (self.n_nodes * self.n_gpus_per_node) == ( pcfg.get("data_parallel_size", 1) - * pcfg.get("model_parallel_size", 1) + * pcfg.get("tensor_parallel_size", 1) * pcfg.get("pipeline_parallel_size", 1) ) @@ -246,8 +246,6 @@ class ProfileConfig(CommonExperimentConfig): model_name="default", input_keys=["packed_prompts"], log_return_value=False, - model_type=self._tmp_model.type, - model_path=self._tmp_model.path, balanced_dp=True, ) diff --git a/realhf/experiments/common/check.py b/realhf/experiments/common/check.py index 9cd3e6a..23f167e 100644 --- a/realhf/experiments/common/check.py +++ b/realhf/experiments/common/check.py @@ -70,7 +70,7 @@ def check_valid_parallel_batch_size(rpc_alloc: RPCAllocation): mb_spec = rpc.mb_spec dp_size = rpc_alloc.parallel.data_parallel_size - tp_size = rpc_alloc.parallel.model_parallel_size + tp_size = rpc_alloc.parallel.tensor_parallel_size pp_size = rpc_alloc.parallel.pipeline_parallel_size factor = 1 diff --git a/realhf/experiments/common/common.py b/realhf/experiments/common/common.py index 49f4429..7aa238e 100644 --- a/realhf/experiments/common/common.py +++ b/realhf/experiments/common/common.py @@ -44,7 +44,6 @@ from realhf.api.quickstart.device_mesh import ( ) from realhf.base.cluster import spec as cluster_spec from realhf.experiments.common.check import ( - check_is_realhf_native_model_interface, check_valid_model_and_path, check_valid_optimizer, check_valid_parallel_batch_size, @@ -61,7 +60,6 @@ from realhf.experiments.common.utils import ( resolve_replica_ids, resolve_rpc_hooks, ) -from realhf.search_engine.search import search_rpc_allocations # Register all HF models import realhf.api.from_hf # isort:skip @@ -144,10 +142,6 @@ class CommonExperimentConfig(BaseExperimentConfig, Experiment): """ return None - @property - def search_kwargs(self) -> Dict[str, Any]: - return {} - @property def global_device_mesh(self) -> DeviceMesh: return DeviceMesh( @@ -161,20 +155,6 @@ class CommonExperimentConfig(BaseExperimentConfig, Experiment): f"_heuristic_rpc_allocation is not implemented in {self.__class__}" ) - def _search(self): - # called in both api.main and controller - gradient_checkpointing = any( - model.gradient_checkpointing for model in self.models.values() - ) - rpc_allocs: List[RPCAllocation] = search_rpc_allocations( - device_mesh=self.global_device_mesh, - rpcs=list(self.rpcs.values()), - gradient_checkpointing=gradient_checkpointing, - use_cache=self.allocation_use_cache, - **self.search_kwargs, - ) - return rpc_allocs - def scheduling_setup(self) -> ExperimentScheduling: """The resourced occupied by each worker. @@ -221,24 +201,11 @@ class CommonExperimentConfig(BaseExperimentConfig, Experiment): self._check_legal_allocation_options() rpcs = self.rpcs - if self.allocation_mode == "search": - # assert self.mode == "slurm" - # assumes gradient checkpointing for all training RPCs if one is enabled - # for the simplicity of search configurations - rpc_allocs = self._search() - for rpc_alloc in rpc_allocs: - assert isinstance(rpc_alloc.rpc, str) - for rpc in rpcs.values(): - if rpc.name == rpc_alloc.rpc: - rpc_alloc.rpc = rpc - break - else: - raise ValueError(f"RPC {rpc_alloc.rpc} not found in rpcs.") - elif self._allocation_mode.is_decoupled(): + if self._allocation_mode.is_decoupled(): paras = self._allocation_mode.parallel_strat - gdp, gpp, gmp = paras["gen"]["d"], paras["gen"]["p"], paras["gen"]["m"] - gen_world_size = gdp * gpp * gmp + gdp, gpp, gtp = paras["gen"]["d"], paras["gen"]["p"], paras["gen"]["m"] + gen_world_size = gdp * gpp * gtp assert ( gen_world_size < self.n_gpus_per_node or gen_world_size % self.n_gpus_per_node == 0 @@ -268,7 +235,7 @@ class CommonExperimentConfig(BaseExperimentConfig, Experiment): parallel=ParallelismConfig( data_parallel_size=gdp, pipeline_parallel_size=gpp, - model_parallel_size=gmp, + tensor_parallel_size=gtp, use_sequence_parallel=False, ), ) @@ -276,7 +243,7 @@ class CommonExperimentConfig(BaseExperimentConfig, Experiment): else: rpc_name = rpc.name if rpc_name in paras: - dp, pp, mp = ( + dp, pp, tp = ( paras[rpc_name]["d"], paras[rpc_name]["p"], paras[rpc_name]["m"], @@ -287,9 +254,9 @@ class CommonExperimentConfig(BaseExperimentConfig, Experiment): f"RPC {rpc_name} parallel strategy not given, " "expect a `*` to specify the default parallel strategy." ) - dp, pp, mp = paras["*"]["d"], paras["*"]["p"], paras["*"]["m"] + dp, pp, tp = paras["*"]["d"], paras["*"]["p"], paras["*"]["m"] if ( - dp * pp * mp + gdp * gpp * gmp + dp * pp * tp + gdp * gpp * gtp != self.n_nodes * self.n_gpus_per_node ): raise ValueError( @@ -297,7 +264,7 @@ class CommonExperimentConfig(BaseExperimentConfig, Experiment): "does not equal to the number of gpus. " "Note that the device mesh of vLLM/SGLang should be disjoint from the device mesh of other MFCs, " "so their summation should be equal to the total number of gpus. " - f"dp={dp}, pp={pp}, mp={mp}, gen.dp={gdp}, gen.pp={gpp}, gen.mp={gmp}, " + f"dp={dp}, pp={pp}, mp={tp}, gen.dp={gdp}, gen.pp={gpp}, gen.mp={gtp}, " f"n_nodes={self.n_nodes}, n_gpus_per_node={self.n_gpus_per_node}" ) alloc = RPCAllocation( @@ -306,10 +273,10 @@ class CommonExperimentConfig(BaseExperimentConfig, Experiment): parallel=ParallelismConfig( data_parallel_size=dp, pipeline_parallel_size=pp, - model_parallel_size=mp, + tensor_parallel_size=tp, use_sequence_parallel=( rpc.interface_type == ModelInterfaceType.TRAIN_STEP - and mp > 1 + and tp > 1 ), ), ) @@ -323,7 +290,7 @@ class CommonExperimentConfig(BaseExperimentConfig, Experiment): rpc_allocs = [] for rpc_name, rpc in self.rpcs.items(): if rpc_name in paras: - dp, pp, mp = ( + dp, pp, tp = ( paras[rpc_name]["d"], paras[rpc_name]["p"], paras[rpc_name]["m"], @@ -334,18 +301,18 @@ class CommonExperimentConfig(BaseExperimentConfig, Experiment): f"RPC {rpc_name} parallel strategy not given, " "expect a `*` to specify the default parallel strategy." ) - dp, pp, mp = paras["*"]["d"], paras["*"]["p"], paras["*"]["m"] - assert dp * pp * mp == self.n_nodes * self.n_gpus_per_node + dp, pp, tp = paras["*"]["d"], paras["*"]["p"], paras["*"]["m"] + assert dp * pp * tp == self.n_nodes * self.n_gpus_per_node alloc = RPCAllocation( rpc=rpc, device_mesh=self.global_device_mesh, parallel=ParallelismConfig( data_parallel_size=dp, pipeline_parallel_size=pp, - model_parallel_size=mp, + tensor_parallel_size=tp, use_sequence_parallel=( rpc.interface_type == ModelInterfaceType.TRAIN_STEP - and mp > 1 + and tp > 1 ), ), ) @@ -455,7 +422,7 @@ class CommonExperimentConfig(BaseExperimentConfig, Experiment): topo=topo, dp_rank=topo.get_coord(shard_idx).data, pp_rank=topo.get_coord(shard_idx).pipe, - mp_rank=topo.get_coord(shard_idx).model, + tp_rank=topo.get_coord(shard_idx).tensor, ), model=ModelAbstraction( "tokenizer", args=dict(tokenizer_path=model_cfg.path) @@ -464,7 +431,6 @@ class CommonExperimentConfig(BaseExperimentConfig, Experiment): gen_backend_name, args=dict( model_path=model_cfg.path, - dtype="bfloat16" if model_cfg.bf16 else "float16", **dict_args, ), ), @@ -503,16 +469,17 @@ class CommonExperimentConfig(BaseExperimentConfig, Experiment): "config_from_hf_converter" ](hf_config) if ( - model_config.n_kv_heads % rpc_alloc.parallel.model_parallel_size + model_config.n_kv_heads % rpc_alloc.parallel.tensor_parallel_size != 0 ) or ( - model_config.n_q_heads % rpc_alloc.parallel.model_parallel_size != 0 + model_config.n_q_heads % rpc_alloc.parallel.tensor_parallel_size + != 0 ): raise ValueError( f"The number of KV heads {model_config.n_kv_heads} or " f"Q heads {model_config.n_q_heads} is not" f" divisible by the configured TP size " - f"({rpc_alloc.parallel.model_parallel_size}). " + f"({rpc_alloc.parallel.tensor_parallel_size}). " f"Please decrease TP size." ) mapping = rpc_alloc.device_mesh.mapping @@ -572,7 +539,7 @@ class CommonExperimentConfig(BaseExperimentConfig, Experiment): topo=topo, dp_rank=topo.get_coord(shard_idx).data, pp_rank=topo.get_coord(shard_idx).pipe, - mp_rank=topo.get_coord(shard_idx).model, + tp_rank=topo.get_coord(shard_idx).tensor, ), model=model, backend=backend, @@ -612,12 +579,9 @@ class CommonExperimentConfig(BaseExperimentConfig, Experiment): "please setup slurm for distributed runs." ) - if self.n_gpus_per_node != 8 and self.allocation_mode in [ - "search", - "heuristic", - ]: + if self.n_gpus_per_node != 8 and self.allocation_mode == "heuristic": raise ValueError( - f"Cannot run search or heuristic allocation with " + f"Cannot run heuristic allocation with " f"n_gpus_per_node {self.n_gpus_per_node}, " "please set n_gpus_per_node to 8." ) @@ -627,13 +591,6 @@ class CommonExperimentConfig(BaseExperimentConfig, Experiment): raise KeyError( f"RPC name {rpc_name} does not match the name in the MFCDef object {rpc.name}." ) - if not check_is_realhf_native_model_interface( - rpc.interface_impl.type_ - ) and self.allocation_mode in ["search"]: - raise ValueError( - f"RPC {rpc.name} interface is not a realhf native implementation. " - f"The search allocation mode are not available." - ) if self.allocation_mode == "manual" and rpc_name not in self.allocations: if rpc_name not in self.allocations: raise ValueError( diff --git a/realhf/experiments/common/math_code_eval_exp.py b/realhf/experiments/common/math_code_eval_exp.py index 7cdf0cc..1f2640d 100644 --- a/realhf/experiments/common/math_code_eval_exp.py +++ b/realhf/experiments/common/math_code_eval_exp.py @@ -65,8 +65,6 @@ class MathCodeEvalConfig(MathCodeEvalOptions, CommonExperimentConfig): model_name="actor", mb_spec=self.actor_gen.mb_spec, interface_type=ModelInterfaceType.GENERATE, - model_type=self.actor.type, - model_path=self.actor.path, interface_impl=actor_interface, input_keys=("packed_prompts", "task_ids"), output_keys=("packed_input_ids",), @@ -79,8 +77,6 @@ class MathCodeEvalConfig(MathCodeEvalOptions, CommonExperimentConfig): mb_spec=self.rew_inf.mb_spec, interface_type=ModelInterfaceType.INFERENCE, interface_impl=rw_interface, - model_type=self.rew.type, - model_path=self.rew.path, min_n_seqs_per_pass=1 / self.group_size, input_keys=("packed_input_ids", "packed_prompts", "task_ids"), output_keys=("rewards",), diff --git a/realhf/experiments/common/null_exp.py b/realhf/experiments/common/null_exp.py index 7e9297a..b5dfe5b 100644 --- a/realhf/experiments/common/null_exp.py +++ b/realhf/experiments/common/null_exp.py @@ -39,8 +39,6 @@ class NullSFTConfig(CommonExperimentConfig, SFTExperimentOptions): model_name="default", input_keys=("packed_input_ids", "prompt_mask"), log_return_value=True, - model_type=self.model.type, - model_path=self.model.path, ) return {"trainDefault": rpc} @@ -88,8 +86,6 @@ class NullPPOConfig(CommonExperimentConfig, NullPPOExperimentOptions): model_name="default", input_keys=("packed_prompts",), output_keys=("rewards",), - model_type=self.model.type, - model_path=self.model.path, ) rpc = MFCDef( n_seqs=self.dataset.train_bs_n_seqs, @@ -100,8 +96,6 @@ class NullPPOConfig(CommonExperimentConfig, NullPPOExperimentOptions): model_name="default", input_keys=("packed_prompts", "rewards"), log_return_value=True, - model_type=self.model.type, - model_path=self.model.path, ) return {"trainDefault": rpc, "reward": rw} diff --git a/realhf/experiments/common/ppo_math_exp.py b/realhf/experiments/common/ppo_math_exp.py index 2f289d3..45a67ac 100644 --- a/realhf/experiments/common/ppo_math_exp.py +++ b/realhf/experiments/common/ppo_math_exp.py @@ -148,32 +148,31 @@ class PPOMATHConfig(CommonExperimentConfig, PPOMATHExperimentOptions): "packed_logprobs", "prompt_mask", ] - if self.ppo.recompute_logprob: + if self.ppo.recompute_logprob and not self.ppo.use_decoupled_loss: rollout_output_keys.remove("packed_logprobs") rollout = MFCDef( name="actor_gen", model_name="actor", mb_spec=self.actor_gen.mb_spec, interface_type=ModelInterfaceType.GENERATE, - model_type=self.actor.type, - model_path=self.actor.path, interface_impl=actor_interface, input_keys=("packed_prompts", "task_ids"), output_keys=tuple(rollout_output_keys), n_seqs=self.dataset.train_bs_n_seqs, ) + actor_inf_outputs = ("packed_logprobs",) + if self.ppo.use_decoupled_loss: + actor_inf_outputs = ("proximal_logprobs",) actor_inf = MFCDef( name="actor_inf", model_name="actor", mb_spec=self.actor_inf.mb_spec, interface_type=ModelInterfaceType.INFERENCE, - model_type=self.actor.type, - model_path=self.actor.path, interface_impl=actor_interface, input_keys=("packed_input_ids",), - output_keys=("packed_logprobs",), - output_key_remap=dict(logprobs="packed_logprobs"), + output_keys=actor_inf_outputs, + output_key_remap=dict(logprobs=actor_inf_outputs[0]), n_seqs=self.dataset.train_bs_n_seqs, ) @@ -200,8 +199,6 @@ class PPOMATHConfig(CommonExperimentConfig, PPOMATHExperimentOptions): model_name="ref", mb_spec=self.ref_inf.mb_spec, interface_type=ModelInterfaceType.INFERENCE, - model_type=self.ref.type, - model_path=self.ref.path, interface_impl=ref_interface, min_n_seqs_per_pass=1 / self.group_size, input_keys=tuple(inf_ref_inputs), @@ -216,8 +213,6 @@ class PPOMATHConfig(CommonExperimentConfig, PPOMATHExperimentOptions): mb_spec=self.critic_inf.mb_spec, interface_type=ModelInterfaceType.INFERENCE, interface_impl=critic_interface, - model_type=self.critic.type, - model_path=self.critic.path, min_n_seqs_per_pass=1 / self.group_size, input_keys=("packed_input_ids", "seq_no_eos_mask"), output_keys=("values",), @@ -238,13 +233,13 @@ class PPOMATHConfig(CommonExperimentConfig, PPOMATHExperimentOptions): train_actor_inputs.remove("values") if self.ppo.kl_ctl == 0: train_actor_inputs.remove("packed_ref_logprobs") + if self.ppo.use_decoupled_loss: + train_actor_inputs.append("proximal_logprobs") train_actor = MFCDef( name="actor_train", model_name="actor", mb_spec=self.actor_train.mb_spec, interface_type=ModelInterfaceType.TRAIN_STEP, - model_type=self.actor.type, - model_path=self.actor.path, interface_impl=actor_interface, input_keys=tuple(train_actor_inputs), log_return_value=True, @@ -269,8 +264,6 @@ class PPOMATHConfig(CommonExperimentConfig, PPOMATHExperimentOptions): mb_spec=self.critic_train.mb_spec, interface_type=ModelInterfaceType.TRAIN_STEP, interface_impl=critic_interface, - model_type=self.critic.type, - model_path=self.critic.path, input_keys=tuple(train_critic_inputs), log_return_value=True, min_n_seqs_per_pass=self.ppo.ppo_n_minibatches / self.group_size, @@ -289,7 +282,7 @@ class PPOMATHConfig(CommonExperimentConfig, PPOMATHExperimentOptions): if self.ppo.disable_value: rpcs.pop("critic_inf") rpcs.pop("critic_train") - if not self.ppo.recompute_logprob: + if not self.ppo.recompute_logprob and not self.ppo.use_decoupled_loss: rpcs.pop("actor_inf") if self.ppo.kl_ctl == 0: rpcs.pop("ref_inf") @@ -311,7 +304,7 @@ class PPOMATHConfig(CommonExperimentConfig, PPOMATHExperimentOptions): if self.ppo.disable_value: allocs.pop("critic_inf") allocs.pop("critic_train") - if not self.ppo.recompute_logprob: + if not self.ppo.recompute_logprob and not self.ppo.use_decoupled_loss: allocs.pop("actor_inf") if self.ppo.kl_ctl == 0: allocs.pop("ref_inf") @@ -337,14 +330,6 @@ class PPOMATHConfig(CommonExperimentConfig, PPOMATHExperimentOptions): def tokenizer_name_or_path(self) -> str: return self.actor.path - @property - def search_kwargs(self): - return { - "num_gen_tokens": self.ppo.gen.max_new_tokens, - "n_ppo_minibatches": self.ppo.ppo_n_minibatches, - "seq_len": self.dataset.max_prompt_len, - } - @property def max_prompt_len(self): return self.dataset.max_prompt_len diff --git a/realhf/experiments/common/sft_exp.py b/realhf/experiments/common/sft_exp.py index 255de65..a420f1e 100644 --- a/realhf/experiments/common/sft_exp.py +++ b/realhf/experiments/common/sft_exp.py @@ -36,8 +36,6 @@ class SFTConfig(CommonExperimentConfig, SFTExperimentOptions): model_name="default", input_keys=("packed_input_ids", "prompt_mask"), log_return_value=True, - model_type=self.model.type, - model_path=self.model.path, ) return {"trainDefault": rpc} diff --git a/realhf/experiments/common/utils.py b/realhf/experiments/common/utils.py index 5881f27..3522e79 100644 --- a/realhf/experiments/common/utils.py +++ b/realhf/experiments/common/utils.py @@ -34,8 +34,8 @@ from realhf.api.core.dfg import OffloadHook, ParamReallocHook from realhf.api.quickstart.device_mesh import RPCAllocation from realhf.base import logging from realhf.base.topology import ( - DataPipeModelParallelTopology, - PipeDataModelParallelTopology, + DataPipeTensorParallelTopology, + PipeDataTensorParallelTopology, ProcessTopology, ) @@ -73,8 +73,8 @@ def get_topo( max_prompt_len: Optional[int] = None, ) -> ProcessTopology: if is_train: - return PipeDataModelParallelTopology( - num_mp=parallel.model_parallel_size, + return PipeDataTensorParallelTopology( + num_tp=parallel.tensor_parallel_size, num_pp=parallel.pipeline_parallel_size, num_dp=parallel.data_parallel_size, sequence_parallel=parallel.use_sequence_parallel, @@ -82,8 +82,8 @@ def get_topo( max_prompt_len=max_prompt_len, gradient_accumulation_fusion=gradient_accumulation_fusion, ) - return DataPipeModelParallelTopology( - num_mp=parallel.model_parallel_size, + return DataPipeTensorParallelTopology( + num_tp=parallel.tensor_parallel_size, num_pp=parallel.pipeline_parallel_size, num_dp=parallel.data_parallel_size, sequence_parallel=parallel.use_sequence_parallel, @@ -93,7 +93,7 @@ def get_topo( def get_world_size(parallel: ParallelismConfig) -> int: return ( - parallel.model_parallel_size + parallel.tensor_parallel_size * parallel.pipeline_parallel_size * parallel.data_parallel_size ) @@ -247,9 +247,8 @@ class AllocationType(enum.Enum): GLOBAL_HYBRID = 2 MANUAL = 3 HEURISTIC = 4 - SEARCH = 5 - DECOUPLED_SGLANG = 6 - DECOUPLED_MOCK = 7 + DECOUPLED_SGLANG = 5 + DECOUPLED_MOCK = 6 @dataclasses.dataclass @@ -293,8 +292,6 @@ class AllocationMode: return cls(AllocationType.MANUAL, None) if allocation_mode == "heuristic": return cls(AllocationType.HEURISTIC, None) - if allocation_mode == "search": - return cls(AllocationType.SEARCH, None) alloc_3d = AllocationMode.extract_3d_alloc(allocation_mode) alloc_hybrid = AllocationMode.extract_key_value_alloc(allocation_mode) diff --git a/realhf/impl/environment/__init__.py b/realhf/impl/environment/__init__.py index f801f1f..3f6c583 100644 --- a/realhf/impl/environment/__init__.py +++ b/realhf/impl/environment/__init__.py @@ -1 +1 @@ -import realhf.impl.environment.math_single_step_env +import realhf.impl.environment.math_code_single_step_env diff --git a/realhf/impl/environment/math_code_single_step_env.py b/realhf/impl/environment/math_code_single_step_env.py new file mode 100644 index 0000000..c07278f --- /dev/null +++ b/realhf/impl/environment/math_code_single_step_env.py @@ -0,0 +1,75 @@ +# Copyright 2025 Ant Group Inc. + +import asyncio +import os +import re +from typing import List, Tuple + +from functioncall.code.local_verify import code_verify as local_code_verify +from functioncall.code.verify import code_verify +from functioncall.math.verify import math_verify +from realhf.api.core.env_api import EnvironmentService, register_environment +from realhf.base import logging +from realhf.impl.dataset.math_code_dataset import load_metadata +from realhf.impl.dataset.math_parser import parse_lines_in_parallel + +ENABLE_FUNCTION_CALL = True if os.getenv("FUNCTIONCALL_SERVICE_DOMAIN", "") else False +math_verify_call = math_verify if ENABLE_FUNCTION_CALL else parse_lines_in_parallel +code_verify_call = code_verify if ENABLE_FUNCTION_CALL else local_code_verify + +logger = logging.getLogger("Math Single Step Environment") + + +def extract_code(text, min_length=20): + code_pattern = r"(?i)```(?:python|py|cpp|CPP)?\s*\n?(.*?)\n?```" + code_blocks = re.findall(code_pattern, text, re.DOTALL) + valid_blocks = [] + for block in code_blocks: + clean_block = block.strip() + if len(clean_block) < min_length: + continue + + valid_blocks.append(clean_block) + + if not valid_blocks: + # logger.warning(f"failed to extract python code from {text}") + return None + # return the last code block + return valid_blocks[-1] + + +class MathCodeSingleStepEnv(EnvironmentService): + def __init__(self, dataset_path: str): + self.id2info, _ = load_metadata(dataset_path) + + async def reset(self, seed=None, options=None): + return None, {} + + async def step(self, action: Tuple[str, List[str]]): + qid, answers = action + group_size = len(answers) + qid = qid.split("@")[0] + cur_task = self.id2info[qid]["task"] + + if cur_task == "math": + format_rewards = await asyncio.to_thread( + math_verify_call, + self.id2info, + answers, + [qid for _ in range(group_size)], + ) + elif cur_task == "code": + answers = [extract_code(x) for x in answers] + format_rewards = await asyncio.to_thread( + code_verify_call, + self.id2info, + answers, + [qid for _ in range(group_size)], + ) + else: + raise NotImplementedError() + + return None, format_rewards, True, False, {} + + +register_environment("math-code-single-step", MathCodeSingleStepEnv) diff --git a/realhf/impl/environment/math_single_step_env.py b/realhf/impl/environment/math_single_step_env.py deleted file mode 100644 index 5b34ce6..0000000 --- a/realhf/impl/environment/math_single_step_env.py +++ /dev/null @@ -1,38 +0,0 @@ -# Copyright 2025 Ant Group Inc. - -import asyncio -import os -from typing import List, Tuple - -from functioncall.math.verify import math_verify -from realhf.api.core.env_api import EnvironmentService, register_environment -from realhf.base import logging -from realhf.impl.dataset.math_code_dataset import load_metadata -from realhf.impl.dataset.math_parser import parse_lines_in_parallel - -ENABLE_FUNCTION_CALL = True if os.getenv("FUNCTIONCALL_SERVICE_DOMAIN", "") else False -math_verify_call = math_verify if ENABLE_FUNCTION_CALL else parse_lines_in_parallel - -logger = logging.getLogger("Math Single Step Environment") - - -class MathSingleStepEnv(EnvironmentService): - def __init__(self, dataset_path: str): - self.id2info, _ = load_metadata(dataset_path) - - async def reset(self, seed=None, options=None): - return None, {} - - async def step(self, action: Tuple[str, List[str]]): - qid, answers = action - group_size = len(answers) - format_rewards = await asyncio.to_thread( - math_verify_call, - self.id2info, - answers, - [qid for _ in range(group_size)], - ) - return None, format_rewards, True, False, {} - - -register_environment("math-single-step", MathSingleStepEnv) diff --git a/realhf/impl/model/__init__.py b/realhf/impl/model/__init__.py index 8e49cf1..655703b 100644 --- a/realhf/impl/model/__init__.py +++ b/realhf/impl/model/__init__.py @@ -13,7 +13,7 @@ import realhf.api.from_hf import realhf.base.logging as logging from realhf.api.core.model_api import HF_MODEL_FAMILY_REGISTRY from realhf.base.importing import import_module -from realhf.base.pkg_version import is_version_less +from realhf.base.pkg_version import is_available, is_version_less from realhf.impl.model.conversion.hf_registry import HFModelRegistry from realhf.impl.model.nn.real_llm_api import ReaLModel @@ -27,8 +27,9 @@ import_module(os.path.join(_filepath, "nn"), _p) # NOTE: skip importing vLLM for now to avoid an # "invalid device context" issue for the 25.01 image -if is_version_less("vllm", "0.6.4"): +if is_available("vllm") and is_version_less("vllm", "0.6.4"): import realhf.impl.model.backend.vllm + import realhf.impl.model.backend.inference import realhf.impl.model.backend.megatron import realhf.impl.model.backend.mock_train diff --git a/realhf/impl/model/backend/inference.py b/realhf/impl/model/backend/inference.py index 5a32312..3a68d0a 100644 --- a/realhf/impl/model/backend/inference.py +++ b/realhf/impl/model/backend/inference.py @@ -62,7 +62,7 @@ class PipelinableInferenceEngine(model_api.PipelinableEngine): f"num_layers(this stage)={self.module.num_layers} " f"pp_size={constants.pipe_parallel_world_size()} " f"dp_size={constants.data_parallel_world_size()} " - f"mp_size={constants.model_parallel_world_size()} " + f"tp_size={constants.tensor_parallel_world_size()} " ) if constants.data_parallel_rank() == 0: logger.info( diff --git a/realhf/impl/model/backend/megatron.py b/realhf/impl/model/backend/megatron.py index 1ed7944..eb1a4d8 100644 --- a/realhf/impl/model/backend/megatron.py +++ b/realhf/impl/model/backend/megatron.py @@ -126,7 +126,7 @@ def megatron_ctx(): # Build the tensor model-parallel groups. parallel_state._TENSOR_MODEL_PARALLEL_GROUP = g if pkg_version.is_version_greater_or_equal("megatron.core", "0.11.0"): - g = constants.model_parallel_group() + g = constants.tensor_parallel_group() parallel_state._TENSOR_MODEL_PARALLEL_GLOBAL_RANKS = ( dist.get_process_group_ranks(g) ) @@ -155,7 +155,7 @@ def megatron_ctx(): if pkg_version.is_version_greater_or_equal("megatron.core", "0.11.0"): # Build the tensor + context parallel groups parallel_state._TENSOR_AND_CONTEXT_PARALLEL_GROUP = ( - constants.model_parallel_group() + constants.tensor_parallel_group() ) # Build expert parallel groups. @@ -173,7 +173,7 @@ def megatron_ctx(): ) else: parallel_state._TENSOR_AND_EXPERT_PARALLEL_GROUP = ( - constants.model_parallel_group() + constants.tensor_parallel_group() ) parallel_state._DATA_MODULO_EXPERT_PARALLEL_GROUP = ( constants.data_parallel_group() @@ -227,7 +227,7 @@ class MegatronEngine: def _all_reduce_layernorm_grads(self): if not ( - constants.sequence_parallel() and constants.model_parallel_world_size() > 1 + constants.sequence_parallel() and constants.tensor_parallel_world_size() > 1 ): return real_model: ReaLModel = self.ddp.module @@ -255,7 +255,7 @@ class MegatronEngine: assert all(x is not None for x in grads) coalesced = _flatten_dense_tensors(grads) - dist.all_reduce(coalesced, group=constants.model_parallel_group()) + dist.all_reduce(coalesced, group=constants.tensor_parallel_group()) for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)): buf.copy_(synced) @@ -362,7 +362,10 @@ class PipeTrainInstrSetForMegatron(PipeTrainInstrSet): ) dist.all_reduce(grad_norm, group=constants.tp_and_pp_group()) grad_norm /= constants.tp_and_pp_world_size() - if constants.data_parallel_rank() == 0 and constants.model_parallel_rank() == 0: + if ( + constants.data_parallel_rank() == 0 + and constants.tensor_parallel_rank() == 0 + ): logger.info( f"Model name {constants.model_name()}, " f"Pipeline rank {constants.pipe_parallel_rank()}. " @@ -539,7 +542,10 @@ class ReaLMegatronEngine(model_api.PipelinableEngine): ) dist.all_reduce(grad_norm, group=constants.tp_and_pp_group()) grad_norm /= constants.tp_and_pp_world_size() - if constants.data_parallel_rank() == 0 and constants.model_parallel_rank() == 0: + if ( + constants.data_parallel_rank() == 0 + and constants.tensor_parallel_rank() == 0 + ): logger.info( f"Megatron backend update success? {update_successful}. " f"Grad Norm: {grad_norm}. " @@ -700,7 +706,8 @@ class MegatronTrainBackend(model_api.ModelBackend, MegatronConfig): # Deleting models directly will not release the memory. # We must disable hooks at first. if pkg_version.is_version_greater_or_equal("megatron.core", "0.11.0"): - model.module.engine.ddp.disable_forward_pre_hook() + if self.ddp.use_distributed_optimizer and self.ddp.overlap_param_gather: + model.module.engine.ddp.disable_forward_pre_hook() else: optimizer = model.module.engine.optim if self.ddp.use_distributed_optimizer and self.ddp.overlap_param_gather: @@ -726,7 +733,7 @@ class MegatronTrainBackend(model_api.ModelBackend, MegatronConfig): sd = optimizer.state_dict() dp = constants.data_parallel_rank() pp = constants.pipe_parallel_rank() - tp = constants.model_parallel_rank() + tp = constants.tensor_parallel_rank() # HACK: (bowei) I'm not sure whether there's duplicated information. torch.save( sd, pathlib.Path(save_dir) / f"megatron_optim_sd_d{dp}p{pp}t{tp}.mckpt" @@ -742,7 +749,7 @@ class MegatronTrainBackend(model_api.ModelBackend, MegatronConfig): dp = constants.data_parallel_rank() pp = constants.pipe_parallel_rank() - tp = constants.model_parallel_rank() + tp = constants.tensor_parallel_rank() sd = torch.load( pathlib.Path(load_dir) / f"megatron_optim_sd_d{dp}p{pp}t{tp}.mckpt" diff --git a/realhf/impl/model/backend/pipe_runner.py b/realhf/impl/model/backend/pipe_runner.py index 44409ca..dcc30b0 100644 --- a/realhf/impl/model/backend/pipe_runner.py +++ b/realhf/impl/model/backend/pipe_runner.py @@ -82,7 +82,7 @@ def _split_and_prefill_pipe_input( raise PipelineError( "Partitioned seqlens are not equal across pipeline parallel ranks. " f"Current rank (dp={constants.data_parallel_rank()}," - f"tp={constants.model_parallel_rank()},pp={constants.pipe_parallel_rank()}), " + f"tp={constants.tensor_parallel_rank()},pp={constants.pipe_parallel_rank()}), " f"gathered batch seqlens={_batch_seqlen_all_gathered}, " f"Have you ensured that the order of dataset across ranks is the same?", ) @@ -118,7 +118,7 @@ def _split_and_prefill_pipe_input( total_len = ( packed_input_ids.shape[0] if not constants.sequence_parallel() - else packed_input_ids.shape[0] // constants.model_parallel_world_size() + else packed_input_ids.shape[0] // constants.tensor_parallel_world_size() ) mb_seq_lens.append(total_len) return (x, ys) @@ -569,7 +569,7 @@ class PipeGenInstrSet: "batch_lengths", micro_batch_id, remove=False ) batch_length = ( - batch_length // constants.model_parallel_world_size() + batch_length // constants.tensor_parallel_world_size() if constants.sequence_parallel() else batch_length ) diff --git a/realhf/impl/model/backend/sglang.py b/realhf/impl/model/backend/sglang.py index 789521e..1c8c480 100644 --- a/realhf/impl/model/backend/sglang.py +++ b/realhf/impl/model/backend/sglang.py @@ -198,8 +198,8 @@ class SGLangGenerationEngine(PipelinableEngine): hybrid_train: bool, request_timeout: int = 1800, ): - if constants.model_parallel_rank() != 0: - dist.barrier(group=constants.model_parallel_cpu_group()) + if constants.tensor_parallel_rank() != 0: + dist.barrier(group=constants.tensor_parallel_cpu_group()) return # Start the serving process self.server_proc = mp.Process( @@ -224,8 +224,8 @@ class SGLangGenerationEngine(PipelinableEngine): if server_args_dict["enable_metrics"]: dp_rank = constants.data_parallel_rank() pp_rank = constants.pipe_parallel_rank() - mp_rank = constants.model_parallel_rank() - metric_server_name = f"d{dp_rank}p{pp_rank}m{mp_rank}" + tp_rank = constants.tensor_parallel_rank() + metric_server_name = f"d{dp_rank}p{pp_rank}t{tp_rank}" key = names.metric_server( constants.experiment_name(), constants.trial_name(), @@ -243,7 +243,7 @@ class SGLangGenerationEngine(PipelinableEngine): # offload weights/cache self.hybrid_train = hybrid_train - dist.barrier(group=constants.model_parallel_cpu_group()) + dist.barrier(group=constants.tensor_parallel_cpu_group()) def __del__(self): if hasattr(self, "server_proc"): @@ -381,8 +381,8 @@ class SGLangGenerationEngine(PipelinableEngine): "NOTE: passing in an arbitrary `min_new_tokens` will lead to a bug for SGLang v0.4.3 " "because we force to skip_tokenizer_init." ) - if constants.model_parallel_rank() != 0: - dist.barrier(group=constants.model_parallel_cpu_group()) + if constants.tensor_parallel_rank() != 0: + dist.barrier(group=constants.tensor_parallel_cpu_group()) return None, None, None results = asyncio.run( @@ -393,12 +393,12 @@ class SGLangGenerationEngine(PipelinableEngine): gconfig=gconfig, ) ) - dist.barrier(group=constants.model_parallel_cpu_group()) + dist.barrier(group=constants.tensor_parallel_cpu_group()) return results def update_weights_from_disk(self, path): - if constants.model_parallel_rank() != 0: - dist.barrier(group=constants.model_parallel_cpu_group()) + if constants.tensor_parallel_rank() != 0: + dist.barrier(group=constants.tensor_parallel_cpu_group()) return async def _fn(): @@ -409,18 +409,17 @@ class SGLangGenerationEngine(PipelinableEngine): await client.async_update_weights_from_disk(path) asyncio.run(_fn()) - dist.barrier(group=constants.model_parallel_cpu_group()) + dist.barrier(group=constants.tensor_parallel_cpu_group()) @dataclasses.dataclass class SGLangGenerationBackend(ModelBackend, SGLangConfig): model_path: str = "" - dtype: str = "float16" def _initialize(self, model: Model, spec: FinetuneSpec) -> Model: if constants.pipe_parallel_world_size() != 1: raise RuntimeError("SGLang does not support pipe parallel size > 1.") - if constants.model_parallel_world_size() > cluster.spec.n_gpus_per_node: + if constants.tensor_parallel_world_size() > cluster.spec.n_gpus_per_node: raise RuntimeError( "AReaL's SGLang integration does not support model parallel size > n_gpus_per_node." ) @@ -436,7 +435,13 @@ class SGLangGenerationBackend(ModelBackend, SGLangConfig): ) != len(datapack.flat2d(ports)): dist.all_gather_object( ports, - network.find_multiple_free_ports(2, low=20000, high=40000), + network.find_multiple_free_ports( + 2, + low=10000, + high=60000, + experiment_name=constants.experiment_name(), + trial_name=constants.trial_name(), + ), group=constants.data_parallel_group(), ) api_server_port, dist_port = ports[constants.data_parallel_rank()] @@ -450,13 +455,12 @@ class SGLangGenerationBackend(ModelBackend, SGLangConfig): tokenizer_mode="auto", load_format="auto", trust_remote_code=True, - kv_cache_dtype="auto", device="cuda", served_model_name=f"{constants.experiment_name()}/{constants.trial_name()}/{constants.model_name().role}", is_embedding=False, skip_tokenizer_init=True, # Other runtime options - tp_size=constants.model_parallel_world_size(), + tp_size=constants.tensor_parallel_world_size(), # Because we have set CUDA_VISIBLE_DEVICES to a single GPU in each process base_gpu_id=int(os.environ["CUDA_VISIBLE_DEVICES"]), file_storage_path=os.path.join( diff --git a/realhf/impl/model/backend/thirdparty/vllm/context.py b/realhf/impl/model/backend/thirdparty/vllm/context.py index 92cbb22..b97ddf8 100644 --- a/realhf/impl/model/backend/thirdparty/vllm/context.py +++ b/realhf/impl/model/backend/thirdparty/vllm/context.py @@ -36,7 +36,7 @@ def _vllm_group_rank(group_type: _vLLMGroupType): if group_type == _vLLMGroupType.WORLD: return constants.tp_and_pp_rank() elif group_type == _vLLMGroupType.TP: - return constants.model_parallel_rank() + return constants.tensor_parallel_rank() elif group_type == _vLLMGroupType.PP: return constants.pipe_parallel_rank() @@ -45,7 +45,7 @@ def _vllm_group_size(group_type: _vLLMGroupType): if group_type == _vLLMGroupType.WORLD: return constants.tp_and_pp_world_size() elif group_type == _vLLMGroupType.TP: - return constants.model_parallel_world_size() + return constants.tensor_parallel_world_size() elif group_type == _vLLMGroupType.PP: return constants.pipe_parallel_world_size() @@ -54,7 +54,7 @@ def _vllm_parallel_group(group_type: _vLLMGroupType): if group_type == _vLLMGroupType.WORLD: return constants.tp_and_pp_group() elif group_type == _vLLMGroupType.TP: - return constants.model_parallel_group() + return constants.tensor_parallel_group() elif group_type == _vLLMGroupType.PP: return constants.pipe_parallel_group() diff --git a/realhf/impl/model/backend/thirdparty/vllm/executor.py b/realhf/impl/model/backend/thirdparty/vllm/executor.py index d8a3cbe..21033dc 100644 --- a/realhf/impl/model/backend/thirdparty/vllm/executor.py +++ b/realhf/impl/model/backend/thirdparty/vllm/executor.py @@ -213,7 +213,7 @@ class GPUExecutor_(GPUExecutor): tok = time.perf_counter() after_mem = float(pynvml.nvmlDeviceGetMemoryInfo(handle).used) is_dp_head = ( - constants.is_last_pipe_stage() and constants.model_parallel_rank() == 0 + constants.is_last_pipe_stage() and constants.tensor_parallel_rank() == 0 ) if is_dp_head: logger.info( @@ -241,7 +241,7 @@ class GPUExecutor_(GPUExecutor): tok = time.perf_counter() after_mem = float(pynvml.nvmlDeviceGetMemoryInfo(handle).used) is_dp_head = ( - constants.is_last_pipe_stage() and constants.model_parallel_rank() == 0 + constants.is_last_pipe_stage() and constants.tensor_parallel_rank() == 0 ) if is_dp_head: logger.info( diff --git a/realhf/impl/model/backend/vllm.py b/realhf/impl/model/backend/vllm.py index ac50bb8..c65f0c1 100644 --- a/realhf/impl/model/backend/vllm.py +++ b/realhf/impl/model/backend/vllm.py @@ -166,7 +166,6 @@ class vLLMGenerationEngine(model_api.PipelinableEngine, LLM): @dataclasses.dataclass class vLLMGenerationBackend(vLLMConfig, model_api.ModelBackend): model_path: str = "" - dtype: str = "bfloat16" def _initialize( self, model: model_api.Model, spec: model_api.FinetuneSpec @@ -192,7 +191,7 @@ class vLLMGenerationBackend(vLLMConfig, model_api.ModelBackend): kv_cache_dtype=self.kv_cache_type, device=constants.current_device(), # Parallelism. - tensor_parallel_size=constants.model_parallel_world_size(), + tensor_parallel_size=constants.tensor_parallel_world_size(), pipeline_parallel_size=constants.pipe_parallel_world_size(), # KV cahce and scheduling. num_scheduler_steps=self.num_scheduler_steps, diff --git a/realhf/impl/model/comm/global_comm.py b/realhf/impl/model/comm/global_comm.py index b28535c..f0820cd 100644 --- a/realhf/impl/model/comm/global_comm.py +++ b/realhf/impl/model/comm/global_comm.py @@ -100,7 +100,7 @@ def setup_global_comm( if worker_index == 0: host_ip = socket.gethostbyname(socket.gethostname()) - port = network.find_free_port() + port = network.find_free_port(experiment_name=expr_name, trial_name=trial_name) pg_init_addr = f"tcp://{host_ip}:{port}" name_resolve.add(pg_master_name, pg_init_addr, keepalive_ttl=300) else: diff --git a/realhf/impl/model/comm/param_realloc.py b/realhf/impl/model/comm/param_realloc.py index c2a161b..cdc3ade 100644 --- a/realhf/impl/model/comm/param_realloc.py +++ b/realhf/impl/model/comm/param_realloc.py @@ -43,10 +43,10 @@ def is_trainable(model_name: ModelName) -> bool: class ParamReallocPair: src: ModelName src_dp_rank: int - src_mp_rank: int + src_tp_rank: int src_pp_rank: int dst: ModelName - dst_mp_rank: int + dst_tp_rank: int dst_pp_rank: int @@ -171,32 +171,34 @@ def _create_param_realloc_groups( range(from_topo.get_dim("pipe")), range(to_topo.get_dim("pipe")) ): # create tensor reshard groups - src_mp_size = from_topo.get_dim("model") - dst_mp_size = to_topo.get_dim("model") + src_tp_size = from_topo.get_dim("tensor") + dst_tp_size = to_topo.get_dim("tensor") - for mp_j in range(dst_mp_size): + for tp_j in range(dst_tp_size): _all_dst_ranks = filter_match_mwids( - dst, to_topo, msid2mwid, pipe=pp_j, model=mp_j + dst, to_topo, msid2mwid, pipe=pp_j, tensor=tp_j ) - if src_mp_size > dst_mp_size: - factor = src_mp_size // dst_mp_size - mp_is = list(range(factor * mp_j, factor * (mp_j + 1))) + if src_tp_size > dst_tp_size: + factor = src_tp_size // dst_tp_size + tp_is = list(range(factor * tp_j, factor * (tp_j + 1))) _all_src_ranks = [ - filter_match_mwids(src, from_topo, msid2mwid, model=mp_i, pipe=pp_i) - for mp_i in mp_is + filter_match_mwids( + src, from_topo, msid2mwid, tensor=tp_i, pipe=pp_i + ) + for tp_i in tp_is ] else: - factor = dst_mp_size // src_mp_size + factor = dst_tp_size // src_tp_size _all_src_ranks = [ filter_match_mwids( src, from_topo, msid2mwid, - model=mp_j // factor, + tensor=tp_j // factor, pipe=pp_i, ) ] - # All GPUs in _src_ranks have the data required by (pp_j, mp_j) + # All GPUs in _src_ranks have the data required by (pp_j, tp_j) for _src_ranks in _all_src_ranks: # NOTE: inter-node communication cost is significantly larger than intra-node communication cost. # We only select one sender per host/node to prevent multiple senders occupying the same network bandwidth. @@ -209,42 +211,42 @@ def _create_param_realloc_groups( ) _idle_src_ranks = [r for r in _src_ranks if r not in assignment] for _src_rank in _idle_src_ranks: - dp_i, mp_i = ( + dp_i, tp_i = ( from_topo.get_coord( mwid2msid[_src_rank][src].parallelism_rank ).data, from_topo.get_coord( mwid2msid[_src_rank][src].parallelism_rank - ).model, + ).tensor, ) key = ParamReallocPair( src=src, src_dp_rank=dp_i, - src_mp_rank=mp_i, + src_tp_rank=tp_i, src_pp_rank=pp_i, dst=dst, - dst_mp_rank=mp_j, + dst_tp_rank=tp_j, dst_pp_rank=pp_j, ) param_realloc_dst_ranks[key] = [] param_realloc_groups[key] = None param_realloc_src_ranks[key] = _src_rank for _src_rank, _dst_ranks in assignment.items(): - dp_i, mp_i = ( + dp_i, tp_i = ( from_topo.get_coord( mwid2msid[_src_rank][src].parallelism_rank ).data, from_topo.get_coord( mwid2msid[_src_rank][src].parallelism_rank - ).model, + ).tensor, ) key = ParamReallocPair( src=src, src_dp_rank=dp_i, - src_mp_rank=mp_i, + src_tp_rank=tp_i, src_pp_rank=pp_i, dst=dst, - dst_mp_rank=mp_j, + dst_tp_rank=tp_j, dst_pp_rank=pp_j, ) param_realloc_dst_ranks[key] = _dst_ranks @@ -315,8 +317,8 @@ def setup_param_realloc( @dataclasses.dataclass class ReparallelizeSenderStep: rank: int - sender_mp_portion_id: int - receiver_mp_portion_id: int + sender_tp_portion_id: int + receiver_tp_portion_id: int param_keys: List[str] param_intervals_cpu: List[Tuple[int, int]] param_intervals_cuda: torch.Tensor @@ -330,8 +332,8 @@ class ReparallelizeSenderStep: @dataclasses.dataclass class ReparallelizeReceiverStep: rank: int - sender_mp_portion_id: int - receiver_mp_portion_id: int + sender_tp_portion_id: int + receiver_tp_portion_id: int sender_param_intervals_cpu: List[Tuple[int, int]] sender_param_intervals_cuda: torch.Tensor sender_max_interval_size: int @@ -356,9 +358,9 @@ def _derive_reparallelize_comm_plan( pg_info: ParamReallocInfo, dtype: Optional[torch.dtype] = torch.float16, ) -> List[ReparallelizeReceiverStep | ReparallelizeSenderStep]: - src_mp_size = from_topo.get_dim("model") - dst_mp_size = to_topo.get_dim("model") - assert src_mp_size % dst_mp_size == 0 or dst_mp_size % src_mp_size == 0 + src_tp_size = from_topo.get_dim("tensor") + dst_tp_size = to_topo.get_dim("tensor") + assert src_tp_size % dst_tp_size == 0 or dst_tp_size % src_tp_size == 0 for k, v in dataclasses.asdict(to_model_config).items(): if k not in ["is_critic"] and v != getattr(from_model_config, k): raise ValueError( @@ -366,8 +368,8 @@ def _derive_reparallelize_comm_plan( f"value in checkpoint is `{v}`, current value is `{getattr(from_model_config, k)}`)." ) if from_model_config.n_kv_heads > 1 and ( - from_model_config.n_kv_heads % src_mp_size == 0 - ) != (from_model_config.n_kv_heads % dst_mp_size == 0): + from_model_config.n_kv_heads % src_tp_size == 0 + ) != (from_model_config.n_kv_heads % dst_tp_size == 0): raise ValueError("Whether to partition kv heads should remain the same.") from_layer_mapping = partition_pipeline_layers( @@ -400,7 +402,7 @@ def _derive_reparallelize_comm_plan( from_model_param_specs, _ = build_param_spec( from_layer_indices, from_model_config, - mp_size=from_topo.get_dim("model"), + tp_size=from_topo.get_dim("tensor"), dp_size=from_topo.get_dim("data"), pp_size=from_topo.get_dim("pipe"), head_param_point_to_embedding=from_model_head_param_point_to_embedding, @@ -411,7 +413,7 @@ def _derive_reparallelize_comm_plan( to_model_param_specs, _ = build_param_spec( to_layer_indices, to_model_config, - mp_size=to_topo.get_dim("model"), + tp_size=to_topo.get_dim("tensor"), pp_size=to_topo.get_dim("pipe"), dp_size=to_topo.get_dim("data"), head_param_point_to_embedding=to_model_head_param_point_to_embedding, @@ -428,25 +430,25 @@ def _derive_reparallelize_comm_plan( if len(layer_indices) == 0: continue - for mp_i in range(src_mp_size): - if dst_mp_size > src_mp_size: - factor = dst_mp_size // src_mp_size - mp_js = [i + factor * mp_i for i in range(factor)] - receiver_mp_portion_id = 0 + for tp_i in range(src_tp_size): + if dst_tp_size > src_tp_size: + factor = dst_tp_size // src_tp_size + tp_js = [i + factor * tp_i for i in range(factor)] + receiver_tp_portion_id = 0 else: - factor = src_mp_size // dst_mp_size - mp_js = [mp_i // factor] - receiver_mp_portion_id = mp_i % factor - for sender_mp_portion_id, mp_j in enumerate(mp_js): + factor = src_tp_size // dst_tp_size + tp_js = [tp_i // factor] + receiver_tp_portion_id = tp_i % factor + for sender_tp_portion_id, tp_j in enumerate(tp_js): for dp_i in range(src_dp_size): key = ParamReallocPair( src=from_model_name, src_dp_rank=dp_i, - src_mp_rank=mp_i, + src_tp_rank=tp_i, src_pp_rank=pp_i, dst=to_model_name, - dst_mp_rank=mp_j, + dst_tp_rank=tp_j, dst_pp_rank=pp_j, ) src = pg_info.param_realloc_src_ranks[key] @@ -462,10 +464,10 @@ def _derive_reparallelize_comm_plan( ) param_size = param_size_from_keys( config=from_model_config, - src_mp_size=src_mp_size, + src_tp_size=src_tp_size, sd_keys=param_keys, - src2dst_tp_size=max(dst_mp_size // src_mp_size, 1), - src2dst_tp_rank=sender_mp_portion_id, + src2dst_tp_size=max(dst_tp_size // src_tp_size, 1), + src2dst_tp_rank=sender_tp_portion_id, head_param_point_to_embedding=from_model_head_param_point_to_embedding, ) if torch.distributed.is_initialized(): @@ -474,11 +476,11 @@ def _derive_reparallelize_comm_plan( param_intervals_cpu = param_intervals_from_keys( model_name=from_model_name, config=from_model_config, - mp_size=src_mp_size, + tp_size=src_tp_size, param_spec=from_model_param_specs, sd_keys=param_keys, - portion_size=max(dst_mp_size // src_mp_size, 1), - portion_rank=sender_mp_portion_id, + portion_size=max(dst_tp_size // src_tp_size, 1), + portion_rank=sender_tp_portion_id, head_param_point_to_embedding=from_model_head_param_point_to_embedding, ) param_intervals_cuda = torch.tensor( @@ -493,11 +495,11 @@ def _derive_reparallelize_comm_plan( receiver_param_intervals_cpu = param_intervals_from_keys( model_name=to_model_name, config=to_model_config, - mp_size=dst_mp_size, + tp_size=dst_tp_size, param_spec=to_model_param_specs, sd_keys=param_keys, - portion_size=max(src_mp_size // dst_mp_size, 1), - portion_rank=receiver_mp_portion_id, + portion_size=max(src_tp_size // dst_tp_size, 1), + portion_rank=receiver_tp_portion_id, head_param_point_to_embedding=to_model_head_param_point_to_embedding, ) receiver_param_intervals_cuda = torch.tensor( @@ -513,8 +515,8 @@ def _derive_reparallelize_comm_plan( comm_plan.append( ReparallelizeReceiverStep( rank=dst_rank, - sender_mp_portion_id=sender_mp_portion_id, - receiver_mp_portion_id=receiver_mp_portion_id, + sender_tp_portion_id=sender_tp_portion_id, + receiver_tp_portion_id=receiver_tp_portion_id, param_keys=param_keys, sender_param_intervals_cpu=param_intervals_cpu, sender_param_intervals_cuda=param_intervals_cuda, @@ -532,8 +534,8 @@ def _derive_reparallelize_comm_plan( comm_plan.append( ReparallelizeSenderStep( rank=src, - sender_mp_portion_id=sender_mp_portion_id, - receiver_mp_portion_id=receiver_mp_portion_id, + sender_tp_portion_id=sender_tp_portion_id, + receiver_tp_portion_id=receiver_tp_portion_id, param_keys=param_keys, param_intervals_cpu=param_intervals_cpu, param_intervals_cuda=param_intervals_cuda, diff --git a/realhf/impl/model/conversion/hf_registry.py b/realhf/impl/model/conversion/hf_registry.py index 697f041..969e946 100644 --- a/realhf/impl/model/conversion/hf_registry.py +++ b/realhf/impl/model/conversion/hf_registry.py @@ -22,8 +22,8 @@ from realhf.base.saveload_utils import ( ) from realhf.impl.model.nn.real_llm_api import ReaLModel from realhf.impl.model.nn.real_llm_parallel import ( - mp_merge_key, - mp_partition_real_model_state_dict, + tp_merge_key, + tp_partition_real_model_state_dict, ) logger = logging.getLogger("HF Registry") @@ -141,11 +141,11 @@ class HFModelRegistry: partition_tik = time.perf_counter() sd = {k: v for k, v in sd.items() if k in required_hf_sd_names} sd = self.sd_from_hf_converter(sd, model.config) - psd = mp_partition_real_model_state_dict( + psd = tp_partition_real_model_state_dict( sd, model.config, - constants.model_parallel_world_size(), - constants.model_parallel_rank(), + constants.tensor_parallel_world_size(), + constants.tensor_parallel_rank(), ) return psd, partition_tik - load_tik, time.perf_counter() - partition_tik @@ -222,8 +222,8 @@ class HFModelRegistry: dp_rank = constants.data_parallel_rank() pp_rank = constants.pipe_parallel_rank() - mp_rank = constants.model_parallel_rank() - mp_size = constants.model_parallel_world_size() + tp_rank = constants.tensor_parallel_rank() + tp_size = constants.tensor_parallel_world_size() pp_size = constants.pipe_parallel_world_size() dp_size = constants.data_parallel_world_size() @@ -234,7 +234,7 @@ class HFModelRegistry: # of each pipeline stage into smaller shards. approx_param_size = ( sum(v.numel() * v.element_size() for v in model.state_dict().values()) - * mp_size + * tp_size ) # By default a shard is at most 1GB. A small size enables parallel saving during training. @@ -274,9 +274,9 @@ class HFModelRegistry: and k == f"{model.config.n_layers + 1}.weight" ): continue - gather_list = [torch.zeros_like(v) for _ in range(mp_size)] - dist.all_gather(gather_list, v, group=constants.model_parallel_group()) - gathered = mp_merge_key(k, gather_list, model.config) + gather_list = [torch.zeros_like(v) for _ in range(tp_size)] + dist.all_gather(gather_list, v, group=constants.tensor_parallel_group()) + gathered = tp_merge_key(k, gather_list, model.config) cpu_sd[k] = gathered.cpu() t2 = time.perf_counter() @@ -299,7 +299,7 @@ class HFModelRegistry: param_size = param_size.item() # Save tokenizer and huggingface model config. - if pp_rank == 0 and dp_rank == 0 and mp_rank == 0: + if pp_rank == 0 and dp_rank == 0 and tp_rank == 0: hf_config.save_pretrained(save_dir) if tokenizer is not None: tokenizer.save_pretrained(save_dir) @@ -307,7 +307,7 @@ class HFModelRegistry: # Dump parameters to disk. if len(pp_stage_n_shards) == 1 and pp_stage_n_shards[0] == 1: fn = "pytorch_model.bin" - if pp_rank == 0 and dp_rank == 0 and mp_rank == 0: + if pp_rank == 0 and dp_rank == 0 and tp_rank == 0: torch.save(hf_sd, os.path.join(save_dir, fn)) else: output_fn = ( @@ -326,8 +326,8 @@ class HFModelRegistry: bin_index["weight_map"] = {} weight_map = {} - mesh_size = dp_size * mp_size - mesh_idx = dp_rank * mp_size + mp_rank + mesh_size = dp_size * tp_size + mesh_idx = dp_rank * tp_size + tp_rank n_shards_per_gpu = (n_shards + mesh_size - 1) // mesh_size if mesh_idx < len(range(0, n_shards, n_shards_per_gpu)): s = list(range(0, n_shards, n_shards_per_gpu))[mesh_idx] @@ -357,7 +357,7 @@ class HFModelRegistry: for wm in weight_map_list: bin_index["weight_map"].update(wm) - if pp_rank == 0 and dp_rank == 0 and mp_rank == 0: + if pp_rank == 0 and dp_rank == 0 and tp_rank == 0: with open( os.path.join(save_dir, "pytorch_model.bin.index.json"), "w" ) as f: diff --git a/realhf/impl/model/interface/math_rw_interface.py b/realhf/impl/model/interface/math_rw_interface.py index 0f15533..3c3c08b 100644 --- a/realhf/impl/model/interface/math_rw_interface.py +++ b/realhf/impl/model/interface/math_rw_interface.py @@ -240,7 +240,7 @@ class MultiTaskRewardInterface(model_api.ModelInterface): return data local_rank = constants.grid().topo.get_rank( data=constants.data_parallel_rank(), - model=0, + tensor=0, pipe=constants.pipe_parallel_world_size() - 1, ) dst = constants.to_global_pg_rank(local_rank) diff --git a/realhf/impl/model/interface/ppo_interface.py b/realhf/impl/model/interface/ppo_interface.py index 76e44f5..54997fd 100644 --- a/realhf/impl/model/interface/ppo_interface.py +++ b/realhf/impl/model/interface/ppo_interface.py @@ -3,7 +3,7 @@ # Licensed under the Apache License, Version 2.0 (the "License"). import dataclasses -from typing import Dict, Literal, Optional +from typing import Dict, List, Literal, Optional import torch import torch.distributed as dist @@ -86,6 +86,7 @@ def _ppo_actor_loss_from_model_outputs( eps_clip=eps_clip, loss_mask=ppo_loss_mask, c_clip=c_clip, + proximal_logprobs=input_.data.get("prox_logp", None), ) # Log training statistics @@ -106,13 +107,20 @@ def _ppo_actor_loss_from_model_outputs( dual_clip_ratio=ppo_stat["dual_clip_mask"].float(), denominator="n_valid_tokens", ) + if "behave_imp_weight" in ppo_stat: + stats_tracker.denominator(unclipped_behave_tokens=ppo_stat["behave_mask"]) + stats_tracker.stat( + behave_imp_weight=ppo_stat["behave_imp_weight"], + behave_approx_kl=ppo_stat["behave_approx_kl"], + denominator="unclipped_behave_tokens", + ) vocab_min_logits = logits.detach().min(-1).values.float() vocab_max_logits = logits.detach().max(-1).values.float() dist.all_reduce( - vocab_min_logits, group=constants.model_parallel_group(), op=dist.ReduceOp.MIN + vocab_min_logits, group=constants.tensor_parallel_group(), op=dist.ReduceOp.MIN ) dist.all_reduce( - vocab_max_logits, group=constants.model_parallel_group(), op=dist.ReduceOp.MAX + vocab_max_logits, group=constants.tensor_parallel_group(), op=dist.ReduceOp.MAX ) stats_tracker.stat( vocab_min_logits=vocab_min_logits, @@ -505,7 +513,7 @@ class PPOActorInterface(model_api.ModelInterface): model: model_api.Model, input_: SequenceSample, mb_spec: MicroBatchSpec, - ) -> Dict: + ) -> Dict | List[Dict]: module = model.module # We call module.eval() because dropout causes the computation of incorrect of log probs. module.eval() @@ -656,15 +664,20 @@ class PPOActorInterface(model_api.ModelInterface): advantages = torch.cat(adv_list, 0) # Prepare data to be splitted into mini-batches. + flat_data = dict( + advantages=advantages, + old_logp=old_logp, + ppo_loss_mask=loss_mask, + packed_input_ids=input_.data["packed_input_ids"], + kl_rewards=kl_rewards, + ) + use_prox_logp = "proximal_logprobs" in input_.data + if use_prox_logp: + flat_data["prox_logp"] = input_.data["proximal_logprobs"].float() + flat_input = SequenceSample.from_default( ids=list(range(input_.bs * self.group_size)), - data=dict( - advantages=advantages, - old_logp=old_logp, - ppo_loss_mask=loss_mask, - packed_input_ids=input_.data["packed_input_ids"], - kl_rewards=kl_rewards, - ), + data=flat_data, seqlens=[int(x) for x in input_lens.cpu().numpy().tolist()], ) @@ -672,6 +685,7 @@ class PPOActorInterface(model_api.ModelInterface): dense_reward_score = dense_reward_score[shift_one_indices] ### Logging code starts. ### + all_stats = [] with stats_tracker.scope("ppo_actor"): assert ( task_ids.shape == reward_score.shape @@ -682,12 +696,13 @@ class PPOActorInterface(model_api.ModelInterface): for idx, task in enumerate(RL_TASKS) } - stats_tracker.denominator( + global_denominators = dict( n_seqs=torch.ones_like(reward_score, dtype=torch.bool), n_tokens=torch.ones_like(prompt_mask, dtype=torch.bool), n_valid_tokens=loss_mask.bool(), **task_denominators, ) + stats_tracker.denominator(**global_denominators) for task in RL_TASKS: stats_tracker.stat( @@ -721,6 +736,22 @@ class PPOActorInterface(model_api.ModelInterface): **seq_stats, denominator="n_seqs", ) + scalars = dict( + disable_value=self.disable_value, + mask_no_eos_with_zero=self.mask_no_eos_with_zero, + eps_clip=self.eps_clip, + use_prox_logp=use_prox_logp, + ) + if self.c_clip is not None: + scalars["c_clip"] = self.c_clip + scalars["use_dual_clip"] = 1 + else: + scalars["use_dual_clip"] = 0 + stats_tracker.scalar(**scalars) + + global_stats = stats_tracker.export() + for k in global_denominators: + global_stats.pop(f"ppo_actor/{k}") # Run mini-batched PPO training! def _loss_fn(logits, input_): @@ -736,43 +767,37 @@ class PPOActorInterface(model_api.ModelInterface): ) for reuse in range(self.sample_reuse): - with stats_tracker.scope(f"reuse{reuse}"): - # NOTE: We split PPO minibatches in terms of #seqs instead of #tokens. - flat_input = SequenceSample.shuffled(flat_input) - bs = flat_input.bs - sizes = [0 for _ in range(self.n_minibatches)] - for idx in range(bs): - sizes[idx % self.n_minibatches] += 1 - spec = SequenceSplitSpec(sizes=sizes) - datas = flat_input.split_with_spec(spec) - logger.info( - f"PPO minibatch split (size {self.n_minibatches}): " - f"#seqs: {[s.bs for s in datas]}, " - f"#tokens: {[sum([sum(lens) for lens in s.seqlens[s._get_split_key()]]) for s in datas]}" + # NOTE: We split PPO minibatches in terms of #seqs instead of #tokens. + flat_input = SequenceSample.shuffled(flat_input) + bs = flat_input.bs + sizes = [0 for _ in range(self.n_minibatches)] + for idx in range(bs): + sizes[idx % self.n_minibatches] += 1 + spec = SequenceSplitSpec(sizes=sizes) + datas = flat_input.split_with_spec(spec) + logger.info( + f"PPO minibatch split (size {self.n_minibatches}): " + f"#seqs: {[s.bs for s in datas]}, " + f"#tokens: {[sum([sum(lens) for lens in s.seqlens[s._get_split_key()]]) for s in datas]}" + ) + for mb_i, data in enumerate(datas): + train_stat = module.train_batch( + input_=data, + mb_spec=mb_spec, + version_steps=model.version.global_step, + loss_fn=_loss_fn, + loss_weight_fn=lambda x: x.data[ + "ppo_loss_mask" + ].count_nonzero(), + token_normalize_scope=self.token_normalize_scope, ) - for mb_i, data in enumerate(datas): - with stats_tracker.scope(f"mb{mb_i}"): - train_stat = module.train_batch( - input_=data, - mb_spec=mb_spec, - version_steps=model.version.global_step, - loss_fn=_loss_fn, - loss_weight_fn=lambda x: x.data[ - "ppo_loss_mask" - ].count_nonzero(), - token_normalize_scope=self.token_normalize_scope, - ) - stats_tracker.scalar(**train_stat) + stats_tracker.scalar(**train_stat) + all_stats.append(stats_tracker.export()) - stats_tracker.scalar( - disable_value=self.disable_value, - mask_no_eos_with_zero=self.mask_no_eos_with_zero, - c_clip=self.c_clip if self.c_clip is not None else float("nan"), - eps_clip=self.eps_clip, - ) model.inc_version() + all_stats[0].update(global_stats) - return stats_tracker.export() + return all_stats # Mock methods for profiling only. def _mock_inference( @@ -1033,7 +1058,7 @@ class PPOCriticInterface(model_api.ModelInterface): model: model_api.Model, input_: SequenceSample, mb_spec: MicroBatchSpec, - ) -> Dict: + ) -> Dict | List[Dict]: assert model.module.module.config.is_critic if self.disable_value: diff --git a/realhf/impl/model/interface/sft_interface.py b/realhf/impl/model/interface/sft_interface.py index 2eccfea..7536695 100644 --- a/realhf/impl/model/interface/sft_interface.py +++ b/realhf/impl/model/interface/sft_interface.py @@ -3,7 +3,7 @@ # Licensed under the Apache License, Version 2.0 (the "License"). import dataclasses -from typing import Dict, Literal +from typing import Dict, List, Literal import torch import torch.distributed as dist @@ -68,10 +68,10 @@ def compute_packed_sft_loss( vocab_min_logits = logits.detach().min(-1).values.float() vocab_max_logits = logits.detach().max(-1).values.float() dist.all_reduce( - vocab_min_logits, group=constants.model_parallel_group(), op=dist.ReduceOp.MIN + vocab_min_logits, group=constants.tensor_parallel_group(), op=dist.ReduceOp.MIN ) dist.all_reduce( - vocab_max_logits, group=constants.model_parallel_group(), op=dist.ReduceOp.MAX + vocab_max_logits, group=constants.tensor_parallel_group(), op=dist.ReduceOp.MAX ) stats_tracker.stat( vocab_min_logits=vocab_min_logits, @@ -88,7 +88,7 @@ class SFTInterface(model_api.ModelInterface): def train_step( self, model: model_api.Model, data: SequenceSample, mb_spec: MicroBatchSpec - ) -> Dict: + ) -> Dict | List[Dict]: module = model.module module.train() diff --git a/realhf/impl/model/modules/attn.py b/realhf/impl/model/modules/attn.py index c633fd0..8f428db 100644 --- a/realhf/impl/model/modules/attn.py +++ b/realhf/impl/model/modules/attn.py @@ -10,14 +10,14 @@ import torch.utils.checkpoint import realhf.base.constants as constants import realhf.base.logging as logging -from realhf.impl.model.parallelism.model_parallel.modules import RowParallelLinear +from realhf.impl.model.parallelism.tensor_parallel.modules import RowParallelLinear from realhf.impl.model.utils.functional import ( apply_rotary_varlen, compute_varlen_position_indices, torch_attn_func, ) -from .mlp import LayerNormQKVLinear +from .mlp import GemmaRMSNorm, LayerNormQKVLinear, LlamaRMSNorm from .rotary import RotaryEmbedding try: @@ -53,6 +53,8 @@ class CausalSelfAttentionLayer(nn.Module): layer_norm_type: Optional[str] = None, # opt applies layer norm after attn do_layernorm_before: bool = True, + # qk layer norm (Qwen3) + qk_layernorm: bool = False, # rotary embedding apply_rotary: bool = False, rotary_base: float = 10000.0, @@ -67,7 +69,7 @@ class CausalSelfAttentionLayer(nn.Module): super().__init__() if dtype is None: dtype = torch.float16 - assert hidden_dim % head_dim == 0 + assert hidden_dim % head_dim == 0, (hidden_dim, head_dim) self.c_attn = LayerNormQKVLinear( input_dim=hidden_dim, head_dim=head_dim, @@ -82,7 +84,7 @@ class CausalSelfAttentionLayer(nn.Module): layer_index=layer_index, ) - if constants.model_parallel_world_size() > 1: + if constants.tensor_parallel_world_size() > 1: self.c_proj = RowParallelLinear( n_q_heads * head_dim, hidden_dim, @@ -100,6 +102,21 @@ class CausalSelfAttentionLayer(nn.Module): device=device, ) + self.qk_layernorm = qk_layernorm + if qk_layernorm: + if layer_norm_type is None: + layer_norm_fn = nn.LayerNorm + elif layer_norm_type == "rms": + layer_norm_fn = LlamaRMSNorm + elif layer_norm_type == "gemma": + layer_norm_fn = GemmaRMSNorm + self.q_ln = layer_norm_fn( + head_dim, eps=layer_norm_epsilon, dtype=dtype, device=device + ) + self.k_ln = layer_norm_fn( + head_dim, eps=layer_norm_epsilon, dtype=dtype, device=device + ) + self.resid_dropout = nn.Dropout(resid_pdrop) self.attn_pdrop = attn_pdrop @@ -173,6 +190,10 @@ class CausalSelfAttentionLayer(nn.Module): q, k, v = self.c_attn(hidden_states) + if self.qk_layernorm: + q = self.q_ln(q) + k = self.k_ln(k) + if self.apply_rotary and (k_cache is None or str(q.device) == "cpu"): # otherwise, we input rotary cos/sin directly into flash_attn_with_kvcache rotary_cache_len = max_seqlen diff --git a/realhf/impl/model/modules/embedding.py b/realhf/impl/model/modules/embedding.py index 5a64274..50da1a6 100644 --- a/realhf/impl/model/modules/embedding.py +++ b/realhf/impl/model/modules/embedding.py @@ -8,7 +8,7 @@ import torch import torch.nn as nn from torch.nn import init -from realhf.impl.model.parallelism.model_parallel.modules import ParallelEmbedding +from realhf.impl.model.parallelism.tensor_parallel.modules import ParallelEmbedding class OffsetPositionalEmbedding(nn.Embedding): diff --git a/realhf/impl/model/modules/mlp.py b/realhf/impl/model/modules/mlp.py index 7daf4c9..9f2e478 100644 --- a/realhf/impl/model/modules/mlp.py +++ b/realhf/impl/model/modules/mlp.py @@ -15,7 +15,7 @@ from transformers.activations import ACT2FN import realhf.base.constants as constants import realhf.base.logging as logging -from realhf.impl.model.parallelism.model_parallel.modules import ( +from realhf.impl.model.parallelism.tensor_parallel.modules import ( ColumnParallelLinear, RowParallelLinear, merged_linear_with_grad_accumulation_and_async_allreduce, @@ -49,10 +49,10 @@ class LayerNormQKVLinear(nn.Module): layer_index=None, ): super().__init__() - model_parallel = constants.model_parallel_world_size() > 1 + tensor_parallel = constants.tensor_parallel_world_size() > 1 sequence_parallel = constants.sequence_parallel() gradient_accumulation_fusion = constants.gradient_accumulation_fusion() - if not model_parallel and (sequence_parallel or gradient_accumulation_fusion): + if not tensor_parallel and (sequence_parallel or gradient_accumulation_fusion): global SEQUENCE_PARALLEL_WARNED if not SEQUENCE_PARALLEL_WARNED: logger.warning( @@ -73,16 +73,16 @@ class LayerNormQKVLinear(nn.Module): input_dim, eps=layer_norm_epsilon, dtype=dtype, device=device ) - self.model_parallel = model_parallel + self.tensor_parallel = tensor_parallel self.layer_index = layer_index - self.mp_worldsize = constants.model_parallel_world_size() - assert n_q_heads % self.mp_worldsize == 0, ( + self.tp_worldsize = constants.tensor_parallel_world_size() + assert n_q_heads % self.tp_worldsize == 0, ( f"n_q_heads {n_q_heads} must be divisible by " - f"mp_worldsize {self.mp_worldsize}" + f"tp_worldsize {self.tp_worldsize}" ) - assert n_kv_heads % self.mp_worldsize == 0, ( + assert n_kv_heads % self.tp_worldsize == 0, ( f"n_kv_heads {n_kv_heads} must be divisible by " - f"mp_worldsize {self.mp_worldsize}" + f"tp_worldsize {self.tp_worldsize}" ) hidden_dim = input_dim # TODO: we can fuse the forward of qkv attention @@ -141,9 +141,9 @@ class LayerNormQKVLinear(nn.Module): self.v_attn.weight, self.v_attn.bias, ) - q = q.view(*q.shape[:-1], self.nq // self.mp_worldsize, self.d) - k = k.view(*k.shape[:-1], self.nkv // self.mp_worldsize, self.d) - v = v.view(*v.shape[:-1], self.nkv // self.mp_worldsize, self.d) + q = q.view(*q.shape[:-1], self.nq // self.tp_worldsize, self.d) + k = k.view(*k.shape[:-1], self.nkv // self.tp_worldsize, self.d) + v = v.view(*v.shape[:-1], self.nkv // self.tp_worldsize, self.d) return q, k, v @@ -163,10 +163,10 @@ class LayerNormMLP(nn.Module): device: Optional[Union[str, torch.device]] = None, ): super().__init__() - model_parallel = constants.model_parallel_world_size() > 1 + tensor_parallel = constants.tensor_parallel_world_size() > 1 sequence_parallel = constants.sequence_parallel() gradient_accumulation_fusion = constants.gradient_accumulation_fusion() - if not model_parallel and (sequence_parallel or gradient_accumulation_fusion): + if not tensor_parallel and (sequence_parallel or gradient_accumulation_fusion): global SEQUENCE_PARALLEL_WARNED if not SEQUENCE_PARALLEL_WARNED: logger.warning( @@ -228,12 +228,12 @@ class LlamaLayerNormMLP(nn.Module): device: Optional[Union[str, torch.device]] = None, ): super().__init__() - self.model_parallel = constants.model_parallel_world_size() > 1 + self.tensor_parallel = constants.tensor_parallel_world_size() > 1 gradient_accumulation_fusion = constants.gradient_accumulation_fusion() self.is_expert = is_expert # when used as experts the MLP always compute without sequence parallel sequence_parallel = constants.sequence_parallel() and not is_expert - if not self.model_parallel and ( + if not self.tensor_parallel and ( sequence_parallel or gradient_accumulation_fusion ): global SEQUENCE_PARALLEL_WARNED @@ -418,13 +418,13 @@ if constants.use_te_impl(): eps=layer_norm_epsilon, sequence_parallel=constants.sequence_parallel(), return_bias=False, - tp_group=constants.model_parallel_group(), - tp_size=constants.model_parallel_world_size(), + tp_group=constants.tensor_parallel_group(), + tp_size=constants.tensor_parallel_world_size(), bias=False, normalization="RMSNorm", activation="swiglu", fuse_wgrad_accumulation=constants.gradient_accumulation_fusion(), params_dtype=dtype, - set_parallel_mode=constants.model_parallel_world_size() > 1, + set_parallel_mode=constants.tensor_parallel_world_size() > 1, device=device, ) diff --git a/realhf/impl/model/modules/moe/experts.py b/realhf/impl/model/modules/moe/experts.py index 6b3e612..10b356b 100644 --- a/realhf/impl/model/modules/moe/experts.py +++ b/realhf/impl/model/modules/moe/experts.py @@ -10,11 +10,11 @@ from torch.nn.parameter import Parameter import realhf.base.constants as constants from realhf.api.core.model_api import ReaLModelConfig from realhf.impl.model.modules.mlp import LlamaLayerNormMLP, get_activation_fn -from realhf.impl.model.parallelism.model_parallel.mappings import ( +from realhf.impl.model.parallelism.tensor_parallel.mappings import ( copy_to_tensor_model_parallel_region, reduce_from_tensor_model_parallel_region, ) -from realhf.impl.model.parallelism.model_parallel.utils import divide +from realhf.impl.model.parallelism.tensor_parallel.utils import divide from realhf.impl.model.utils.random import _initialize_affine_weight_gpu try: @@ -125,7 +125,7 @@ class GroupedMLP(torch.nn.Module): self.activation_func = get_activation_fn(self.config.activation_function) # How many feature each rank holds for fc1 and fc2, respectively. - tp_size = constants.model_parallel_world_size() + tp_size = constants.tensor_parallel_world_size() intermediate_dim_per_partition = divide(self.config.intermediate_dim, tp_size) # Note: The current kernel implementations of grouped_gemm @@ -186,7 +186,7 @@ class GroupedMLP(torch.nn.Module): ): tokens_per_expert = tokens_per_expert.cpu() if permuted_local_hidden_states.nelement() != 0: - if constants.model_parallel_world_size() > 1: + if constants.tensor_parallel_world_size() > 1: permuted_local_hidden_states = copy_to_tensor_model_parallel_region( permuted_local_hidden_states ) @@ -208,7 +208,7 @@ class GroupedMLP(torch.nn.Module): output = grouped_gemm.ops.gmm( inter, self.grouped_down_proj, tokens_per_expert, trans_b=False ) - if constants.model_parallel_world_size() > 1: + if constants.tensor_parallel_world_size() > 1: output = reduce_from_tensor_model_parallel_region(output) else: # No token is allocated for local experts. diff --git a/realhf/impl/model/modules/moe/router.py b/realhf/impl/model/modules/moe/router.py index 872cbfe..2e28acf 100644 --- a/realhf/impl/model/modules/moe/router.py +++ b/realhf/impl/model/modules/moe/router.py @@ -8,7 +8,7 @@ import torch.nn.init as init import realhf.base.constants as constants from realhf.api.core.model_api import ReaLModelConfig -from realhf.impl.model.parallelism.model_parallel.mappings import ( +from realhf.impl.model.parallelism.tensor_parallel.mappings import ( gather_from_sequence_parallel_region, ) from realhf.impl.model.utils.moe import ( @@ -117,10 +117,10 @@ class TopKRouter(torch.nn.Module): torch.Tensor: The activation tensor with the attached gradient function. """ moe_aux_loss_coeff = self.config.moe.aux_loss_coeff - moe_aux_loss_coeff /= constants.model_parallel_world_size() + moe_aux_loss_coeff /= constants.tensor_parallel_world_size() scale_for_logging = 1.0 if constants.sequence_parallel(): - scale_for_logging *= constants.model_parallel_world_size() + scale_for_logging *= constants.tensor_parallel_world_size() aux_loss = switch_load_balancing_loss_func( probs, @@ -128,7 +128,7 @@ class TopKRouter(torch.nn.Module): self.config.moe.top_k, moe_aux_loss_coeff, sequence_partition_group=( - constants.model_parallel_group() + constants.tensor_parallel_group() if constants.sequence_parallel() else None ), @@ -155,7 +155,7 @@ class TopKRouter(torch.nn.Module): """ if self.config.moe.z_loss_coeff > 0: moe_z_loss_coeff = ( - self.config.moe.z_loss_coeff / constants.model_parallel_world_size() + self.config.moe.z_loss_coeff / constants.tensor_parallel_world_size() ) z_loss = z_loss_func(logits, moe_z_loss_coeff) logits = MoEAuxLossAutoScaler.apply(logits, z_loss) diff --git a/realhf/impl/model/modules/moe/token_dispatcher.py b/realhf/impl/model/modules/moe/token_dispatcher.py index e106aca..d9bac41 100644 --- a/realhf/impl/model/modules/moe/token_dispatcher.py +++ b/realhf/impl/model/modules/moe/token_dispatcher.py @@ -7,7 +7,7 @@ import torch import realhf.base.constants as constants from realhf.api.core.model_api import ReaLModelConfig -from realhf.impl.model.parallelism.model_parallel.mappings import ( +from realhf.impl.model.parallelism.tensor_parallel.mappings import ( gather_from_sequence_parallel_region, scatter_to_sequence_parallel_region, ) diff --git a/realhf/impl/model/nn/flatten_param.py b/realhf/impl/model/nn/flatten_param.py index 0810a8c..236dad2 100644 --- a/realhf/impl/model/nn/flatten_param.py +++ b/realhf/impl/model/nn/flatten_param.py @@ -23,8 +23,8 @@ from .real_llm_base import ReaLModelParamKeys from .real_llm_parallel import ( get_real_model_param_shape, intervals_partition_fn, - mp_partition_key, shape_partition_fn, + tp_partition_key, ) try: @@ -188,7 +188,7 @@ def set_intervals( def param_size_from_keys( config: model_api.ReaLModelConfig, - src_mp_size: int, + src_tp_size: int, sd_keys: List[str], src2dst_tp_size: int, src2dst_tp_rank: int, @@ -202,9 +202,9 @@ def param_size_from_keys( and "0.wte.weight" in sd_keys ): continue - new_shape = mp_partition_key( + new_shape = tp_partition_key( k, - get_real_model_param_shape(k, config, src_mp_size), + get_real_model_param_shape(k, config, src_tp_size), src2dst_tp_rank, src2dst_tp_size, config, @@ -218,7 +218,7 @@ def build_param_spec( layer_indices: List[int], config: model_api.ReaLModelConfig, dp_size: int, - mp_size: int, + tp_size: int, pp_size: int, head_param_point_to_embedding: bool, bucket_size: int = 40000000, @@ -273,7 +273,7 @@ def build_param_spec( if head_param_point_to_embedding and k == f"{config.n_layers + 1}.weight": continue - shape = get_real_model_param_shape(k, config, mp_size) + shape = get_real_model_param_shape(k, config, tp_size) numel = int(np.prod(shape)) data_end_index = data_start_index + numel @@ -307,14 +307,14 @@ def param_intervals_from_keys( config: model_api.ReaLModelConfig, head_param_point_to_embedding: bool, param_spec: Dict[str, ContiguousParamSpec], - mp_size: int, + tp_size: int, sd_keys: List[str], portion_size: int, portion_rank: int, ) -> List[int]: param_size = param_size_from_keys( config=config, - src_mp_size=mp_size, + src_tp_size=tp_size, sd_keys=sd_keys, src2dst_tp_size=portion_size, src2dst_tp_rank=portion_rank, @@ -333,13 +333,13 @@ def param_intervals_from_keys( if ( model_name, k.split(".", 1)[1], - mp_size, + tp_size, portion_rank, portion_size, ) not in _FLAT_PARAM_INDICES_CACHE: - zero_start_intervals = mp_partition_key( + zero_start_intervals = tp_partition_key( k, - get_real_model_param_shape(k, config, mp_size), + get_real_model_param_shape(k, config, tp_size), portion_rank, portion_size, config, @@ -349,7 +349,7 @@ def param_intervals_from_keys( ( model_name, k.split(".", 1)[1], - mp_size, + tp_size, portion_rank, portion_size, ) @@ -359,7 +359,7 @@ def param_intervals_from_keys( ( model_name, k.split(".", 1)[1], - mp_size, + tp_size, portion_rank, portion_size, ) diff --git a/realhf/impl/model/nn/real_llm_api.py b/realhf/impl/model/nn/real_llm_api.py index 905c80c..bfc70ef 100644 --- a/realhf/impl/model/nn/real_llm_api.py +++ b/realhf/impl/model/nn/real_llm_api.py @@ -167,7 +167,7 @@ class ReaLModel(nn.Module): self._param_spec, self._param_size = build_param_spec( list(range(self.layer_idx_start, self.layer_idx_end)), self.config, - mp_size=constants.model_parallel_world_size(), + tp_size=constants.tensor_parallel_world_size(), pp_size=constants.pipe_parallel_world_size(), dp_size=constants.data_parallel_world_size(), head_param_point_to_embedding=self.head_param_point_to_embedding, @@ -282,7 +282,7 @@ class ReaLModel(nn.Module): device=device, dtype=dtype, ) - elif not config.is_critic and constants.model_parallel_world_size() > 1: + elif not config.is_critic and constants.tensor_parallel_world_size() > 1: l = ParallelActorHead( config.hidden_dim, config.vocab_size, @@ -428,14 +428,14 @@ class ReaLModel(nn.Module): x.cu_seqlens = x.cu_seqlens.int() # Copy input tensor to a pinned buffer. - mp_size = constants.model_parallel_world_size() + tp_size = constants.tensor_parallel_world_size() batch_length = None if ys[0].packed_input_ids is not None: batch_length = ys[0].packed_input_ids.shape[0] if x.pp_input is not None: batch_length = x.pp_input.shape[0] assert batch_length is not None - padded_batch_length = (batch_length + mp_size - 1) // mp_size * mp_size + padded_batch_length = (batch_length + tp_size - 1) // tp_size * tp_size pad_size = padded_batch_length - batch_length if ( @@ -609,7 +609,7 @@ class ReaLModel(nn.Module): to_param_spec, to_param_size = build_param_spec( to_layer_indices, to_model_config, - mp_size=to_topo.get_dim("model"), + tp_size=to_topo.get_dim("tensor"), dp_size=to_topo.get_dim("data"), pp_size=to_topo.get_dim("pipe"), head_param_point_to_embedding=to_model_head_param_point_to_embedding, diff --git a/realhf/impl/model/nn/real_llm_base.py b/realhf/impl/model/nn/real_llm_base.py index 96837e4..231bcfd 100644 --- a/realhf/impl/model/nn/real_llm_base.py +++ b/realhf/impl/model/nn/real_llm_base.py @@ -17,7 +17,7 @@ import transformers import realhf.base.constants as constants import realhf.base.logging as logging -import realhf.impl.model.parallelism.model_parallel.mappings as tensor_parallel +import realhf.impl.model.parallelism.tensor_parallel.mappings as tensor_parallel from realhf.api.core import model_api from realhf.impl.model.modules import ( CausalSelfAttentionLayer, @@ -28,9 +28,8 @@ from realhf.impl.model.modules import ( LlamaRMSNorm, OffsetParallelPositionalEmbedding, OffsetPositionalEmbedding, - scatter_to_sequence_parallel_region, ) -from realhf.impl.model.parallelism.model_parallel.modules import ( +from realhf.impl.model.parallelism.tensor_parallel.modules import ( ColumnParallelLinear, ParallelEmbedding, gather_from_sequence_parallel_region, @@ -139,6 +138,7 @@ class ReaLModelBlock(nn.Module): use_attention_bias=config.use_attention_bias, use_attn_proj_bias=config.use_attn_proj_bias, do_layernorm_before=config.do_layernorm_before, + qk_layernorm=config.qk_layernorm, apply_rotary=config.apply_rotary, rotary_base=config.rotary_base, rotary_interleaved=config.rotary_interleaved, @@ -281,8 +281,8 @@ class VocabPositionEmbedding(nn.Module): self.n_positions = config.n_positions self.hidden_dim = config.hidden_dim - model_parallel = constants.model_parallel_world_size() > 1 - if model_parallel: + tensor_parallel = constants.tensor_parallel_world_size() > 1 + if tensor_parallel: embed_cls = ParallelEmbedding else: embed_cls = nn.Embedding @@ -295,7 +295,7 @@ class VocabPositionEmbedding(nn.Module): if self.apply_abs_pos_embed: p_embed_cls = ( OffsetParallelPositionalEmbedding - if model_parallel + if tensor_parallel else OffsetPositionalEmbedding ) self.wpe = p_embed_cls( @@ -416,7 +416,7 @@ class ParallelActorHead(ColumnParallelLinear): def _forward(self, x: torch.Tensor): weight = self.weight if self._norm_head: - from realhf.impl.model.parallelism.model_parallel.mappings import ( + from realhf.impl.model.parallelism.tensor_parallel.mappings import ( gather_from_sequence_parallel_region, ) @@ -431,7 +431,7 @@ class ParallelActorHead(ColumnParallelLinear): ).transpose(1, 0) head_norm = unnormed_head.norm(dim=0, keepdim=True, p=2) normed_head = unnormed_head / (head_norm + 1e-7) - weight = scatter_to_sequence_parallel_region(normed_head) + weight = tensor_parallel.scatter_to_sequence_parallel_region(normed_head) output = parallel_lm_logits( x, @@ -486,6 +486,12 @@ class ReaLModelParamKeys: keys += [f"{idx + 1}.attn.c_proj.weight"] if config.use_attn_proj_bias: keys += [f"{idx + 1}.attn.c_proj.bias"] + if config.qk_layernorm: + keys += [f"{idx + 1}.attn.q_ln.weight"] + keys += [f"{idx + 1}.attn.k_ln.weight"] + if config.layer_norm_type is None: + keys += [f"{idx + 1}.attn.q_ln.bias"] + keys += [f"{idx + 1}.attn.k_ln.bias"] keys += [f"{idx + 1}.mlp.ln.weight"] if config.layer_norm_type is None: keys += [f"{idx + 1}.mlp.ln.bias"] diff --git a/realhf/impl/model/nn/real_llm_generate.py b/realhf/impl/model/nn/real_llm_generate.py index b3e8c12..5dae9d9 100644 --- a/realhf/impl/model/nn/real_llm_generate.py +++ b/realhf/impl/model/nn/real_llm_generate.py @@ -55,8 +55,8 @@ def genstep( unfinished_sequences: Bool tensor indicator of whether a sequence is finished. Shape [bs]. """ - if constants.model_parallel_world_size() > 1: - from realhf.impl.model.parallelism.model_parallel.mappings import ( + if constants.tensor_parallel_world_size() > 1: + from realhf.impl.model.parallelism.tensor_parallel.mappings import ( gather_from_tensor_model_parallel_region, ) @@ -95,20 +95,20 @@ def genstep( next_tokens = distrb.mode if gconfig.greedy else distrb.sample() logprob = distrb.log_prob(next_tokens) - if constants.model_parallel_world_size() > 1: - if constants.model_parallel_rank() > 0: + if constants.tensor_parallel_world_size() > 1: + if constants.tensor_parallel_rank() > 0: logprob[:] = 0 next_tokens[:] = 0 handle = torch.distributed.all_reduce( logprob, torch.distributed.ReduceOp.SUM, async_op=True, - group=constants.model_parallel_group(), + group=constants.tensor_parallel_group(), ) torch.distributed.all_reduce( next_tokens, torch.distributed.ReduceOp.SUM, - group=constants.model_parallel_group(), + group=constants.tensor_parallel_group(), ) if tokenizer.eos_token_id is not None: @@ -139,7 +139,7 @@ def genstep( if not logits_mask.any(): logits_mask = None - if constants.model_parallel_world_size() > 1: + if constants.tensor_parallel_world_size() > 1: handle.wait() return next_tokens, logprob, logits_mask, terminate, unfinished_sequences diff --git a/realhf/impl/model/nn/real_llm_parallel.py b/realhf/impl/model/nn/real_llm_parallel.py index 4eb876d..6dbd8a1 100644 --- a/realhf/impl/model/nn/real_llm_parallel.py +++ b/realhf/impl/model/nn/real_llm_parallel.py @@ -42,27 +42,27 @@ if constants.use_te_impl(): def tensor_slice_partition_fn( tensor: torch.Tensor, - mp_rank: Optional[int], - mp_world_size: int, + tp_rank: Optional[int], + tp_world_size: int, dim: Optional[int], ) -> Union[List[torch.Tensor], torch.Tensor]: """Partition a tensor by slicing along a dimension for tensor-model parallelism.""" if dim is None: - splits = [tensor for _ in range(mp_world_size)] + splits = [tensor for _ in range(tp_world_size)] else: - assert tensor.shape[dim] % mp_world_size == 0 - splits = torch.split(tensor, tensor.shape[dim] // mp_world_size, dim=dim) - if mp_rank is None: + assert tensor.shape[dim] % tp_world_size == 0 + splits = torch.split(tensor, tensor.shape[dim] // tp_world_size, dim=dim) + if tp_rank is None: return [s.contiguous() for s in splits] else: - return splits[mp_rank].contiguous() + return splits[tp_rank].contiguous() def intervals_partition_fn( shape: torch.Size, - mp_rank: Optional[int], - mp_world_size: int, + tp_rank: Optional[int], + tp_world_size: int, dim: Optional[int], ) -> Union[List[torch.Tensor], torch.Tensor]: """Get the intervals of a MP-partitioned tensor in the flatten view. @@ -74,34 +74,34 @@ def intervals_partition_fn( Used by parameter reallocation. Return a numpy array of shape [N, 2], where N is the number of intervals. """ - assert mp_rank is not None + assert tp_rank is not None param_size = int(np.prod(shape)) if dim is None: return np.array([(0, param_size)], dtype=np.int64) if dim < 0: dim = len(shape) + dim - assert shape[dim] % mp_world_size == 0 + assert shape[dim] % tp_world_size == 0 if len(shape) == 1: assert dim == 0 - partition_size = shape[0] // mp_world_size + partition_size = shape[0] // tp_world_size return np.array( - [(partition_size * mp_rank, partition_size * (mp_rank + 1))], + [(partition_size * tp_rank, partition_size * (tp_rank + 1))], dtype=np.int64, ) else: assert len(shape) == 2, shape if dim == 0: - row_start = mp_rank * shape[0] // mp_world_size - row_end = (mp_rank + 1) * shape[0] // mp_world_size + row_start = tp_rank * shape[0] // tp_world_size + row_end = (tp_rank + 1) * shape[0] // tp_world_size return np.array( [(row_start * shape[1], row_end * shape[1])], dtype=np.int64 ) else: assert dim == 1 - col_start = mp_rank * shape[1] // mp_world_size - col_end = (mp_rank + 1) * shape[1] // mp_world_size + col_start = tp_rank * shape[1] // tp_world_size + col_end = (tp_rank + 1) * shape[1] // tp_world_size return np.arange(shape[0], dtype=np.int64)[:, None] * shape[1] + np.array( [(col_start, col_end)], dtype=np.int64 ) @@ -109,32 +109,32 @@ def intervals_partition_fn( def shape_partition_fn( shape: torch.Size, - mp_rank: Optional[int], - mp_world_size: int, + tp_rank: Optional[int], + tp_world_size: int, dim: Optional[int], ): """Get the partitioned shape of a tensor for tensor-model parallelism.""" if dim is None: - splits = [shape for _ in range(mp_world_size)] + splits = [shape for _ in range(tp_world_size)] else: if dim < 0: dim = len(shape) + dim - assert shape[dim] % mp_world_size == 0 + assert shape[dim] % tp_world_size == 0 splits = [ - (*shape[:dim], shape[dim] // mp_world_size, *shape[dim + 1 :]) - for _ in range(mp_world_size) + (*shape[:dim], shape[dim] // tp_world_size, *shape[dim + 1 :]) + for _ in range(tp_world_size) ] - if mp_rank is None: + if tp_rank is None: return [s for s in splits] else: - return splits[mp_rank] + return splits[tp_rank] -def mp_partition_key( +def tp_partition_key( key: str, tensor_or_shape: torch.Tensor | torch.Size, - mp_rank: Optional[int], - mp_size: Optional[int], + tp_rank: Optional[int], + tp_size: Optional[int], config: model_api.ReaLModelConfig, partition_fn: Callable[ [torch.Tensor, Optional[int], int, Optional[int]], @@ -149,7 +149,7 @@ def mp_partition_key( if any([ek in key for ek in EMBEDDING_KEYS]): assert "weight" in key - return partition_fn(tensor_or_shape, mp_rank, mp_size, dim=0) + return partition_fn(tensor_or_shape, tp_rank, tp_size, dim=0) elif key == f"{config.n_layers + 1}.weight": # output head if ( isinstance(tensor_or_shape, torch.Tensor) and tensor_or_shape.shape[0] == 1 @@ -157,88 +157,90 @@ def mp_partition_key( not isinstance(tensor_or_shape, torch.Tensor) and tensor_or_shape[0] == 1 ): assert config.is_critic - return partition_fn(tensor_or_shape, mp_rank, mp_size, dim=None) + return partition_fn(tensor_or_shape, tp_rank, tp_size, dim=None) else: - return partition_fn(tensor_or_shape, mp_rank, mp_size, dim=0) + return partition_fn(tensor_or_shape, tp_rank, tp_size, dim=0) elif any([ck in key for ck in COLUMN_LINEAR_KEYS]): if ( ("k_attn" in key) or ("v_attn" in key) - ) and config.n_kv_heads % mp_size != 0: - return partition_fn(tensor_or_shape, mp_rank, mp_size, dim=None) + ) and config.n_kv_heads % tp_size != 0: + return partition_fn(tensor_or_shape, tp_rank, tp_size, dim=None) # partition both weight and bias - return partition_fn(tensor_or_shape, mp_rank, mp_size, dim=0) + return partition_fn(tensor_or_shape, tp_rank, tp_size, dim=0) elif any([rk in key for rk in ROW_LINEAR_KEYS]): # only partition weight if "weight" in key: - return partition_fn(tensor_or_shape, mp_rank, mp_size, dim=1) + return partition_fn(tensor_or_shape, tp_rank, tp_size, dim=1) else: assert "bias" in key, key - return partition_fn(tensor_or_shape, mp_rank, mp_size, dim=None) + return partition_fn(tensor_or_shape, tp_rank, tp_size, dim=None) else: - return partition_fn(tensor_or_shape, mp_rank, mp_size, dim=None) + return partition_fn(tensor_or_shape, tp_rank, tp_size, dim=None) -def mp_partition_real_model_state_dict( +def tp_partition_real_model_state_dict( state_dict: Dict[str, torch.Tensor], config: model_api.ReaLModelConfig, - mp_size: int, - mp_rank: Optional[int] = None, + tp_size: int, + tp_rank: Optional[int] = None, ) -> Union[Dict, List[Dict]]: - """A helper function to partition a state dict using `mp_partition_key`.""" - if mp_size == 1: - if mp_rank is None: + """A helper function to partition a state dict using `tp_partition_key`.""" + if tp_size == 1: + if tp_rank is None: return [state_dict] else: return state_dict new_state_dict = {} for k, v in state_dict.items(): - new_state_dict[k] = mp_partition_key(k, v, mp_rank, mp_size, config) + new_state_dict[k] = tp_partition_key(k, v, tp_rank, tp_size, config) - if mp_rank is None: + if tp_rank is None: return [ - {k: v[mp_rank] for k, v in new_state_dict.items()} - for mp_rank in range(mp_size) + {k: v[tp_rank] for k, v in new_state_dict.items()} + for tp_rank in range(tp_size) ] else: return new_state_dict def get_real_model_param_shape( - k: str, config: model_api.ReaLModelConfig, mp_size: int + k: str, config: model_api.ReaLModelConfig, tp_size: int ) -> Tuple: if "wte.weight" in k: - assert config.vocab_size % mp_size == 0 - return (config.vocab_size // mp_size, config.hidden_dim) + assert config.vocab_size % tp_size == 0 + return (config.vocab_size // tp_size, config.hidden_dim) elif "wpe.weight" in k: - assert config.n_positions % mp_size == 0 - if (config.n_positions + config.abs_position_embedding_offset) % mp_size != 0: + assert config.n_positions % tp_size == 0 + if (config.n_positions + config.abs_position_embedding_offset) % tp_size != 0: raise ValueError( f"The dimenstion of position embedding " f"({config.n_positions} + offset {config.abs_position_embedding_offset}) " - f"is not divisible by mp_size ({mp_size}). " + f"is not divisible by tp_size ({tp_size}). " "Models like this (e.g. OPT-350m) inherently do not support tensor parallelism." ) return ( - (config.n_positions + config.abs_position_embedding_offset) // mp_size, + (config.n_positions + config.abs_position_embedding_offset) // tp_size, config.hidden_dim, ) elif ".ln." in k or ".ln_f." in k: return (config.hidden_dim,) + elif ".q_ln." in k or ".k_ln." in k: + return (config.head_dim,) elif k == f"{config.n_layers + 1}.weight": # output head if config.is_critic: return (1, config.hidden_dim) - elif mp_size > 1: - assert config.vocab_size % mp_size == 0 - return (config.vocab_size // mp_size, config.hidden_dim) + elif tp_size > 1: + assert config.vocab_size % tp_size == 0 + return (config.vocab_size // tp_size, config.hidden_dim) else: return (config.vocab_size, config.hidden_dim) elif any([ck in k for ck in COLUMN_LINEAR_KEYS]): if "k_attn" in k or "v_attn" in k: if "weight" in k: - if config.n_kv_heads % mp_size == 0: + if config.n_kv_heads % tp_size == 0: return ( - config.head_dim * config.n_kv_heads // mp_size, + config.head_dim * config.n_kv_heads // tp_size, config.hidden_dim, ) else: @@ -248,27 +250,27 @@ def get_real_model_param_shape( ) else: assert "bias" in k - if config.n_kv_heads % mp_size == 0: - return (config.head_dim * config.n_kv_heads // mp_size,) + if config.n_kv_heads % tp_size == 0: + return (config.head_dim * config.n_kv_heads // tp_size,) else: return (config.head_dim * config.n_kv_heads,) if "mlp" in k: if "weight" in k: - return (config.intermediate_dim // mp_size, config.hidden_dim) + return (config.intermediate_dim // tp_size, config.hidden_dim) else: assert "bias" in k - return (config.intermediate_dim // mp_size,) + return (config.intermediate_dim // tp_size,) if "weight" in k: - assert config.n_q_heads % mp_size == 0 - return (config.n_q_heads * config.head_dim // mp_size, config.hidden_dim) + assert config.n_q_heads % tp_size == 0 + return (config.n_q_heads * config.head_dim // tp_size, config.hidden_dim) else: assert "bias" in k - return (config.n_q_heads * config.head_dim // mp_size,) + return (config.n_q_heads * config.head_dim // tp_size,) elif any([rk in k for rk in ROW_LINEAR_KEYS]): if "mlp" in k and "weight" in k: - return (config.hidden_dim, config.intermediate_dim // mp_size) + return (config.hidden_dim, config.intermediate_dim // tp_size) elif "attn" in k and "weight" in k: - return (config.hidden_dim, config.n_q_heads * config.head_dim // mp_size) + return (config.hidden_dim, config.n_q_heads * config.head_dim // tp_size) elif "bias" in k: return (config.hidden_dim,) else: @@ -280,7 +282,7 @@ def get_real_model_param_shape( raise NotImplementedError(f"unkown shape of key {k}.") -def mp_merge_key( +def tp_merge_key( k: str, tensors: List[torch.Tensor], config: model_api.ReaLModelConfig, @@ -297,17 +299,17 @@ def mp_merge_key( return tensors[0] -def mp_merge_real_model_state_dict( +def tp_merge_real_model_state_dict( state_dicts: List[Dict[str, torch.Tensor]], config: model_api.ReaLModelConfig, ) -> Dict: - mp_size = len(state_dicts) - if mp_size == 1: + tp_size = len(state_dicts) + if tp_size == 1: return state_dicts[0] new_state_dict = {} for k in state_dicts[0].keys(): - new_state_dict[k] = mp_merge_key(k, [sd[k] for sd in state_dicts], config) + new_state_dict[k] = tp_merge_key(k, [sd[k] for sd in state_dicts], config) return new_state_dict @@ -317,37 +319,37 @@ class ReaLModelParamCount: @staticmethod def _derive_count_from_keys( - keys: List[str], config: model_api.ReaLModelConfig, mp_size: int + keys: List[str], config: model_api.ReaLModelConfig, tp_size: int ) -> int: count = 0 for k in keys: - count += np.prod(get_real_model_param_shape(k, config, mp_size)) + count += np.prod(get_real_model_param_shape(k, config, tp_size)) return int(count) @staticmethod - def embed(config: model_api.ReaLModelConfig, mp_size: int) -> int: + def embed(config: model_api.ReaLModelConfig, tp_size: int) -> int: return ReaLModelParamCount._derive_count_from_keys( - ReaLModelParamKeys.embed(config), config, mp_size + ReaLModelParamKeys.embed(config), config, tp_size ) @staticmethod - def tblock(config: model_api.ReaLModelConfig, idx: int, mp_size: int) -> int: + def tblock(config: model_api.ReaLModelConfig, idx: int, tp_size: int) -> int: return ReaLModelParamCount._derive_count_from_keys( - ReaLModelParamKeys.tblock(config, idx), config, mp_size + ReaLModelParamKeys.tblock(config, idx), config, tp_size ) @staticmethod - def head(config: model_api.ReaLModelConfig, mp_size: int) -> int: + def head(config: model_api.ReaLModelConfig, tp_size: int) -> int: return ReaLModelParamCount._derive_count_from_keys( - ReaLModelParamKeys.head(config), config, mp_size + ReaLModelParamKeys.head(config), config, tp_size ) @staticmethod - def total(config: model_api.ReaLModelConfig, idx: int, mp_size: int) -> int: + def total(config: model_api.ReaLModelConfig, idx: int, tp_size: int) -> int: return ( - config.n_layers * ReaLModelParamCount.tblock(config, idx, mp_size) - + ReaLModelParamCount.head(config, mp_size) - + ReaLModelParamCount.embed(config, mp_size) + config.n_layers * ReaLModelParamCount.tblock(config, idx, tp_size) + + ReaLModelParamCount.head(config, tp_size) + + ReaLModelParamCount.embed(config, tp_size) ) @@ -356,7 +358,7 @@ def partition_pipeline_layers( num_stages: int, method: str = "parameters", ) -> Dict[int, Tuple[int, int]]: - # Ignoring mp_size in param count because tensor parallel equally partitions parameters. + # Ignoring tp_size in param count because tensor parallel equally partitions parameters. # It is irrelevant to how we partition pipeline stages. param_counts = ( [ReaLModelParamCount.embed(config, 1)] diff --git a/realhf/impl/model/parallelism/model_parallel/mappings.py b/realhf/impl/model/parallelism/tensor_parallel/mappings.py similarity index 90% rename from realhf/impl/model/parallelism/model_parallel/mappings.py rename to realhf/impl/model/parallelism/tensor_parallel/mappings.py index 841a544..fef023c 100644 --- a/realhf/impl/model/parallelism/model_parallel/mappings.py +++ b/realhf/impl/model/parallelism/tensor_parallel/mappings.py @@ -13,11 +13,11 @@ def _reduce(input_): """All-reduce the input tensor across model parallel group.""" # Bypass the function if we are using only 1 GPU. - if constants.model_parallel_world_size() == 1: + if constants.tensor_parallel_world_size() == 1: return input_ # All-reduce. - torch.distributed.all_reduce(input_, group=constants.model_parallel_group()) + torch.distributed.all_reduce(input_, group=constants.tensor_parallel_group()) return input_ @@ -25,7 +25,7 @@ def _split_along_last_dim(input_): """Split the tensor along its last dimension and keep the corresponding slice.""" - world_size = constants.model_parallel_world_size() + world_size = constants.tensor_parallel_world_size() # Bypass the function if we are using only 1 GPU. if world_size == 1: return input_ @@ -34,7 +34,7 @@ def _split_along_last_dim(input_): input_list = split_tensor_along_last_dim(input_, world_size) # Note: torch.split does not create contiguous tensors by default. - rank = constants.model_parallel_rank() + rank = constants.tensor_parallel_rank() output = input_list[rank].contiguous() return output @@ -44,7 +44,7 @@ def _split_along_first_dim(input_): """Split the tensor along its first dimension and keep the corresponding slice.""" - world_size = constants.model_parallel_world_size() + world_size = constants.tensor_parallel_world_size() # Bypass the function if we are using only 1 GPU. if world_size == 1: return input_ @@ -55,7 +55,7 @@ def _split_along_first_dim(input_): dim_size % world_size == 0 ), "First dimension of the tensor should be divisible by tensor parallel size" local_dim_size = dim_size // world_size - rank = constants.model_parallel_rank() + rank = constants.tensor_parallel_rank() dim_offset = rank * local_dim_size output = input_[dim_offset : dim_offset + local_dim_size].contiguous() @@ -66,19 +66,19 @@ def _split_along_first_dim(input_): def _gather_along_last_dim(input_): """Gather tensors and concatinate along the last dimension.""" - world_size = constants.model_parallel_world_size() + world_size = constants.tensor_parallel_world_size() # Bypass the function if we are using only 1 GPU. if world_size == 1: return input_ # Size and dimension. last_dim = input_.dim() - 1 - rank = constants.model_parallel_rank() + rank = constants.tensor_parallel_rank() tensor_list = [torch.empty_like(input_) for _ in range(world_size)] tensor_list[rank] = input_ torch.distributed.all_gather( - tensor_list, input_, group=constants.model_parallel_group() + tensor_list, input_, group=constants.tensor_parallel_group() ) # Note: torch.cat already creates a contiguous tensor. @@ -90,7 +90,7 @@ def _gather_along_last_dim(input_): def _gather_along_first_dim(input_): """Gather tensors and concatinate along the first dimension.""" - world_size = constants.model_parallel_world_size() + world_size = constants.tensor_parallel_world_size() # Bypass the function if we are using only 1 GPU. if world_size == 1: return input_ @@ -102,7 +102,7 @@ def _gather_along_first_dim(input_): dim_size, dtype=input_.dtype, device=constants.current_device() ) torch.distributed._all_gather_base( - output, input_.contiguous(), group=constants.model_parallel_group() + output, input_.contiguous(), group=constants.tensor_parallel_group() ) return output @@ -110,7 +110,7 @@ def _gather_along_first_dim(input_): def _reduce_scatter_along_first_dim(input_): """Reduce-scatter the input tensor across model parallel group.""" - world_size = constants.model_parallel_world_size() + world_size = constants.tensor_parallel_world_size() # Bypass the function if we are using only 1 GPU. if world_size == 1: return input_ @@ -128,7 +128,7 @@ def _reduce_scatter_along_first_dim(input_): dim_size, dtype=input_.dtype, device=constants.current_device() ) torch.distributed._reduce_scatter_base( - output, input_.contiguous(), group=constants.model_parallel_group() + output, input_.contiguous(), group=constants.tensor_parallel_group() ) return output diff --git a/realhf/impl/model/parallelism/model_parallel/modules.py b/realhf/impl/model/parallelism/tensor_parallel/modules.py similarity index 96% rename from realhf/impl/model/parallelism/model_parallel/modules.py rename to realhf/impl/model/parallelism/tensor_parallel/modules.py index 1437bb6..cdbc42b 100644 --- a/realhf/impl/model/parallelism/model_parallel/modules.py +++ b/realhf/impl/model/parallelism/tensor_parallel/modules.py @@ -44,7 +44,7 @@ except ImportError: import realhf.base.logging as logging -logger = logging.getLogger("model_parallel.modules") +logger = logging.getLogger("tensor_parallel.modules") def get_activation_fn(activation_function: str) -> Callable: @@ -95,12 +95,12 @@ class ParallelEmbedding(torch.nn.Module): self.scale_grad_by_freq = False self.sparse = False self._weight = None - self.tensor_model_parallel_size = constants.model_parallel_world_size() + self.tensor_model_parallel_size = constants.tensor_parallel_world_size() # Divide the weight matrix along the vocaburaly dimension. self.vocab_start_index, self.vocab_end_index = ( VocabUtility.vocab_range_from_global_vocab_size( self.num_embeddings, - constants.model_parallel_rank(), + constants.tensor_parallel_rank(), self.tensor_model_parallel_size, ) ) @@ -110,7 +110,7 @@ class ParallelEmbedding(torch.nn.Module): logger.debug( f"ParallelEmbedding: num_embeddings={num_embeddings}, per_partition={self.num_embeddings_per_partition}, embedding_dim={embedding_dim}," - f"tp_rank={constants.model_parallel_rank()},tp_world_size={constants.model_parallel_world_size()}" + f"tp_rank={constants.tensor_parallel_rank()},tp_world_size={constants.tensor_parallel_world_size()}" ) # Allocate weights and initialize. self.weight = Parameter( @@ -264,7 +264,7 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function): assert ( not ctx.async_grad_allreduce ), "async_grad_allreduce and sequence_parallel can not be both True" - world_size = constants.model_parallel_world_size() + world_size = constants.tensor_parallel_world_size() dim_size = list(input.size()) dim_size[0] = dim_size[0] * world_size @@ -272,7 +272,7 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function): dim_size, input.dtype, "mpu" ) torch.distributed._all_gather_base( - all_gather_buffer, input, group=constants.model_parallel_group() + all_gather_buffer, input, group=constants.tensor_parallel_group() ) total_input = all_gather_buffer else: @@ -290,7 +290,7 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function): use_bias = ctx.use_bias if ctx.sequence_parallel: - world_size = constants.model_parallel_world_size() + world_size = constants.tensor_parallel_world_size() dim_size = list(input.size()) dim_size[0] = dim_size[0] * world_size @@ -300,7 +300,7 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function): handle = torch.distributed._all_gather_base( all_gather_buffer, input, - group=constants.model_parallel_group(), + group=constants.tensor_parallel_group(), async_op=True, ) @@ -327,7 +327,7 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function): # Asynchronous all-reduce handle = torch.distributed.all_reduce( grad_input, - group=constants.model_parallel_group(), + group=constants.tensor_parallel_group(), async_op=True, ) # Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the @@ -346,7 +346,7 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function): handle = torch.distributed._reduce_scatter_base( sub_grad_input, grad_input, - group=constants.model_parallel_group(), + group=constants.tensor_parallel_group(), async_op=True, ) # Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the @@ -525,7 +525,7 @@ class MergedLinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Funct assert ( not ctx.async_grad_allreduce ), "async_grad_allreduce and sequence_parallel can not be both True" - world_size = constants.model_parallel_world_size() + world_size = constants.tensor_parallel_world_size() dim_size = list(input.size()) dim_size[0] = dim_size[0] * world_size @@ -533,7 +533,7 @@ class MergedLinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Funct dim_size, input.dtype, "mpu" ) torch.distributed._all_gather_base( - all_gather_buffer, input, group=constants.model_parallel_group() + all_gather_buffer, input, group=constants.tensor_parallel_group() ) total_input = all_gather_buffer else: @@ -557,7 +557,7 @@ class MergedLinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Funct is_w_parallel = ctx.is_w_parallel if ctx.sequence_parallel: - world_size = constants.model_parallel_world_size() + world_size = constants.tensor_parallel_world_size() dim_size = list(input.size()) dim_size[0] = dim_size[0] * world_size @@ -567,7 +567,7 @@ class MergedLinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Funct handle = torch.distributed._all_gather_base( all_gather_buffer, input, - group=constants.model_parallel_group(), + group=constants.tensor_parallel_group(), async_op=True, ) @@ -578,7 +578,7 @@ class MergedLinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Funct total_input = input grad_input = 0 for w, is_parallel, grad in zip(weights, is_w_parallel, grads): - if is_parallel or constants.model_parallel_rank() == 0: + if is_parallel or constants.tensor_parallel_rank() == 0: grad_input = grad_input + grad.matmul(w) if ctx.sequence_parallel: @@ -597,7 +597,7 @@ class MergedLinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Funct # Asynchronous all-reduce handle = torch.distributed.all_reduce( grad_input, - group=constants.model_parallel_group(), + group=constants.tensor_parallel_group(), async_op=True, ) # Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the @@ -616,7 +616,7 @@ class MergedLinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Funct handle = torch.distributed._reduce_scatter_base( sub_grad_input, grad_input, - group=constants.model_parallel_group(), + group=constants.tensor_parallel_group(), async_op=True, ) # Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the @@ -785,7 +785,7 @@ class ColumnParallelLinear(torch.nn.Module): self.output_size = output_size self.gather_output = gather_output # Divide the weight matrix along the last dimension. - world_size = constants.model_parallel_world_size() + world_size = constants.tensor_parallel_world_size() self.output_size_per_partition = divide(output_size, world_size) self.skip_bias_add = skip_bias_add self.is_expert = is_expert @@ -852,7 +852,7 @@ class ColumnParallelLinear(torch.nn.Module): # in expert MLPs always behave as sequence parallel is not enabled. sequence_parallel = constants.sequence_parallel() and not self.is_expert async_tensor_model_parallel_allreduce = ( - constants.model_parallel_world_size() > 1 and not sequence_parallel + constants.tensor_parallel_world_size() > 1 and not sequence_parallel ) if sequence_parallel: @@ -942,7 +942,7 @@ class RowParallelLinear(torch.nn.Module): self.output_size = output_size self.input_is_parallel = input_is_parallel # Divide the weight matrix along the last dimension. - world_size = constants.model_parallel_world_size() + world_size = constants.tensor_parallel_world_size() self.input_size_per_partition = divide(input_size, world_size) self.skip_bias_add = skip_bias_add self.gradient_accumulation_fusion = gradient_accumulation_fusion @@ -1030,9 +1030,9 @@ def parallel_lm_logits( bias=None, ): """LM logits using word embedding weights.""" - model_parallel = constants.model_parallel_world_size() > 1 + tensor_parallel = constants.tensor_parallel_world_size() > 1 sequence_parallel = constants.sequence_parallel() - async_grad_allreduce = not sequence_parallel and model_parallel + async_grad_allreduce = not sequence_parallel and tensor_parallel # Parallel logits. if sequence_parallel: input_parallel = input_ @@ -1066,7 +1066,7 @@ class _VocabParallelCrossEntropy(torch.autograd.Function): torch.distributed.all_reduce( logits_max, op=torch.distributed.ReduceOp.MAX, - group=constants.model_parallel_group(), + group=constants.tensor_parallel_group(), ) # Subtract the maximum value. vocab_parallel_logits = vocab_parallel_logits - logits_max.unsqueeze(dim=-1) @@ -1074,8 +1074,8 @@ class _VocabParallelCrossEntropy(torch.autograd.Function): # Get the partition's vocab indecies get_vocab_range = VocabUtility.vocab_range_from_per_partition_vocab_size partition_vocab_size = vocab_parallel_logits.size()[-1] - rank = constants.model_parallel_rank() - world_size = constants.model_parallel_world_size() + rank = constants.tensor_parallel_rank() + world_size = constants.tensor_parallel_world_size() vocab_start_index, vocab_end_index = get_vocab_range( partition_vocab_size, rank, world_size ) @@ -1101,7 +1101,7 @@ class _VocabParallelCrossEntropy(torch.autograd.Function): torch.distributed.all_reduce( predicted_logits, op=torch.distributed.ReduceOp.SUM, - group=constants.model_parallel_group(), + group=constants.tensor_parallel_group(), ) # Sum of exponential of logits along vocab dimension across all GPUs. @@ -1111,7 +1111,7 @@ class _VocabParallelCrossEntropy(torch.autograd.Function): torch.distributed.all_reduce( sum_exp_logits, op=torch.distributed.ReduceOp.SUM, - group=constants.model_parallel_group(), + group=constants.tensor_parallel_group(), ) # Loss = log(sum(exp(logits))) - predicted-logit. diff --git a/realhf/impl/model/parallelism/model_parallel/utils.py b/realhf/impl/model/parallelism/tensor_parallel/utils.py similarity index 93% rename from realhf/impl/model/parallelism/model_parallel/utils.py rename to realhf/impl/model/parallelism/tensor_parallel/utils.py index 4336a71..89d7ffa 100644 --- a/realhf/impl/model/parallelism/model_parallel/utils.py +++ b/realhf/impl/model/parallelism/tensor_parallel/utils.py @@ -6,11 +6,7 @@ from typing import List, Sequence import numpy as np import torch -from realhf.base.constants import ( - model_parallel_group, - model_parallel_rank, - model_parallel_world_size, -) +import realhf.base.constants as constants _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = { "tensor_model_parallel": False, @@ -22,7 +18,7 @@ _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = { def param_is_not_model_parallel_duplicate(param): return ( hasattr(param, "tensor_model_parallel") and param.tensor_model_parallel - ) or (model_parallel_rank() == 0) + ) or (constants.tensor_parallel_rank() == 0) def set_tensor_model_parallel_attributes(tensor, is_parallel, dim, stride): @@ -110,8 +106,8 @@ def split_tensor_into_1d_equal_chunks(tensor, new_buffer=False): If False, returns a view into the existing Tensor. Default is False """ - partition_size = torch.numel(tensor) // model_parallel_world_size() - start_index = partition_size * model_parallel_rank() + partition_size = torch.numel(tensor) // constants.tensor_parallel_world_size() + start_index = partition_size * constants.tensor_parallel_rank() end_index = start_index + partition_size if new_buffer: data = torch.empty( @@ -135,7 +131,7 @@ def gather_split_1d_tensor(tensor): Arguments: tensor: A Tensor or view of this rank's portion of the data. """ - numel_gathered = torch.numel(tensor) * model_parallel_world_size() + numel_gathered = torch.numel(tensor) * constants.tensor_parallel_world_size() gathered = torch.empty( numel_gathered, dtype=tensor.dtype, @@ -147,7 +143,9 @@ def gather_split_1d_tensor(tensor): # as opposed to torch.distributed.all_gather for efficiency reasons. # This API calls directly NCCL all-gather versus the former does # internal copies and can potentially cause slow down. - torch.distributed._all_gather_base(gathered, tensor, group=model_parallel_group()) + torch.distributed._all_gather_base( + gathered, tensor, group=constants.tensor_parallel_group() + ) return gathered diff --git a/realhf/impl/model/utils/functional.py b/realhf/impl/model/utils/functional.py index cbdba37..d80b87f 100644 --- a/realhf/impl/model/utils/functional.py +++ b/realhf/impl/model/utils/functional.py @@ -210,11 +210,11 @@ def gather_packed_shifted_log_probs( """ labels = torch.nn.functional.pad(labels[1:], (0, 1), value=0) leave_one_indices = build_leave_one_indices(logits, cu_seqlens) - if constants.model_parallel_world_size() > 1: + if constants.tensor_parallel_world_size() > 1: # NOTE: logprobs is freaking sensitive to input_ids. If the input sequence is a natural sequence, everything will be fine. # However, if we input random token IDs, parallel cross entropy can produce VERY different results than the normal # torch.gather based version (e.g., the maximum absolute different can reach ~50). - from realhf.impl.model.parallelism.model_parallel.modules import ( + from realhf.impl.model.parallelism.tensor_parallel.modules import ( vocab_parallel_cross_entropy, ) @@ -239,14 +239,16 @@ def gather_packed_shifted_log_probs( def apply_logits_mask(logits: torch.HalfTensor, mask: torch.BoolTensor): - assert mask.shape[-1] == logits.shape[-1] * constants.model_parallel_world_size(), ( - constants.model_parallel_world_size(), + assert ( + mask.shape[-1] == logits.shape[-1] * constants.tensor_parallel_world_size() + ), ( + constants.tensor_parallel_world_size(), logits.shape, mask.shape, ) parallel_vocab_size = logits.shape[-1] - mp_rank = constants.model_parallel_rank() - mask = mask[:, mp_rank * parallel_vocab_size : (mp_rank + 1) * parallel_vocab_size] + tp_rank = constants.tensor_parallel_rank() + mask = mask[:, tp_rank * parallel_vocab_size : (tp_rank + 1) * parallel_vocab_size] logits.masked_fill_(mask, torch.finfo(logits.dtype).min) diff --git a/realhf/impl/model/utils/padding.py b/realhf/impl/model/utils/padding.py index b34a85f..0f4d79e 100644 --- a/realhf/impl/model/utils/padding.py +++ b/realhf/impl/model/utils/padding.py @@ -250,7 +250,7 @@ def pad_sequence_parallel_input( ): """Sequence parallel requires packed_input_ids has a shape of 1 dimension [total_seq_len], and total_seq_len should be divisible by - model_parallel_world_size. This function is used to pad packed_input_ids to + tensor_parallel_world_size. This function is used to pad packed_input_ids to suitable length with an empty sequence, and return new packed_input_ids, cu_seqlens and max_seqlen. @@ -262,10 +262,10 @@ def pad_sequence_parallel_input( Returns: (torch.Tensor, torch.Tensor, int, int): padded (packed_input_ids, cu_seqlens, max_seqlen, pad_size) """ - mp_world_size = constants.model_parallel_world_size() + tp_world_size = constants.tensor_parallel_world_size() pad_size = 0 - if len(packed_input_ids) % mp_world_size != 0: - pad_size = mp_world_size - len(packed_input_ids) % mp_world_size + if len(packed_input_ids) % tp_world_size != 0: + pad_size = tp_world_size - len(packed_input_ids) % tp_world_size packed_input_ids = torch.nn.functional.pad( packed_input_ids, (0, pad_size), value=1 ) @@ -281,9 +281,9 @@ def pad_sequence_parallel_generate_input( ): """Only for pipeline generate input when model+seq parallel is enabled. To make sure inputs for seq parallel model have a shape with first dimension - divisible by model_parallel_world_size, the packed_input_ids should have - length divisible by model_parallel_world_size, and contains number of - sequences divisible by model_parallel_world_size. + divisible by tensor_parallel_world_size, the packed_input_ids should have + length divisible by tensor_parallel_world_size, and contains number of + sequences divisible by tensor_parallel_world_size. Args: packed_input_ids (torch.Tensor): unpadded packed_input_ids @@ -293,16 +293,16 @@ def pad_sequence_parallel_generate_input( Returns: (torch.Tensor, torch.Tensor, int, int, int): padded (packed_input_ids, cu_seqlens, max_seqlen, pad_size, pad_seq_size) """ - mp_world_size = constants.model_parallel_world_size() + tp_world_size = constants.tensor_parallel_world_size() pad_size, pad_seq_size = 0, 0 if ( - len(packed_input_ids) % mp_world_size != 0 - or (len(cu_seqlens) - 1) % mp_world_size != 0 + len(packed_input_ids) % tp_world_size != 0 + or (len(cu_seqlens) - 1) % tp_world_size != 0 ): - pad_size = mp_world_size - len(packed_input_ids) % mp_world_size - pad_seq_size = mp_world_size - (len(cu_seqlens) - 1) % mp_world_size + pad_size = tp_world_size - len(packed_input_ids) % tp_world_size + pad_seq_size = tp_world_size - (len(cu_seqlens) - 1) % tp_world_size if pad_size < pad_seq_size: - pad_size += mp_world_size + pad_size += tp_world_size pad_cu_seqlens = torch.tensor(list(range(1, pad_seq_size)) + [pad_size]) + len( packed_input_ids ) diff --git a/realhf/impl/model/utils/ppo_functional.py b/realhf/impl/model/utils/ppo_functional.py index 39b9976..a8c3211 100644 --- a/realhf/impl/model/utils/ppo_functional.py +++ b/realhf/impl/model/utils/ppo_functional.py @@ -55,6 +55,7 @@ def actor_loss_fn( eps_clip: float, loss_mask: Optional[torch.BoolTensor] = None, c_clip: Optional[float] = None, + proximal_logprobs: Optional[torch.FloatTensor] = None, ) -> Tuple[torch.Tensor, Dict]: """Compute PPO actor loss function. @@ -83,13 +84,22 @@ def actor_loss_fn( old_logprobs = old_logprobs.clone() if advantages.is_inference(): advantages = advantages.clone() - - if loss_mask is not None: - loss_mask_count = loss_mask.count_nonzero() or 1 - # For numerical stability. - ratio = torch.where(loss_mask, torch.exp(logprobs - old_logprobs), 0) + if proximal_logprobs is not None: + assert proximal_logprobs.dtype == torch.float32 + if proximal_logprobs.is_inference(): + proximal_logprobs = proximal_logprobs.clone() + denorm_logprobs = proximal_logprobs else: - ratio = torch.exp(logprobs - old_logprobs) + denorm_logprobs = old_logprobs + + # create mask + if loss_mask is None: + loss_mask = torch.ones_like(logprobs, dtype=torch.bool) + loss_mask: torch.BoolTensor + + loss_mask_count = loss_mask.count_nonzero() or 1 + # For numerical stability. + ratio = torch.where(loss_mask, torch.exp(logprobs - denorm_logprobs), 0) clipped_ratio = torch.clamp(ratio, 1.0 - eps_clip, 1.0 + eps_clip) pg_loss1 = -advantages * ratio @@ -104,24 +114,34 @@ def actor_loss_fn( pg_loss = torch.min(pg_loss, pg_loss3) else: dual_clip_mask = torch.zeros_like(clip_mask) + if proximal_logprobs is not None: + behav_kl = proximal_logprobs - old_logprobs + behav_imp_weight = behav_kl.exp() + if c_clip is not None: + behav_mask = (behav_imp_weight <= c_clip).logical_and(loss_mask) + else: + behav_mask = loss_mask + behav_kl = torch.where(behav_mask, behav_kl, 0.0) + behav_imp_weight = torch.where(behav_mask, behav_imp_weight, 0.0) + pg_loss = pg_loss * behav_imp_weight logging_loss = pg_loss.detach() - if loss_mask is not None: - pg_loss = torch.where(loss_mask, pg_loss, 0).sum() / loss_mask_count - else: - pg_loss = pg_loss.mean() + pg_loss = torch.where(loss_mask, pg_loss, 0).sum() / loss_mask_count - if loss_mask is not None: - clip_mask.logical_and_(loss_mask) - dual_clip_mask.logical_and_(loss_mask) + clip_mask.logical_and_(loss_mask) + dual_clip_mask.logical_and_(loss_mask) # Remain torch.CudaTensor here for all-reduce after train step. stat = dict( loss=logging_loss, importance_weight=ratio.detach(), - approx_kl=(logprobs - old_logprobs).detach(), + approx_kl=(logprobs - denorm_logprobs).detach(), clip_mask=clip_mask, dual_clip_mask=dual_clip_mask, ) + if proximal_logprobs is not None: + stat["behave_imp_weight"] = behav_imp_weight + stat["behave_approx_kl"] = behav_kl + stat["behave_mask"] = behav_mask return pg_loss, stat diff --git a/realhf/impl/model/utils/random.py b/realhf/impl/model/utils/random.py index 2870ed7..434f70a 100644 --- a/realhf/impl/model/utils/random.py +++ b/realhf/impl/model/utils/random.py @@ -13,7 +13,7 @@ from torch.cuda import device as device_ctx_manager from torch.utils.checkpoint import detach_variable import realhf.base.constants as constants -from realhf.impl.model.parallelism.model_parallel.utils import ( +from realhf.impl.model.parallelism.tensor_parallel.utils import ( divide, gather_split_1d_tensor, safely_set_viewless_tensor_data, @@ -169,10 +169,10 @@ def model_parallel_cuda_manual_seed(seed): tensor-model-parallel state: This state is different among a set of model parallel GPUs, but the same across data parallel groups. This is used for example for dropout in model parallel regions. """ # 2718 is just for fun and any POSITIVE value will work. - model_parallel_rank = constants.model_parallel_rank() + tensor_parallel_rank = constants.tensor_parallel_rank() expert_parallel_rank = 0 offset = seed + 2718 - tensor_model_parallel_seed = offset + model_parallel_rank + tensor_model_parallel_seed = offset + tensor_parallel_rank # Data parallel gets the original seed. data_parallel_seed = seed @@ -187,7 +187,7 @@ def model_parallel_cuda_manual_seed(seed): ) expert_parallel_seed = ( - seed + 1024 + 100 * expert_parallel_rank + model_parallel_rank + seed + 1024 + 100 * expert_parallel_rank + tensor_parallel_rank ) _CUDA_RNG_STATE_TRACKER.add(_EXPERT_PARALLEL_RNG_TRACKER_NAME, expert_parallel_seed) @@ -331,8 +331,8 @@ def _initialize_affine_weight_cpu( weight_list = torch.split( master_weight, per_partition_per_stride_size, dim=partition_dim ) - rank = constants.model_parallel_rank() - world_size = constants.model_parallel_world_size() + rank = constants.tensor_parallel_rank() + world_size = constants.tensor_parallel_world_size() my_weight_list = weight_list[rank::world_size] with torch.no_grad(): diff --git a/realhf/scheduler/slurm/client.py b/realhf/scheduler/slurm/client.py index f33d0ce..a586382 100644 --- a/realhf/scheduler/slurm/client.py +++ b/realhf/scheduler/slurm/client.py @@ -5,13 +5,18 @@ import fcntl import os import re +import select import subprocess +import threading import time from collections import defaultdict from typing import Dict, List, Literal, Optional, Tuple +import colorama + import realhf.base.logging as logging from realhf.base.cluster import spec as cluster_spec +from realhf.base.constants import LOG_ROOT from realhf.base.constants import SLURM_LOCK_FILE_NAME as LOCK_FILE_NAME from realhf.scheduler.client import JobException, JobInfo, JobState, SchedulerClient from realhf.scheduler.evaluator import AutomaticEvaluator @@ -29,6 +34,49 @@ SCHEDULING_TIMEOUT_MAX_SECONDS = 3600 * 24 SCHEDULER_WAIT_CHECK_TIME_INTERVAL = 5 +def monitor_log( + job_name: str, log_path: str, output_file: str, stop_event: threading.Event +): + """Monitor a log file and write its contents to the output file with job name prefix.""" + # Wait for log file to be created + while not os.path.exists(log_path) and not stop_event.is_set(): + time.sleep(0.1) + + if stop_event.is_set(): + return + + # Open the log file and follow it + with open(log_path, "r") as log_file, open(output_file, "a") as out_file: + # Store last position + position = 0 + line_pos = 0 + + while not stop_event.is_set(): + log_file.seek(position) + try: + new_lines = log_file.readlines() + except UnicodeDecodeError: + time.sleep(0.5) + continue + + if new_lines: + # Update position + position = log_file.tell() + + worker_type = job_name.split(":")[1] + # Write new lines to output file with job name prefix + for line in new_lines: + if line.strip(): # Skip empty lines + out_file.write( + f"{colorama.Fore.YELLOW + colorama.Style.DIM}({worker_type} Line {line_pos}){colorama.Style.RESET_ALL} {line}" + ) + line_pos += 1 + out_file.flush() + + # Sleep briefly to avoid CPU spinning + time.sleep(0.1) + + class SlurmSchedulerClient(SchedulerClient): """Uses Slurm (https://slurm.schedmd.com/overview.html).""" @@ -248,6 +296,26 @@ class SlurmSchedulerClient(SchedulerClient): # before wait, commit all remaining pending jobs # TODO: grab global file lock to avoid multi-experiment deadlocks self.__allocate_and_commit_pending_jobs() + # Start monitoring threads + threads = [] + stop_events = [] + + merged_log_path = os.path.join( + LOG_ROOT, self.expr_name, self.trial_name, "main.log" + ) + + for job_name, launch_info in self.__committed_jobs.items(): + stop_event = threading.Event() + stop_events.append(stop_event) + + # Thread for monitoring the log file + log_thread = threading.Thread( + target=monitor_log, + args=(job_name, launch_info.log_path, merged_log_path, stop_event), + ) + threads.append(log_thread) + log_thread.start() + # begin wait deadline = None if timeout is None else time.time() + timeout left = set(self.__committed_jobs.keys()) @@ -256,44 +324,52 @@ class SlurmSchedulerClient(SchedulerClient): f"Waiting for {num_jobs_left} jobs. Jobs IDs: " f"{','.join(sorted([x.job_info.slurm_id for x in self.__committed_jobs.values()]))}." ) - while len(left) > 0: - if len(left) < num_jobs_left: - num_jobs_left = len(left) - logger.info(f"Waiting for {num_jobs_left} jobs.") - if self.__evaluator is not None: - self.__evaluator.step() - if deadline is not None and time.time() > deadline: - raise TimeoutError( - f"Timeout waiting for {self.run_name}: {', '.join(sorted(left))}" - ) - try: - self.__update_all() - except subprocess.CalledProcessError: - logger.warning( - "Calling squeue failed. Check slurm manually if you continue to see this warning." - ) - time.sleep(30) - continue - for job_slurm_name in list(left): - launch_info = self.__committed_jobs[job_slurm_name] - if launch_info.slurm_id is None: + logger.info( + f"All slurm logs will be merged. To check the real-time output, " + f"run\n\t`tail -f {merged_log_path}`." + ) + try: + while len(left) > 0: + if len(left) < num_jobs_left: + num_jobs_left = len(left) + logger.info(f"Waiting for {num_jobs_left} jobs.") + if self.__evaluator is not None: + self.__evaluator.step() + if deadline is not None and time.time() > deadline: + raise TimeoutError( + f"Timeout waiting for {self.run_name}: {', '.join(sorted(left))}" + ) + try: + self.__update_all() + except subprocess.CalledProcessError: + logger.warning( + "Calling squeue failed. Check slurm manually if you continue to see this warning." + ) + time.sleep(30) continue - if launch_info.job_info.state in check_status: - launch_info.show_log() - raise JobException( - run_name=self.run_name, - worker_type=launch_info.worker_type, - host=launch_info.job_info.host, - reason=launch_info.job_info.state, - ) - if launch_info.job_info.state in remove_status: - logger.info( - f"Job {launch_info.slurm_name} is {launch_info.job_info.state}.(Removed)" - ) - left.remove(job_slurm_name) - if update: - self.__committed_jobs.pop(job_slurm_name) - time.sleep(SCHEDULER_WAIT_CHECK_TIME_INTERVAL) + for job_slurm_name in list(left): + launch_info = self.__committed_jobs[job_slurm_name] + if launch_info.slurm_id is None: + continue + if launch_info.job_info.state in check_status: + launch_info.show_log() + raise JobException( + run_name=self.run_name, + worker_type=launch_info.worker_type, + host=launch_info.job_info.host, + reason=launch_info.job_info.state, + ) + if launch_info.job_info.state in remove_status: + logger.info( + f"Job {launch_info.slurm_name} is {launch_info.job_info.state}.(Removed)" + ) + left.remove(job_slurm_name) + if update: + self.__committed_jobs.pop(job_slurm_name) + time.sleep(SCHEDULER_WAIT_CHECK_TIME_INTERVAL) + finally: + [s.set() for s in stop_events] + [t.join() for t in threads] def __update_all(self): states = [] diff --git a/realhf/scheduler/slurm/utils.py b/realhf/scheduler/slurm/utils.py index 2b729f4..9cb79b7 100644 --- a/realhf/scheduler/slurm/utils.py +++ b/realhf/scheduler/slurm/utils.py @@ -294,21 +294,29 @@ class SlurmLaunchInfo: @property def multiprog_path(self) -> str: - return os.path.join( + path = os.path.join( LOG_ROOT, self.exper_name, self.trial_name, + "slurm", + "multiprog", f"{self.worker_type}-{self.worker_submission_idx}.multiprog", ) + os.makedirs(os.path.dirname(path), exist_ok=True) + return path @property def hostfile_path(self) -> str: - return os.path.join( + path = os.path.join( LOG_ROOT, self.exper_name, self.trial_name, + "slurm", + "hostfile", f"{self.worker_type}-{self.worker_submission_idx}.hostfile", ) + os.makedirs(os.path.dirname(path), exist_ok=True) + return path def show_log(self): try: diff --git a/realhf/search_engine/__init__.py b/realhf/search_engine/__init__.py deleted file mode 100644 index b2c9d3e..0000000 --- a/realhf/search_engine/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -# Copyright 2025 Ant Group Inc. -# Copyright 2024 Wei Fu & Zhiyu Mei -# Licensed under the Apache License, Version 2.0 (the "License"). - - -def import_profiler_registers(): - import realhf.search_engine.enumerate - import realhf.search_engine.estimate - import realhf.search_engine.layers - import realhf.search_engine.param_realloc - import realhf.search_engine.search - import realhf.search_engine.utils diff --git a/realhf/search_engine/enumerate.py b/realhf/search_engine/enumerate.py deleted file mode 100644 index f949768..0000000 --- a/realhf/search_engine/enumerate.py +++ /dev/null @@ -1,158 +0,0 @@ -# Copyright 2025 Ant Group Inc. -# Copyright 2024 Wei Fu & Zhiyu Mei -# Licensed under the Apache License, Version 2.0 (the "License"). - -from typing import List - -from realhf.api.core.dfg import MFCDef, ModelInterfaceType -from realhf.api.core.dfg import build_graph as build_dfg -from realhf.api.quickstart.device_mesh import DeviceMesh, find_parallel_strategies -from realhf.api.quickstart.search import MFCDef, RPCExecution, RPCInstance -from realhf.search_engine.estimate import ( - estimate_rpc_memory_cost, - estimate_rpc_time_cost, -) - -MEM_INDEX = 1.0 # heuristic value to scale estimated memory - - -def enumerate_rpc_executions( - rpc: MFCDef, - device_mesh: DeviceMesh, - seq_len: int, - num_gen_tokens: int, - n_ppo_minibatches: int, - gradient_checkpointing: bool, -) -> List[RPCExecution]: - sub_device_meshes = device_mesh.sub_device_meshes() - import pprint - - feasible = [] - for sub_device_mesh in sub_device_meshes: - ps = find_parallel_strategies(sub_device_mesh) - for parallel in ps: - num_dp = parallel.data_parallel_size - num_pp = parallel.pipeline_parallel_size - num_mp = parallel.model_parallel_size - bs = rpc.n_seqs - # seq_len = seq_len - min_bs = ( - 2 * num_dp * num_pp * n_ppo_minibatches - if rpc.interface_type == ModelInterfaceType.TRAIN_STEP - else num_dp * num_pp - ) - if min_bs > bs: - # batch size too small - continue - # heuristic to filter out inherent slow configurations - if ( - num_mp * num_dp > device_mesh.n_gpus_per_node - and rpc.interface_type == ModelInterfaceType.TRAIN_STEP - ): - continue - if num_mp > 8: - continue - if num_pp > max(device_mesh.n_nodes, 8): - continue - # memory and time estimation - mem_cost, static_mem = estimate_rpc_memory_cost( - rpc, - parallel, - bs, - seq_len, - gradient_checkpointing=gradient_checkpointing, - n_ppo_minibatches=n_ppo_minibatches, - num_gen_tokens=num_gen_tokens, - offload=rpc.model_name.role in ["ref", "reward"], - ) - mem_cost = int(mem_cost * MEM_INDEX) - static_mem = int(static_mem * MEM_INDEX) - time_cost = estimate_rpc_time_cost( - rpc, - parallel, - bs=bs, - seq_len=seq_len, - num_gen_tokens=num_gen_tokens, - gradient_checkpointing=gradient_checkpointing, - n_ppo_minibatches=n_ppo_minibatches, - ) - time_cost = int(time_cost) - if mem_cost < device_mesh.gpu_memory_capacity: - feasible.append( - RPCExecution( - rpc, - sub_device_mesh, - parallel, - time_cost, - mem_cost, - static_mem, - ) - ) - return feasible - - -def build_graph( - rpcs: List[MFCDef], - num_epoch: int = 5, - epoch_dependency_interval: int = 1, - if_print=False, -) -> List[RPCInstance]: - """Build model function call graph of multiple training epochs, - - args: - exp: ProfileExperiment, the experiment object - num_epoch: int, number of training epochs - epoch_dependency_interval: int, the interval of epoch dependency, - e.g. if epoch_dependency_interval = 2, then the graph will have - edges between epoch i and epoch i+2, i+4, ... - - returns: - rpc_instances: List[RPCInstance], the list of RPCInstance objects - """ - # one epoch dependency graph - rpcs, edges = build_dfg(rpcs) - rpc_names_mapping = {rpc.name: rpc for rpc in rpcs} - rpc_instances = [] - - # multi epoch graph - for epoch_id in range(num_epoch): - for rpc in rpcs: - children = [] - parents = [] - if rpc.is_src and epoch_id >= epoch_dependency_interval: - for other in rpcs: - if other.is_dst and other.model_name.role == rpc.model_name.role: - parents.append( - RPCInstance( - rpc, - epoch_id - epoch_dependency_interval, - [], - [], - ) - ) - if rpc.is_dst and rpc.model_name.role == rpc.model_name.role: - for other in rpcs: - if ( - other.is_src - and epoch_id + epoch_dependency_interval < num_epoch - ): - children.append( - RPCInstance( - rpc, - epoch_id + epoch_dependency_interval, - [], - [], - ) - ) - for parent in rpc.parents: - p = rpc_names_mapping[parent] - parents.append(RPCInstance(p, epoch_id, [], [])) - for child in rpc.children: - c = rpc_names_mapping[child] - children.append(RPCInstance(c, epoch_id, [], [])) - rpc_instance = RPCInstance(rpc, epoch_id, parents, children) - rpc_instances.append(rpc_instance) - if if_print: - for ri in rpc_instances: - print(ri) - return rpc_instances diff --git a/realhf/search_engine/estimate.py b/realhf/search_engine/estimate.py deleted file mode 100644 index a5b5db7..0000000 --- a/realhf/search_engine/estimate.py +++ /dev/null @@ -1,499 +0,0 @@ -# Copyright 2025 Ant Group Inc. -# Copyright 2024 Wei Fu & Zhiyu Mei -# Licensed under the Apache License, Version 2.0 (the "License"). - -# Estimate a fucntion-call level execution time for device mesh enumerate pruning -# assume one batch of data passes through all rpcs once -import argparse -import getpass -import itertools -import os -import pickle -from collections import defaultdict -from typing import Optional - -import numpy as np -import pandas as pd - -import realhf.base.cluster -import realhf.base.constants as constants -import realhf.base.logging as logging -from realhf.api.cli_args import ParallelismConfig -from realhf.api.core.dfg import MFCDef, ModelFamily, ModelInterfaceType -from realhf.api.core.model_api import ReaLModelConfig -from realhf.search_engine.param_realloc import estimate_param_realloc_time_cost -from realhf.search_engine.utils import load_model_config - -logger = logging.getLogger("estimate", "benchmark") - -PROFILE_RESULT_PATH = os.path.join( - realhf.base.cluster.spec.fileroot, - "logs", - getpass.getuser(), - "profile", - "profile", - "layer_stats", -) - - -def get_param_realloc_stats( - model_family: ModelFamily, - model_path: str, - n_nodes: int, - use_cache: bool = True, -): - non_critic = ModelFamily(model_family._class, model_family.size, False) - table_path = os.path.join( - constants.PROFILER_CACHE_PATH, - "param_realloc", - f"prtc_{non_critic}_n{n_nodes}.pkl", - ) - if not os.path.exists(table_path): - print( - f"Calculating estimation of param realloc time cost for {model_family} at {model_path}" - ) - estimate_param_realloc_time_cost(n_nodes, {non_critic: model_path}) - - print(f"Loading param realloc stats from {table_path}") - return pickle.load(open(table_path, "rb")) - - -def get_organized_op_stats( - model_family: ModelFamily, model_path: str, use_cache: bool = True -): - non_critic = ModelFamily(model_family._class, model_family.size, False) - # parse raw stats into list of OpInfo used for estimation - cache_path = os.path.join( - constants.PROFILER_CACHE_PATH, "organized_stats", f"{non_critic}.pkl" - ) - if use_cache and os.path.exists(cache_path): - with open(cache_path, "rb") as f: - return pickle.load(f) - - raw_result_path = os.path.join( - constants.PROFILER_CACHE_PATH, "layer_stats", str(non_critic) - ) - if not os.path.exists(raw_result_path): - from realhf.apps.main import _main_profile_layers - - _main_profile_layers(non_critic, model_path) - - raw_stats_list = [] - for fn in os.listdir(raw_result_path): - if not fn.startswith("layer-stats"): - continue - num_mp = int(fn.replace(".pkl", "").split("_")[1]) - rank = int(fn.replace(".pkl", "").split("_")[2]) - with open(os.path.join(raw_result_path, fn), "rb") as f: - stats = pickle.load(f) - - if isinstance(stats, dict): - stats = pd.DataFrame(stats) - elif isinstance(stats, pd.DataFrame): - pass - else: - raise ValueError(f"Unsupported stats type {type(stats)}") - - stats["num_mp"] = num_mp - stats["rank"] = rank - raw_stats_list.append(stats) - - raw_stats = pd.concat(raw_stats_list) - bs_list = raw_stats["bs"].unique() - seq_len_list = raw_stats["seq_len"].unique() - op_name_list = raw_stats["op_name"].unique() - layer_name_list = raw_stats["layer_name"].unique() - num_mp_list = raw_stats["num_mp"].unique() - organized_stats = defaultdict(list) - - for op_name, bs, seq_len, layer_name, num_mp in itertools.product( - op_name_list, bs_list, seq_len_list, layer_name_list, num_mp_list - ): - filter_cond = ( - (raw_stats["op_name"] == op_name) - & (raw_stats["bs"] == bs) - & (raw_stats["seq_len"] == seq_len) - & (raw_stats["layer_name"] == layer_name) - & (raw_stats["num_mp"] == num_mp) - ) - avg_time_ns = raw_stats[filter_cond]["time_ns"].mean() - x = int(bs) if op_name == "fwd_gen_1" else int(bs * seq_len) - - organized_stats["op_name"].append(op_name) - organized_stats["layer_name"].append(layer_name) - organized_stats["bs"].append(bs) - organized_stats["seq_len"].append(seq_len) - organized_stats["num_mp"].append(num_mp) - organized_stats["avg_time_ns"].append(avg_time_ns) - organized_stats["x"].append(x) - - df = pd.DataFrame(organized_stats) - if use_cache: - os.makedirs(os.path.dirname(cache_path), exist_ok=True) - with open(cache_path, "wb") as f: - pickle.dump(df, f) - - return df - - -def computation_instruction_time_cost( - op_stats: pd.DataFrame, - op_name: str, - num_layers: int, - parallel_strategy: ParallelismConfig, - bs: int, - seqlen: int, -): - # inst cost unit: ns - layer_names = ["embedding_layer", "block_0", "head"] - num_pp = parallel_strategy.pipeline_parallel_size - num_mp = parallel_strategy.model_parallel_size - op_stats = op_stats[ - (op_stats["op_name"] == op_name) & (op_stats["num_mp"] == num_mp) - ] - - op_cost = {} - embed_stats = op_stats[op_stats["layer_name"] == "embedding_layer"] - if embed_stats[ - (embed_stats["bs"] == bs) & (embed_stats["seq_len"] == seqlen) - ].empty: - # do linear interpolation for data points that does not exist - for layer_name in layer_names: - layer_stats = op_stats[op_stats["layer_name"] == layer_name] - assert layer_stats[ - (layer_stats["bs"] == bs) & (layer_stats["seq_len"] == seqlen) - ].empty - assert not layer_stats.empty, ( - layer_name, - op_name, - num_mp, - op_stats, - ) - xs = layer_stats["x"] - ys = layer_stats["avg_time_ns"] - x = int(bs) if op_name == "fwd_gen_1" else int(bs * seqlen) - y = np.interp(x, xs, ys) - if max(xs) < x or min(xs) > x: - logger.warning( - f"Interpolated value outside profiling range, " - f"parallel strategy {parallel_strategy}: " - f"{x} in {sorted(list(set(xs)))}" - ) - # estimate using largest or smallest value - if max(xs) < x: - y = ys.max() * (x / xs[ys.idxmax()]) - else: - y = ys.min() * (x / xs[ys.idxmin()]) - op_cost[layer_name] = y - else: - for layer_name in layer_names: - assert not op_stats[op_stats["layer_name"] == layer_name].empty - required_stats = op_stats[ - (op_stats["layer_name"] == layer_name) - & (op_stats["bs"] == bs) - & (op_stats["seq_len"] == seqlen) - ] - assert required_stats.shape[0] == 1 - op_cost[layer_name] = required_stats["avg_time_ns"].values[0] - - embedding_layer_cost = op_cost["embedding_layer"] - block_0_cost = op_cost["block_0"] - head_cost = op_cost["head"] - cost = (embedding_layer_cost + num_layers * block_0_cost + head_cost) / num_pp - return cost - - -def communication_instruction_time_cost(comm_stats, size, comm_type): - return size / comm_stats[comm_type] # unit: ns - - -def estimate_instruction_time_costs( - model_family: ModelFamily, - model_path: str, - num_layers: int, # model configuration, num transformers layer - parallel_strategy: ParallelismConfig, - hidden_dim: int, - batch_size: int, - seq_len: int, - n_ppo_minibatches: int = 1, -): - comm_stats = default_communication_stats() - op_stats = get_organized_op_stats(model_family, model_path, use_cache=True) - - num_mp = parallel_strategy.model_parallel_size - num_pp = parallel_strategy.pipeline_parallel_size - num_dp = parallel_strategy.data_parallel_size - num_gpus = num_dp * num_mp * num_pp - - train_mbs = ( - batch_size / (2 * num_pp * num_dp * n_ppo_minibatches) - if num_pp > 1 - else batch_size / (num_dp * n_ppo_minibatches) - ) - gen_mbs = batch_size / (num_pp * num_dp) - - # pprint.pprint(op_cost, indent=4) - inst_keys = [ - "gen_fwd_0", - "gen_fwd_1", - "inf_fwd", - "train_fwd", - "train_bwd", - "train_opt", - ] - op_names = ["fwd_gen_0", "fwd_gen_1", "fwd", "fwd", "bwd", "opt"] - inst_stats = {} - - for inst_key, op_name in zip(inst_keys, op_names): - mbs = train_mbs if "train" in inst_key else gen_mbs - inst_stats[inst_key] = computation_instruction_time_cost( - op_stats, op_name, num_layers, parallel_strategy, mbs, seq_len - ) - - comm_type = "remote_send" if num_gpus // num_pp >= 8 else "local_send" - - inst_stats["act_p2p"] = communication_instruction_time_cost( - comm_stats, 2 * hidden_dim * train_mbs * seq_len, comm_type - ) - inst_stats["grad_p2p"] = communication_instruction_time_cost( - comm_stats, 2 * hidden_dim * train_mbs * seq_len, comm_type - ) - inst_stats["gen_act_p2p"] = communication_instruction_time_cost( - comm_stats, 2 * hidden_dim * gen_mbs, comm_type - ) - return inst_stats - - -def _estimate_rpc_time_cost( - inst_stats, - parallel_strategy: ParallelismConfig, - model_interface_type: ModelInterfaceType, - # model function call args - num_gen_tokens: int, - gradient_checkpointing: bool, - n_ppo_minibatches: int = 1, -): - # TODO: improve/remove heuristic - num_pp = parallel_strategy.pipeline_parallel_size - num_dp = parallel_strategy.data_parallel_size - num_mp = parallel_strategy.model_parallel_size - if model_interface_type == ModelInterfaceType.INFERENCE: - num_micro_batches = num_pp - compute_cost = inst_stats["inf_fwd"] * (num_pp + num_micro_batches - 1) - comm_cost = inst_stats["act_p2p"] * (num_pp + num_micro_batches - 2) * 2 - if num_mp > 1: # and num_pp * num_dp > 1: - compute_cost = compute_cost * (1 - num_mp * 0.03) - elif model_interface_type == ModelInterfaceType.TRAIN_STEP: - # TODO: add reduce grads, add ppo micro batches - num_micro_batches = num_pp * 2 if num_pp > 1 else 1 - compute_cost = (inst_stats["train_fwd"] + inst_stats["train_bwd"]) * ( - num_pp + num_micro_batches - 1 - ) + inst_stats["train_opt"] - if gradient_checkpointing: - compute_cost += inst_stats["train_fwd"] * (num_pp + num_micro_batches - 1) - comm_cost = ( - (inst_stats["grad_p2p"] + inst_stats["act_p2p"]) - * (num_pp + num_micro_batches - 2) - * 2 - ) - compute_cost = compute_cost * n_ppo_minibatches - comm_cost = comm_cost * n_ppo_minibatches - if num_pp * num_dp <= 1: - compute_cost = compute_cost * (1 - num_mp * 0.04) - if num_mp > 1: # and num_pp * num_dp > 1: - compute_cost = compute_cost * (1 - num_mp * 0.03) - elif model_interface_type == ModelInterfaceType.GENERATE: - num_micro_batches = num_pp - num_gen_tokens = num_gen_tokens - compute_cost = ( - inst_stats["gen_fwd_0"] * (num_pp + num_micro_batches - 1) - + inst_stats["gen_fwd_1"] * (num_gen_tokens - 1) * num_micro_batches - ) - - if num_dp * num_mp > 1: - compute_cost = compute_cost * (1 - min(num_dp * num_mp, 8) * 0.03) - comm_cost = 0 - - # dirty heuristic - if num_pp > 8: - compute_cost = compute_cost * (1 + num_pp * 0.01) - - # FIXME: disable comm cost for its not accurate - comm_cost = 0 - - return compute_cost + comm_cost - - -def estimate_rpc_time_cost( - rpc: MFCDef, - parallel_strategy: ParallelismConfig, - bs: int, - seq_len: int, - num_gen_tokens: int = 256, - gradient_checkpointing: bool = False, - n_ppo_minibatches: int = 1, -): - # time unit: miliseconds - # FIXME: n_ppo_minibatches > 1 will result in bad estimation - # when batch size is large enough, n_ppo_minibatches > 1 will not affect the estimation - n_ppo_minibatches = 1 - model_type = rpc.model_type - model_path = rpc.model_path - model_config = load_model_config(rpc.model_type._class, rpc.model_path) - - inst_cost = estimate_instruction_time_costs( - model_type, - model_path, - model_config.n_layers, - parallel_strategy, - model_config.hidden_dim, - bs, - seq_len, - n_ppo_minibatches=n_ppo_minibatches, - ) - return ( - _estimate_rpc_time_cost( - inst_cost, - parallel_strategy, - rpc.interface_type, - num_gen_tokens=num_gen_tokens, - gradient_checkpointing=gradient_checkpointing, - n_ppo_minibatches=n_ppo_minibatches, - ) - ) / 1e6 - - -def default_communication_stats(if_print=False): - # use default comm stats of cluster - r = dict( - # between GPUs on the same node - local_send=170, # unit: GB/s - local_recv=170, - # IB between GPUs on different nodes - remote_send=20, # unit: GB/s - remote_recv=20, - ) - if if_print: - print(r) - return r - - -def estimate_model_size(model_config: ReaLModelConfig): - h = model_config.hidden_dim - i = model_config.intermediate_dim - v = model_config.vocab_size - L = model_config.n_layers - # for llama actor only - n_params = 3 * v * h + (3 * h * i + 4 * h * h) * L - return 2 * n_params - - -def estimate_rpc_memory_cost( - rpc: MFCDef, - parallel_strategy: ParallelismConfig, - batch_size: int, - seq_len: int, - offload: bool = False, - gradient_checkpointing: bool = False, - offload_optimizer: bool = False, - n_ppo_minibatches: int = 1, - num_gen_tokens: int = 128, -): - # TODO: improve heuristic - interface_type = rpc.interface_type - model_config = load_model_config(rpc.model_type._class, rpc.model_path) - - h = model_config.hidden_dim - i = model_config.intermediate_dim - v = model_config.vocab_size - s = seq_len - gs = num_gen_tokens - b = batch_size - L = model_config.n_layers - # for llama actor only - n_params = 3 * v * h + (3 * h * i + 4 * h * h) * L - param_mem = 2 * n_params - grad_mem = 2 * n_params - optimizer_mem = 20 * n_params if not offload_optimizer else 0 - - num_pp = parallel_strategy.pipeline_parallel_size - num_mp = parallel_strategy.model_parallel_size - num_dp = parallel_strategy.data_parallel_size - # zero1, pp and mp divide evenly - # enable sequence parallel - if interface_type == ModelInterfaceType.TRAIN_STEP: - # gradient checkpointing is always enabled for flash attn - static_mem = (param_mem + grad_mem) // (num_pp * num_mp) + optimizer_mem // ( - num_pp * num_dp * num_mp - ) - micro_bs = b // (2 * num_pp * num_dp) if num_pp > 0 else b // (num_dp) - if gradient_checkpointing: - active_mem = (micro_bs * s * h * num_pp * 2) // (num_pp * num_mp) - else: - # FIXME: calculate other memory entries - active_mem = (micro_bs * s * h * num_pp * 2) * 2 * L // (num_pp * num_mp) - return static_mem + active_mem, static_mem - elif interface_type == ModelInterfaceType.INFERENCE: - static_mem = int(2 * param_mem // (num_pp * num_mp)) - # if num_dp > 4: - # static_mem = static_mem * 1.25 - if offload: - return static_mem, 0 # assume offload - else: - return static_mem, static_mem - elif interface_type == ModelInterfaceType.GENERATE: - static_mem = int(2 * param_mem // (num_pp * num_mp)) - if num_dp > 4 and num_dp * num_mp * num_pp <= 16: - static_mem = static_mem * 1.25 - if num_mp == 0 and num_pp == 0: - static_mem = static_mem * 1.25 - active_mem = ( - 2 * (2 * b * (gs + s) * h) * L // (num_pp * num_mp * num_dp) - ) # kv cache - return static_mem + active_mem, static_mem - - -def example(rpcs): - if args.model_size == 7: - n_nodes = 1 - elif args.model_size == 13: - n_nodes = 2 - elif args.model_size == 34: - n_nodes = 4 - elif args.model_size == 70: - n_nodes = 8 - - expr_name = f"profile-s{args.model_size}p{n_nodes}m1d8" - - rollout, inf, train = rpcs - - bs = 128 - seq_len = 1024 - - p1 = ParallelismConfig( - pipeline_parallel_size=1, model_parallel_size=4, data_parallel_size=8 - ) - rpc_cost = estimate_rpc_time_cost( - train, - p1, - gradient_checkpointing=True, - num_gen_tokens=896, - bs=bs, - seq_len=seq_len, - n_ppo_minibatches=4, - ) - mem_cost, static_mem = estimate_rpc_memory_cost(rollout, p1, bs, seq_len) - print(f"{p1} rpc cost {rpc_cost:.2f} seconds mem cost {mem_cost/(1024**3):.2f} GB") - - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Run a profiling experiment.") - parser.add_argument( - "-s", - "--model_size", - type=int, - default=7, - ) - args = parser.parse_args() - - example(args) diff --git a/realhf/search_engine/layers.py b/realhf/search_engine/layers.py deleted file mode 100644 index e1ab877..0000000 --- a/realhf/search_engine/layers.py +++ /dev/null @@ -1,310 +0,0 @@ -# Copyright 2025 Ant Group Inc. -# Copyright 2024 Wei Fu & Zhiyu Mei -# Licensed under the Apache License, Version 2.0 (the "License"). - -import os -import pickle -import time -from collections import defaultdict -from typing import List, Optional, Union - -import pandas as pd -import torch -import torch.distributed as dist -import transformers - -import realhf.api.core.model_api as model_api -import realhf.api.core.system_api as config_package -import realhf.base.constants as constants -import realhf.base.logging as logging -from realhf.api.core.model_api import ReaLModelConfig -from realhf.impl.model.utils.padding import unpad_input - -logger = logging.getLogger("profile layers", "system") - - -def make_layers(config: ReaLModelConfig, dtype, device): - from realhf.impl.model.nn.real_llm_base import ( - OutputHead, - ReaLModelBlock, - VocabPositionEmbedding, - ) - - embedding_layer = VocabPositionEmbedding( - config, - dtype=dtype, - device=device, - ) - real_model_blocks = [ - ReaLModelBlock( - config, - layer_index=i, - output_layernorm=(i == 1), - dtype=dtype, - device=device, - ) - for i in range(1) - ] - head = OutputHead( - config.hidden_dim, - 1 if config.is_critic else config.vocab_size, - bias=False, - device=device, - dtype=dtype, - ) - - layer_names = ["embedding_layer", "block_0", "head"] - return [embedding_layer] + real_model_blocks + [head], layer_names - - -class ProfileLayers: - - def __init__( - self, - model_name: str, - config: ReaLModelConfig, - tokenizer: transformers.PreTrainedTokenizerFast = None, - dtype: Optional[torch.dtype] = None, - device: Optional[Union[str, torch.device]] = None, - ): - self.model_name = model_name - self.config = config - self.backend_config = config_package.ModelBackend( - type_="deepspeed", - args=dict( - optimizer_name="adam", - optimizer_config=dict(lr=1e-5, weight_decay=0.0, betas=(0.9, 0.95)), - warmup_steps_proportion=0.0, - min_lr_ratio=0.0, - zero_stage=1, - bf16=False, - ), - ) - - self.dtype = dtype - self.device = device - self.layers, self.layer_names = make_layers(config, dtype, device) - self.hidden_dim = config.hidden_dim - self.head_dim = config.head_dim - self.max_new_tokens = 128 # only useful in kv cache memory alloc - self.min_new_tokens = 128 - - self.stats = defaultdict(list) - self.num_layers = len(self.layers) - - self.layers = [ - model_api.Model(name, layer, tokenizer, device=device, dtype=dtype) - for layer, name in zip(self.layers, self.layer_names) - ] - self.backend = model_api.make_backend(self.backend_config) - ft_spec = model_api.FinetuneSpec(10, 100, 10) - self.layers = [self.backend.initialize(layer, ft_spec) for layer in self.layers] - self.stats = defaultdict(list) - - def reset_stats(self): - self.stats = defaultdict(list) - - def insert_data_point(self, layer_name, name, bs, seq_len, time_ns): - self.stats["layer_name"].append(layer_name) - self.stats["op_name"].append(name) - self.stats["bs"].append(bs) - self.stats["seq_len"].append(seq_len) - self.stats["time_ns"].append(time_ns) - - @torch.no_grad() - def fwd_gen(self, bs, seq_len): - from realhf.impl.model.nn.real_llm_base import PipeCacheData, PipeTransferData - - input_ids = torch.randint( - 0, - self.config.vocab_size, - (bs, seq_len), - dtype=torch.long, - device=self.device, - ) - attention_mask = torch.ones_like(input_ids, device=self.device) - # fwd_gen_0 - packed_input_ids, _, cu_seqlens, max_seqlen = unpad_input( - input_ids, attention_mask - ) - cu_seqlens = cu_seqlens.to(device=self.device) - packed_input_ids = packed_input_ids.to(device=self.device) - x = PipeTransferData( - cu_seqlens=cu_seqlens, - max_seqlen=int(max_seqlen), - store_kv_cache=True, - ) - ys = [PipeCacheData() for _ in range(self.num_layers)] - ys[0].packed_input_ids = packed_input_ids - - for layer_name, layer, y in zip(self.layer_names, self.layers, ys): - st = time.monotonic_ns() - x: PipeTransferData = layer.module(x, y) - x.pp_input = x.pp_output - torch.cuda.synchronize() - self.insert_data_point( - layer_name, "fwd_gen_0", bs, seq_len, time.monotonic_ns() - st - ) - - prompt_logits = x.pp_output - logits = prompt_logits[cu_seqlens[1:] - 1] - input_lens = cu_seqlens[1:] - cu_seqlens[:-1] - cache_seqlens = input_lens.clone().to(dtype=torch.int32) - layer_indices = range(len(ys)) - - for y, layer_idx in zip(ys[1:-1], layer_indices[1:-1]): - assert ( - y.k_cache is not None - and y.v_cache is not None - and y.cache_seqlens is not None - ) - kvcache_seqlen = max( - max_seqlen + self.max_new_tokens, - self.hidden_dim // self.head_dim + 10, - ) - # fix of a flash attention bug - k_cache = torch.zeros( - (bs, kvcache_seqlen, *y.k_cache.shape[1:]), - dtype=y.k_cache.dtype, - device=self.device, - ) - v_cache = torch.zeros_like(k_cache) - indices = ( - torch.arange( - kvcache_seqlen, - device=constants.current_device(), - dtype=torch.long, - )[None, :] - < input_lens[:, None] - ) - k_cache[indices] = y.k_cache - v_cache[indices] = y.v_cache - y.k_cache = k_cache - y.v_cache = v_cache - y.cache_seqlens = cache_seqlens - x = PipeTransferData(store_kv_cache=True) - ys[0].cache_seqlens = cache_seqlens - - # fwd_gen_1 - new_tokens = torch.randint( - 0, - self.config.vocab_size, - (bs,), - dtype=torch.long, - device=self.device, - ) - ys[0].packed_input_ids = new_tokens - ys[0].packed_position_ids = None - x.cu_seqlens = torch.arange(bs + 1, dtype=torch.int32, device=self.device) - x.max_seqlen = 1 - for layer_name, layer, y in zip(self.layer_names, self.layers, ys): - st = time.monotonic_ns() - x = layer.module(x, y) - x.pp_input = x.pp_output - torch.cuda.synchronize() - self.insert_data_point( - layer_name, "fwd_gen_1", bs, seq_len, time.monotonic_ns() - st - ) - - def fwd_bwd_opt(self, bs, seq_len): - from realhf.impl.model.nn.real_llm_base import PipeCacheData, PipeTransferData - - input_ids = torch.randint( - 0, - self.config.vocab_size, - (bs, seq_len), - dtype=torch.long, - device=self.device, - ) - attention_mask = torch.ones_like(input_ids, device=self.device) - packed_input_ids, _, cu_seqlens, max_seqlen = unpad_input( - input_ids, attention_mask - ) - cu_seqlens = cu_seqlens.to(device=self.device) - packed_input_ids = packed_input_ids.to(device=self.device) - x = PipeTransferData( - cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, store_kv_cache=False - ) - ys = [PipeCacheData() for _ in range(self.num_layers)] - ys[0].packed_input_ids = packed_input_ids - - for layer_name, layer, y in zip(self.layer_names, self.layers, ys): - # fwd - st = time.monotonic_ns() - x: PipeTransferData = layer.module(x, y) - torch.cuda.synchronize() - self.insert_data_point( - layer_name, "fwd", bs, seq_len, time.monotonic_ns() - st - ) - # bwd - r = torch.rand( - *x.pp_output.shape, - device=x.pp_output.device, - dtype=x.pp_output.dtype, - ) - loss = torch.max(x.pp_output * r) - st = time.monotonic_ns() - layer.module.backward(loss) - torch.cuda.synchronize() - self.insert_data_point( - layer_name, "bwd", bs, seq_len, time.monotonic_ns() - st - ) - # opt - st = time.monotonic_ns() - layer.module.step() - torch.cuda.synchronize() - self.insert_data_point( - layer_name, "opt", bs, seq_len, time.monotonic_ns() - st - ) - x.pp_input = x.pp_output.clone().detach() - - def make_dataframe_and_print(self): - df = pd.DataFrame(self.stats) - logger.info(f"Current Stats: \nstr{df}") - - def dump_stats(self, world_size): - rank = dist.get_rank() - # dump full stats - dump_dir = os.path.join( - constants.PROFILER_CACHE_PATH, - "layer_stats", - ) - dump_path = os.path.join( - dump_dir, self.model_name, f"layer-stats_{world_size}_{rank}.pkl" - ) - os.makedirs(os.path.dirname(dump_path), exist_ok=True) - - with open(dump_path, "wb") as f: - df = pd.DataFrame(self.stats) - pickle.dump(df, f) - - -def make_profile_layers( - device: torch.device, - model_path: str, - model_name: str, - dtype: Optional[str] = None, - hf_model_type: str = "llama", -): - from realhf.impl.model.nn.real_llm_api import ReaLModel - - if dtype == "fp16" or dtype == None: - dtype = torch.float16 - elif dtype == "bf16": - dtype = torch.bfloat16 - elif dtype == "fp32": - dtype == torch.float32 - else: - raise NotImplementedError(f"Unsupported dtype {dtype}") - tokenizer = None - config: ReaLModelConfig = getattr(ReaLModel, f"config_from_{hf_model_type}")( - model_path=model_path, - ) - if tokenizer is None: - tokenizer = model_api.load_hf_tokenizer(model_path) - - profile_layers = ProfileLayers( - model_name, config, tokenizer=tokenizer, dtype=dtype, device=device - ) - - return profile_layers diff --git a/realhf/search_engine/param_realloc.py b/realhf/search_engine/param_realloc.py deleted file mode 100644 index 6e495e9..0000000 --- a/realhf/search_engine/param_realloc.py +++ /dev/null @@ -1,272 +0,0 @@ -# Copyright 2025 Ant Group Inc. -# Copyright 2024 Wei Fu & Zhiyu Mei -# Licensed under the Apache License, Version 2.0 (the "License"). - -import dataclasses -import itertools -import json -import time -from collections import defaultdict -from typing import * - -import torch -import torch.distributed - -import realhf.api.core.system_api as system_api -import realhf.base.constants as constants -import realhf.base.topology as topology -from realhf.api.core.config import ModelFamily, ModelName -from realhf.api.core.model_api import ReaLModelConfig -from realhf.base.topology import decompose_to_three_factors - - -def bcast_cost( - param_size: float, bw: float, src: int, dsts: List[int], n_nodes_per_gpu=8 -): - src_node = src // n_nodes_per_gpu - dst_nodes = [dst // n_nodes_per_gpu for dst in dsts] - if src_node == dst_nodes[0] and all( - dst_node == dst_nodes[0] for dst_node in dst_nodes - ): - return param_size * 2 * 8 / (1800 * 1024**3) - # return 0.0 - else: - # param size is in float16, bw is in Gbps - return param_size * 2 * 8 / (bw * 1024**3) * len(set(dst_nodes)) - - -def compute_cost( - world_size: int, - from_model_name: ModelName, - to_model_name: ModelName, - from_topo: topology.ProcessTopology, - to_topo: topology.ProcessTopology, - model_config: ReaLModelConfig, - bw: float, # Gbps - set_interval_cost: float, -) -> int: - from realhf.impl.model.comm.param_realloc import ( - ParamReallocInfo, - ReparallelizeReceiverStep, - ReparallelizeSenderStep, - _create_param_realloc_groups, - _derive_reparallelize_comm_plan, - ) - - param_sync_groups = {} - param_sync_src_ranks = {} - param_sync_dst_ranks = {} - msid2mwid = {} - for i in range(from_topo.world_size()): - msid2mwid[ - system_api.ModelShardID.from_parallelism_rank(from_model_name, from_topo, i) - ] = i - for i in range(to_topo.world_size()): - msid2mwid[ - system_api.ModelShardID.from_parallelism_rank(to_model_name, to_topo, i) - ] = (i + world_size - to_topo.world_size()) - _create_param_realloc_groups( - from_topo, - to_topo, - from_model_name, - to_model_name, - msid2mwid, - param_sync_groups, - param_sync_src_ranks, - param_sync_dst_ranks, - ) - pg_info = ParamReallocInfo( - param_sync_groups, - param_sync_src_ranks, - param_sync_dst_ranks, - ) - comm_plan = _derive_reparallelize_comm_plan( - from_model_name, - to_model_name, - from_topo, - to_topo, - model_config, - model_config, - pg_info, - ) - - # Run boradcast! - max_cost = max_comm_volume = max_bcast_cnt = 0 - for _rank in range(world_size): - cost = comm_volume = bcast_cnt = 0 - for step in comm_plan: - if isinstance(step, ReparallelizeReceiverStep) and step.rank == _rank: - if step.rank != step.src: - cost += bcast_cost(step.param_size, bw, step.src, step.dst_ranks) - comm_volume += step.param_size - bcast_cnt += 1 - cost += set_interval_cost - if isinstance(step, ReparallelizeSenderStep) and step.rank == _rank: - if step.group is not None: - cost += bcast_cost(step.param_size, bw, step.rank, step.dst_ranks) - bcast_cnt += 1 - max_cost = max(max_cost, cost) - max_comm_volume = max(max_comm_volume, comm_volume) - max_bcast_cnt = max(max_bcast_cnt, bcast_cnt) - - return max_cost - - -def dump_table( - n_nodes: int, - model_family: ModelFamily, - model_path: str, - rank: int = 0, - parallel: int = 1, -): - from_model_name = ModelName("actor", 0) - to_model_name = ModelName("actor", 1) - - def hash_tuple_into_str(t) -> str: - return ",".join([str(i) for i in t]) - - import tqdm - - res = {} - device_mesh_sizes = [4] + [8 * i for i in range(1, n_nodes + 1)] - space = list(itertools.product(device_mesh_sizes, device_mesh_sizes)) - sub_space = space[rank::parallel] - # for a, b in set(small_spaces + large_spaces): - for a, b in sub_space: - mtik = time.perf_counter() - all_configs = list( - itertools.product( - decompose_to_three_factors(a), decompose_to_three_factors(b) - ) - ) - all_configs = list(filter(lambda x: x[0][1] <= 8 and x[1][1] <= 8, all_configs)) - all_configs = list(filter(lambda x: x[0][2] <= 8 and x[1][2] <= 8, all_configs)) - all_configs = list( - filter( - lambda x: x[0][1] in [1, 2, 4, 8] and x[1][1] in [1, 2, 4, 8], - all_configs, - ) - ) - all_configs = list( - filter(lambda x: x[0][0] <= 16 and x[1][0] <= 16, all_configs) - ) - all_configs = list( - filter( - lambda x: x[0][1] % x[1][1] == 0 or x[1][1] % x[0][1] == 0, - all_configs, - ) - ) - for config_id, (from_pp_mp_dp, to_pp_mp_dp) in tqdm.tqdm( - enumerate(all_configs) - ): - world_size = max(a, b) - - from_topo = topology.PipeDataModelParallelTopology( - *from_pp_mp_dp, False, False - ) - to_topo = topology.PipeDataModelParallelTopology(*to_pp_mp_dp, False, False) - assert world_size >= from_topo.world_size() - assert world_size >= to_topo.world_size() - - from realhf.search_engine.utils import load_model_config - - mconfig = load_model_config(model_family._class, model_path) - - cost = compute_cost( - world_size, - from_model_name, - to_model_name, - from_topo, - to_topo, - mconfig, - bw=200.0, - set_interval_cost=0.03, - ) - res[ - hash_tuple_into_str((model_family.size, *from_pp_mp_dp, *to_pp_mp_dp)) - ] = int(cost * 1000 * 1000) - print( - f"Time for model size {model_family.size} {a} -> {b} {rank}/{parallel}: " - f"{time.perf_counter() - mtik:.4f}, num res entries {len(res)}" - ) - - print(f"Rank {rank} of model {model_family} finished, res size {len(res)}.") - - import os - import pickle - - dump_path = os.path.join(constants.PROFILER_CACHE_PATH, "param_realloc") - fn = f"prtc_{model_family}_n{n_nodes}_{rank}_{parallel}.pkl" - if not os.path.exists(dump_path): - os.makedirs(dump_path, exist_ok=True) - with open(os.path.join(dump_path, fn), "wb") as f: - pickle.dump(res, f) - print(f"dumped table with {len(res)} entries to {model_family}-{rank}-{parallel}.") - - -def dump_table_parallel( - n_nodes: int, - model_family_to_path: Dict[ModelFamily, str], - parallel: int = 4, -): - import torch.multiprocessing as mp - - mp.set_start_method("spawn", force=True) - - rq = mp.Queue() - ps = [] - for model_family, model_path in model_family_to_path.items(): - for rank in range(parallel): - ps.append( - mp.Process( - target=dump_table, - args=(n_nodes, model_family, model_path, rank, parallel), - ) - ) - - for p in ps: - p.start() - - for p in ps: - p.join() - - -def merge_tables( - n_nodes: int, - model_family_to_path: Dict[ModelFamily, str], - parallel: int = 4, -): - import os - import pickle - - res_path = os.path.join(constants.PROFILER_CACHE_PATH, "param_realloc") - - for model_family in model_family_to_path.keys(): - prefix = f"prtc_{model_family}_n{n_nodes}" - - r = {} - counter = 0 - for fn in os.listdir(res_path): - if fn.endswith(".pkl") and fn.startswith(prefix): - counter += 1 - path = os.path.join(res_path, fn) - with open(path, "rb") as f: - r.update(pickle.load(f)) - os.remove(path) - if counter < parallel: - raise RuntimeError( - "missing sub-tables, probably some sub-processes failed " - "during param realloc time cost estimation." - ) - with open(os.path.join(res_path, f"{prefix}.pkl"), "wb") as f: - pickle.dump(r, f) - print(f"merged parallel tables into {prefix}.pkl, total entries {len(r)}") - - -def estimate_param_realloc_time_cost( - n_nodes: int, - model_family_to_path: Dict[ModelFamily, str], - parallel: int = 4, -): - dump_table_parallel(n_nodes, model_family_to_path, parallel) - merge_tables(n_nodes, model_family_to_path, parallel) diff --git a/realhf/search_engine/search.py b/realhf/search_engine/search.py deleted file mode 100644 index 3f9d2f0..0000000 --- a/realhf/search_engine/search.py +++ /dev/null @@ -1,232 +0,0 @@ -# Copyright 2025 Ant Group Inc. -# Copyright 2024 Wei Fu & Zhiyu Mei -# Licensed under the Apache License, Version 2.0 (the "License"). - -import argparse -import functools -import json -import os -import pickle -import pprint -import re -from typing import Any, Dict, List, Literal, Optional - -import numpy as np - -try: - import realhf._C.mdm_search as mdm_search -except ModuleNotFoundError: - mdm_search = None - -import realhf.base.constants as constants -from realhf.api.cli_args import ModelTrainEvalConfig, ParallelismConfig -from realhf.api.core.config import ModelInterfaceType -from realhf.api.core.dfg import MFCDef -from realhf.api.quickstart.device_mesh import DeviceMesh, RPCAllocation -from realhf.api.quickstart.search import RPCExecution - - -def search_rpc_allocations( - device_mesh: DeviceMesh, - rpcs: List[MFCDef], - num_gen_tokens: int = 256, - n_ppo_minibatches: int = 1, - seq_len: int = 256, - gradient_checkpointing: bool = True, - use_cache: bool = False, -) -> List[RPCAllocation]: - from realhf.search_engine.enumerate import build_graph - from realhf.search_engine.estimate import get_param_realloc_stats - - from_file = os.environ.get("REAL_IS_REMOTE", "0") == "1" - dump_dir = os.path.join( - constants.LOG_ROOT, - constants.experiment_name(), - constants.trial_name(), - "device_mapping.pkl", - ) - log_dir = os.path.join( - constants.LOG_ROOT, - constants.experiment_name(), - constants.trial_name(), - "device_mapping", - ) - rs_dir = os.path.join( - constants.LOG_ROOT, - constants.experiment_name(), - constants.trial_name(), - "raw_search_result", - ) - rpc_exe_dir = os.path.join( - constants.LOG_ROOT, - constants.experiment_name(), - constants.trial_name(), - "rpc_exe_info", - ) - - if from_file or (use_cache and os.path.exists(dump_dir)): - with open(dump_dir, "r") as f: - s = json.load(f) - rpc_allocs = [RPCAllocation.from_dict(d) for d in s] - return rpc_allocs - else: - os.makedirs(os.path.dirname(dump_dir), exist_ok=True) - - n_nodes = device_mesh.n_nodes - table = {} - for rpc in rpcs: - print(f"Getting param realloc stats for {rpc.model_type} at {rpc.model_path}") - t = get_param_realloc_stats(rpc.model_type, rpc.model_path, n_nodes, True) - table.update(t) - - rpc_exe_list = make_rpc_exe_list( - rpcs, - device_mesh, - num_gen_tokens=num_gen_tokens, - n_ppo_minibatches=n_ppo_minibatches, - seq_len=seq_len, - gradient_checkpointing=gradient_checkpointing, - log_dir=rpc_exe_dir, - if_print=False, - ) - graph = build_graph(rpcs, 5, 1, if_print=False) - model_size_dict = make_model_size_dict(rpcs, if_print=False) - - n_nodes = device_mesh.n_nodes - search_time = 120 - - rs: List[Dict[str, List]] = mdm_search.multi_mcmc_search( - rpcs, - rpc_exe_list, - graph, - table, - model_size_dict, - 0.001, # beta min - 0.002, # beta max - 0.001, # beta step - search_time, # time limit for each search - 1, # repeat - ) - if not from_file: - with open(rs_dir, "w") as f: - - pprint.pprint(rs, stream=f) - - r: Dict[str, Dict[str, Any]] = rs[-1] - pprint.pprint(r) - - rpc_name_to_rpcs = {rpc.name: rpc for rpc in rpcs} - rpc_allocs = [] - for rpc_name, alloc_info in r.items(): - if rpc_name in ["end_time", "mem_cost"]: - continue - # rpc = rpc_dict[rpc_name] - rpc = rpc_name_to_rpcs[rpc_name] - parallel = ParallelismConfig( - pipeline_parallel_size=alloc_info["num_pp"], - data_parallel_size=alloc_info["num_dp"], - model_parallel_size=alloc_info["num_mp"], - use_sequence_parallel=( - alloc_info["num_mp"] > 1 - and rpc.interface_type == ModelInterfaceType.TRAIN_STEP - ), - ) - sub_device_mesh = DeviceMesh( - n_nodes=device_mesh.n_nodes, - n_gpus_per_node=device_mesh.n_gpus_per_node, - mapping=alloc_info["device_mesh_mapping"], - name=alloc_info["device_mesh_name"], - global_mesh_name=device_mesh.global_mesh_name, - ) - rpc_alloc = RPCAllocation( - rpc=rpc, - device_mesh=sub_device_mesh, - parallel=parallel, - ) - rpc_allocs.append(rpc_alloc) - - if not from_file: - with open(dump_dir, "w") as f: - json.dump([rpc_alloc.to_dict() for rpc_alloc in rpc_allocs], f, indent=4) - with open(log_dir, "w") as f: - - pprint.pprint(rpc_allocs, stream=f) - - return rpc_allocs - - -def make_rpc_exe_list( - rpcs: List[MFCDef], - device_mesh: DeviceMesh, - num_gen_tokens: int, - n_ppo_minibatches: int, - seq_len: int, - gradient_checkpointing: bool, - if_print: bool = False, - log_dir: Optional[str] = None, -) -> List[RPCExecution]: - from realhf.search_engine.enumerate import enumerate_rpc_executions - - rpc_exe_list = [] - log_flag = False - for rpc in rpcs: - # real_model_config = load_model_config(rpc) - feasible = enumerate_rpc_executions( - rpc, - device_mesh, - seq_len=seq_len, - num_gen_tokens=num_gen_tokens, - n_ppo_minibatches=n_ppo_minibatches, - gradient_checkpointing=gradient_checkpointing, - ) - rpc_exe_list.extend(feasible) - - if log_dir is not None: - mode = "w" if not log_flag else "a" - with open(log_dir, mode) as f: - f.write(f"{rpc.name} feasible: {len(feasible)}\n") - feasible.sort(key=lambda x: x.time_cost) - # feasible = feasible[:30] - for i, rpc_exe in enumerate(feasible): - f.write( - f"{i}: time_cost: {rpc_exe.time_cost} ms, {rpc_exe.time_cost} " - f"sub_device_mesh: {rpc_exe.device_mesh}, " - f"parallel_strategy: {rpc_exe.parallel_strategy}, " - f"mem_cost: {rpc_exe.mem/(1024*1024*1024):02f} GB, " - f"static_mem_cost: {rpc_exe.static_mem/(1024*1024*1024):02f} GB\n" - ) - f.write("\n") - log_flag = True - - if if_print: - print(f"{rpc.name} feasible: {len(feasible)}") - feasible.sort(key=lambda x: x.time_cost) - # feasible = feasible[:10] - for i, rpc_exe in enumerate(feasible): - print( - f"{i}: time_cost: {rpc_exe.time_cost} ms, " - f"sub_device_mesh: {rpc_exe.device_mesh}, " - f"parallel_strategy: {rpc_exe.parallel_strategy}, " - f"mem_cost: {rpc_exe.mem/(1024*1024*1024):02f} GB, " - f"static_mem_cost: {rpc_exe.static_mem/(1024*1024*1024):02f} GB" - ) - - return rpc_exe_list - - -def make_model_size_dict(rpcs: List[MFCDef], if_print: bool = False) -> Dict[str, int]: - model_size_dict = {} - - for rpc in rpcs: - if rpc.model_name.role in model_size_dict: - continue - # model_configs = load_model_config(rpc) - # model_size_dict[rpc.model_name.role] = estimate_model_size(real_model_config) - model_size_dict[rpc.model_name.role] = rpc.model_type.size - - if if_print: - print( - f"model_name: {rpc.model_name.role}, " - f"model_size: {rpc.model_type.size}" - ) - return model_size_dict diff --git a/realhf/search_engine/utils.py b/realhf/search_engine/utils.py deleted file mode 100644 index b946ec5..0000000 --- a/realhf/search_engine/utils.py +++ /dev/null @@ -1,28 +0,0 @@ -# Copyright 2025 Ant Group Inc. -# Copyright 2024 Wei Fu & Zhiyu Mei -# Licensed under the Apache License, Version 2.0 (the "License"). - -from realhf.api.core.model_api import ReaLModelConfig - - -def find_factors(n): - factors = [] - for i in range(1, n + 1): - if n % i == 0: - factors.append(i) - return factors - - -def make_stats_key(rpc_name, bs, seq_len): - return f"{rpc_name}|{bs}|{seq_len}" - - -def parse_stats_key(key): - rpc_name, bs, seq_len = key.split("|") - return rpc_name, int(bs), int(seq_len) - - -def load_model_config(model_class: str, model_path: str) -> ReaLModelConfig: - from realhf.impl.model.nn.real_llm_api import ReaLModel - - return getattr(ReaLModel, f"config_from_{model_class}")(model_path=model_path) diff --git a/realhf/system/controller.py b/realhf/system/controller.py index bdd5942..69b0eca 100644 --- a/realhf/system/controller.py +++ b/realhf/system/controller.py @@ -478,9 +478,11 @@ def run_ray_worker( # NOTE: Importing these will initialize DeepSpeed/CUDA devices. # profiler.import_profiler_registers() - import realhf.impl.dataset - import realhf.impl.model - import realhf.system + if worker_type != "master_worker": + # For master_worker, there could be errors while importing and it is not necessary. + import realhf.impl.dataset + import realhf.impl.model + import realhf.system worker_name = f"{worker_type}/{idx}" server = worker_control.make_server( diff --git a/realhf/system/data_manager.py b/realhf/system/data_manager.py index 2f063a1..092e241 100644 --- a/realhf/system/data_manager.py +++ b/realhf/system/data_manager.py @@ -24,6 +24,17 @@ SCATTER_GROUPS = {} logger = logging.getLogger("data_manager", "system") +def find_minimal_superset(A: List[Set[int]], B: Set[int]) -> Set[int] | None: + min_size = float("inf") + result = None + for S in A: + if B.issubset(S): + if len(S) < min_size: + min_size = len(S) + result = S + return result + + class DataManager: def __init__( @@ -52,7 +63,7 @@ class DataManager: mw_ranks: Dict[ModelName, List[int]] = {} - # Stores the dp_head (i.e., mp_rank=0, pp_rank=-1) ranks given a model_name. + # Stores the dp_head (i.e., tp_rank=0, pp_rank=-1) ranks given a model_name. mw_dp_head_ranks: Dict[ModelName, List[int]] = defaultdict(list) assert msid2mwid is not None @@ -67,7 +78,7 @@ class DataManager: topo, msid2mwid, pipe=topo.get_dim("pipe") - 1, - model=0, + tensor=0, ) dp_size = topo.get_dim("data") for dp_i in range(dp_size): @@ -87,11 +98,12 @@ class DataManager: list(ranks), backend="nccl" if constants.use_cuda() else "gloo" ) - scatter_ranks = tuple(sorted(set([ranks[0]] + mw_ranks[dst]))) - SCATTER_GROUPS[scatter_ranks] = new_or_get_group( - list(scatter_ranks), - backend="nccl" if constants.use_cuda() else "gloo", - ) + for rank in ranks: + scatter_ranks = tuple(sorted(set([rank] + mw_ranks[dst]))) + SCATTER_GROUPS[scatter_ranks] = new_or_get_group( + list(scatter_ranks), + backend="nccl" if constants.use_cuda() else "gloo", + ) # Construct all src-dst pairs, from any src dp rank to any dst dp rank. # Note that a dp rank corresponds to multiple parameter shards (TP+PP), @@ -228,7 +240,20 @@ class DataManager: def _run_gather( self, step: RedistribStep, data_infos: Dict[Hashable, SequenceSample] ): - if dist.get_rank() not in step.srcs: + # It's possible that some DP rank is not involved. + # Create dummpy data to make the gather happy. + gather_ranks = find_minimal_superset( + [set(k) for k in GATHER_GROUPS.keys()], set(step.srcs) + ) + assert gather_ranks is not None, ( + set(step.srcs), + [set(k) for k in GATHER_GROUPS.keys()], + ) + gather_ranks = sorted(list(gather_ranks)) + + pgroup = GATHER_GROUPS[tuple(gather_ranks)] + + if dist.get_rank() not in gather_ranks: return maxlen = 0 @@ -249,36 +274,47 @@ class DataManager: torch.empty( maxlen, device=constants.current_device(), dtype=torch.float32 ) - for _ in range(len(step.srcs)) + for _ in gather_ranks ] + is_valid_gather = [i in step.srcs for i in gather_ranks] else: gather_list = None - local_gather_idx = step.srcs.index(dist.get_rank()) - ids = step.ids[local_gather_idx] - for i in ids: - self.storage[i].to_device(constants.current_device()) - samples = [self.storage[i] for i in ids] - data = torch.cat( - [ - sample.data[key].float().flatten() - for sample in samples - for key in step.keys - ] - ) - data = self._pad_data(data, maxlen) + if dist.get_rank() in step.srcs: + local_gather_idx = step.srcs.index(dist.get_rank()) + ids = step.ids[local_gather_idx] + for i in ids: + self.storage[i].to_device(constants.current_device()) + samples = [self.storage[i] for i in ids] + data = torch.cat( + [ + sample.data[key].float().flatten() + for sample in samples + for key in step.keys + ] + ) + data = self._pad_data(data, maxlen) + else: + data = torch.empty( + maxlen, device=constants.current_device(), dtype=torch.float32 + ) dist.gather( data, gather_list, dst=step.root, - group=GATHER_GROUPS[tuple(sorted(step.srcs))], + group=pgroup, ) if dist.get_rank() != step.root: + del data return - for ids, buf in zip(step.ids, gather_list): + cnt = 0 + for is_valid, buf in zip(is_valid_gather, gather_list): + if not is_valid: + continue + ids = step.ids[cnt] offset = 0 for i in ids: for key in step.keys: @@ -302,6 +338,9 @@ class DataManager: self.storage[i].update_(s) else: self.storage[i] = s + cnt += 1 + assert cnt == len(step.srcs) == len(step.ids) + del data def _run_scatter( self, step: RedistribStep, data_infos: Dict[Hashable, SequenceSample] diff --git a/realhf/system/function_executor.py b/realhf/system/function_executor.py index e934680..ddbc2aa 100644 --- a/realhf/system/function_executor.py +++ b/realhf/system/function_executor.py @@ -27,7 +27,7 @@ class FunctionExecutor: rpcs: List[MFCDef], msid2mwid: Dict[ModelShardID, int], stream: NameResolvingRequestClient, - buffer: AsyncIOSequenceBuffer, + buffers: List[AsyncIOSequenceBuffer], model_topos: Dict[str, ProcessTopology], model_configs: Dict[str, None | ReaLModelConfig], ctrl: RPCCorountineControl, @@ -58,14 +58,15 @@ class FunctionExecutor: model_topos=model_topos, model_configs=model_configs, ctrl=ctrl, - buffer=buffer, + buffers=buffers, redistrib_planner=self.redistrib_planner, summary_writer=summary_writer, ) self.func_calls[rpc.name] = func_call self.stream = stream - self.buffer = buffer + self.buffers = buffers + self.buffer_id = 0 self.data_loading_dp_idx = -1 self.shuffle_dataset = shuffle_dataset @@ -111,18 +112,17 @@ class FunctionExecutor: self.ctrl.ids_to_clear.clear() - async def load_data(self): - buffer = self.buffer + async def load_data(self, buffer_id: int): + buffer = self.buffers[buffer_id] ctrl = self.ctrl received_ids = set() - while self.buffer.size < max(rpc.n_seqs for rpc in self.rpcs): - + while buffer.size < max(rpc.n_seqs for rpc in self.rpcs): resps = await self.stream.call_async( handlers=[f"__data{dp_idx}__" for dp_idx in range(self.src_dp_size)], handle_type="fetch", - datas=[None for _ in range(self.src_dp_size)], + datas=[buffer_id for _ in range(self.src_dp_size)], verbose=False, ) @@ -182,10 +182,13 @@ class FunctionExecutor: logger.info("Waiting for the finish of the execution graph.") loop = asyncio.get_event_loop() - tasks = [loop.create_task(fc.run()) for fc in self.func_calls.values()] + [ + tasks = [ + loop.create_task(fc.run(self.buffer_id)) for fc in self.func_calls.values() + ] + [ loop.create_task(self.flush_calls()), - loop.create_task(self.load_data()), + loop.create_task(self.load_data(self.buffer_id)), loop.create_task(self.finish_traverse()), ] loop.run_until_complete(asyncio.gather(*tasks)) + self.buffer_id = (self.buffer_id + 1) % len(self.buffers) diff --git a/realhf/system/generation_server.py b/realhf/system/generation_server.py index ad2928e..ad1f889 100644 --- a/realhf/system/generation_server.py +++ b/realhf/system/generation_server.py @@ -1,14 +1,111 @@ import os +import subprocess +import sys import time +from pathlib import Path + +import requests from realhf.api.cli_args import SGLangConfig from realhf.api.core.system_api import GenerationServer as GenerationServerConfig -from realhf.base import gpu_utils, logging, name_resolve, names, network, seeding +from realhf.base import ( + gpu_utils, + logging, + name_resolve, + names, + network, + pkg_version, + seeding, +) +from realhf.base.cluster import spec as cluster_spec from realhf.system.worker_base import PollResult, Worker logger = logging.getLogger(__name__) +def execute_shell_command(command: str) -> subprocess.Popen: + """ + Execute a shell command and return its process handle. + """ + # Replace newline continuations and split the command string. + command = command.replace("\\\n", " ").replace("\\", " ") + parts = command.split() + return subprocess.Popen(parts, text=True, stderr=subprocess.STDOUT) + + +def launch_server_cmd(command: str, port: int = 30000): + """ + Launch the server using the given command. + If no port is specified, a free port is reserved. + """ + p = Path(os.path.dirname(__file__)) + patch_path = str( + p.parent.parent + / "patch" + / "sglang" + / f"v{pkg_version.get_version('sglang')}.patch" + ) + + target_path = "" + sglang_meta = subprocess.check_output( + "python3 -m pip show sglang", shell=True + ).decode("ascii") + for line in sglang_meta.split("\n"): + line = line.strip() + if line.startswith("Editable project location: "): + target_path = str(Path(line.split(": ")[1]).parent) + + if target_path: + proc = subprocess.Popen( + ["git", "apply", patch_path], + cwd=target_path, + stderr=sys.stdout, + stdout=sys.stdout, + ) + proc.wait() + logger.info(f"Applied SGLang patch at {target_path}") + assert port is not None + full_command = f"{command} --port {port}" + process = execute_shell_command(full_command) + return process, port + + +def terminate_process(process, port=None): + """ + Terminate the process and, if a port was reserved, release it. + """ + from sglang.srt.utils import kill_process_tree + + kill_process_tree(process.pid) + + +def wait_for_server(base_url: str, timeout: int = None) -> None: + """Wait for the server to be ready by polling the /v1/models endpoint. + + Args: + base_url: The base URL of the server + timeout: Maximum time to wait in seconds. None means wait forever. + """ + start_time = time.time() + while True: + try: + response = requests.get( + f"{base_url}/v1/models", + headers={"Authorization": "Bearer None"}, + ) + if response.status_code == 200: + time.sleep(5) + break + + if timeout and time.time() - start_time > timeout: + raise TimeoutError("Server did not become ready within timeout period") + except requests.exceptions.RequestException: + time.sleep(1) + + +PORT_CLEARANCE_PERIOD = 90 + + class GenerationServer(Worker): def _configure(self, config: GenerationServerConfig): self.config = config @@ -36,20 +133,37 @@ class GenerationServer(Worker): config = self.config assert config.backend_type == "sglang" + + host_ip = network.gethostip() + host = "localhost" if not config.backend_args.enable_metrics else host_ip + + # NOTE: Ports returned by `find_multiple_free_ports` are unique, + # but SGLang servers still encounter conflicts. + # Use a clearance period to hack over this issue. + servers_per_node = cluster_spec.n_gpus_per_node // self.config.tp_size + idx_on_this_node = self.worker_index % servers_per_node + time.sleep(idx_on_this_node * PORT_CLEARANCE_PERIOD / servers_per_node) + + ports = network.find_multiple_free_ports( + 2, + low=10000, + high=60000, + experiment_name=self.experiment_name, + trial_name=self.trial_name, + ) + server_port = ports[0] + nccl_port = ports[1] + cmd = SGLangConfig.build_cmd( config.backend_args, config.model_path, tp_size=config.tp_size, server_index=self.worker_index, base_gpu_id=self.base_gpu_id, + dist_init_addr=f"{host}:{nccl_port}", ) - from sglang.utils import launch_server_cmd, wait_for_server - host_ip = network.gethostip() - host = "localhost" if not config.backend_args.enable_metrics else host_ip - - # TODO: handle launching error and retry - self.server_process, self.server_port = launch_server_cmd(cmd) + self.server_process, self.server_port = launch_server_cmd(cmd, port=server_port) self.server_addr = f"http://{host}:{self.server_port}" wait_for_server(self.server_addr) @@ -80,6 +194,5 @@ class GenerationServer(Worker): def _exit_hook(self, exit_status): if self.server_process is not None and self.config.backend_type == "sglang": - from sglang.utils import terminate_process terminate_process(self.server_process) diff --git a/realhf/system/gserver_manager.py b/realhf/system/gserver_manager.py index 39d1f07..aaa25f3 100644 --- a/realhf/system/gserver_manager.py +++ b/realhf/system/gserver_manager.py @@ -6,20 +6,46 @@ import shutil import threading import time from collections import defaultdict +from dataclasses import dataclass from typing import List import aiohttp +import numpy as np from realhf.api.core.model_api import GenReqMeta, GenRespMeta, ModelVersionReq from realhf.api.core.system_api import GserverManager as GserverManagerConfig from realhf.base import constants, logging, name_resolve, names, network, recover -from realhf.system.worker_base import AsyncWorker, PollResult, Worker +from realhf.system.worker_base import PollResult, Worker -logger = logging.getLogger("Generation Manager", "colored") +logger = logging.getLogger("Generation Manager", "system") STALENESS_WARNED = defaultdict(lambda: False) +@dataclass +class RolloutStat: + submit: int = 0 + accepted: int = 0 + running: int = 0 + + def inc(self): + self.submit += 1 + self.accepted += 1 + self.running += 1 + + def accept(self): + self.running -= 1 + + def reject(self): + self.running -= 1 + self.accepted -= 1 + + +@dataclass +class AllocateRolloutInput: + qid: str + + class GserverManager(Worker): """This worker has the following functionalities: 1. As a router, it schedules generation requests and returns the @@ -37,23 +63,30 @@ class GserverManager(Worker): assert self.config.worker_info.worker_count == 1 - self.async_lock = asyncio.Lock() self.threading_lock = threading.Lock() - self.n_total_rollouts = 0 - self.n_running_rollouts = 0 - self.accepted_rollouts = 0 + self.rollout_stat = RolloutStat() self.schedule_policy = config.schedule_policy self._last_param_realloc_step = 0 + self._qid_to_server_url = {} + + self._server_token_usage = defaultdict(float) + self._server_request_counts = defaultdict(int) + + self._last_thpt_output_time = time.time() + self._gen_tokens = 0 + self.experiment_name = config.worker_info.experiment_name self.trial_name = config.worker_info.trial_name # manager server - self.server = None + self.manager_http_server = None self.thread = None + self.server_urls = [] + # recover info self.__recover_run, self.__recover_info = recover.load_recover_info() if self.__recover_run: @@ -67,10 +100,12 @@ class GserverManager(Worker): name_resolve.add(name, self.__recover_info.last_step_info.global_step) self._loaded_recover_weights = False - self.n_total_rollouts = self.accepted_rollouts = ( + hist_rollouts = ( self.config.train_batch_size * self.__recover_info.last_step_info.global_step ) + self.rollout_stat.submit = hist_rollouts + self.rollout_stat.accepted = hist_rollouts return config.worker_info @@ -84,7 +119,7 @@ class GserverManager(Worker): if cnt >= timeout: raise TimeoutError("Waiting generation servers timeout.") urls = name_resolve.get_subtree(name) - assert len(set(urls)) == len(urls), urls + assert len(set(urls)) == len(urls), (len(urls), len(set(urls)), urls) return urls def _get_recover_ckpt_path(self, role: str): @@ -146,49 +181,34 @@ class GserverManager(Worker): async def flush_requests_and_update_weights( self, server_url, new_param_path, update_weights_retries=5 ): - # HACK: urls are designed for SGLang server_index = self.server_urls.index(server_url) - async with aiohttp.ClientSession(server_url) as session: - running_requests = None - tik = time.perf_counter() - while running_requests is None or running_requests > 0: - if time.perf_counter() - tik > self.config.flush_request_timeout: - raise RuntimeError( - f"Waiting for flush requests failed. {running_requests} requests " - f"remain after {self.config.flush_request_timeout} secs waiting. " - f"Please try to reduce `new_tokens_per_chunk`." - ) - if running_requests is not None and running_requests > 0: - logger.info( - f"Waiting for {running_requests} requests on gen server {server_index}... " - f"Time taken so far: {time.perf_counter() - tik:.4f}s" - ) - await asyncio.sleep(0.5) - async with session.get(f"/metrics") as resp: - resp.raise_for_status() - text = await resp.text() - for line in text.split("\n"): - if line.startswith("sglang:num_running_reqs"): - running_requests = float(line.split(" ")[1]) - break - - success = False - for _ in range(update_weights_retries): + success = False + for _ in range(update_weights_retries): + async with aiohttp.ClientSession( + server_url, + timeout=aiohttp.ClientTimeout( + total=self.config.flush_request_timeout, sock_connect=30 + ), + ) as session: async with session.post( f"/update_weights_from_disk", - json=dict(model_path=new_param_path), + json=dict(model_path=new_param_path, allow_interrupt=True), ) as resp: if resp.status == 200: res = await resp.json() success = res["success"] if success: + logger.info( + f"{res['num_paused_requests']} requests are interrupted " + f"during updateing weights for server {server_index}: {server_url}" + ) return logger.warning( f"Update weights failed: {res['message']}. Retrying." ) logger.warning(f"Update weights failed: {resp.reason}. Retrying.") time.sleep(0.1) - raise RuntimeError("Update weights failed.") + raise RuntimeError("Update weights failed.") def _round_robin_schedule(self, req_meta: GenReqMeta) -> int: if not hasattr(self, "round_robin_idx"): @@ -198,6 +218,16 @@ class GserverManager(Worker): self.round_robin_idx %= self.config.n_servers return r + def _least_requests_schedule(self, req_meta: GenReqMeta) -> int: + counts = [ + self._server_request_counts[server_url] for server_url in self.server_urls + ] + return int(np.argmin(counts)) + + def _least_token_usage_schedule(self, req_meta: GenReqMeta) -> int: + url = min(self.server_urls, key=lambda k: self._server_token_usage[k]) + return self.server_urls.index(url) + def _poll(self): if not self.thread: # Find addresses of generation servers @@ -228,6 +258,23 @@ class GserverManager(Worker): loop.run_until_complete(asyncio.gather(*tasks)) logger.info(f"Generaion server updated weights from: {new_param_path}") + tasks = [ + self._get_server_token_usage(server_url) for server_url in self.server_urls + ] + loop = asyncio.get_event_loop() + token_usages = loop.run_until_complete(asyncio.gather(*tasks)) + with self.threading_lock: + for server_url, token_usage in zip(self.server_urls, token_usages): + self._server_token_usage[server_url] = token_usage + + if time.time() - self._last_thpt_output_time > 30: + interval = time.time() - self._last_thpt_output_time + logger.info( + f"Generation throughput: {self._gen_tokens / interval:.2f} tokens/s" + ) + self._last_thpt_output_time = time.time() + self._gen_tokens = 0 + # clear old weights realloc_root = os.path.join( constants.PARAM_REALLOC_PATH, @@ -237,9 +284,11 @@ class GserverManager(Worker): ) if os.path.exists(realloc_root): for realloc_version in os.listdir(realloc_root): + # Lock-free is safe here. + # Remain one checkpoint for recover. if ( os.path.isdir(os.path.join(realloc_root, realloc_version)) - and int(realloc_version) < self._last_param_realloc_step + and int(realloc_version) < self._last_param_realloc_step - 1 ): shutil.rmtree(os.path.join(realloc_root, realloc_version)) logger.info( @@ -247,29 +296,56 @@ class GserverManager(Worker): f"checkpoint: {os.path.join(realloc_root, realloc_version)}" ) - # TODO: we may want to update server status - # in the main thread. - - time.sleep(1) + time.sleep(5) return PollResult(0, 0) - async def is_staled(self): - global_sample_cnt = self.n_total_rollouts - expected_version = global_sample_cnt // self.config.train_batch_size - staled = ( - expected_version - > self.config.max_head_offpolicyness + self._last_param_realloc_step + async def _get_server_token_usage(self, server_url): + async with aiohttp.ClientSession( + server_url, + timeout=aiohttp.ClientTimeout( + total=self.config.flush_request_timeout, sock_connect=30 + ), + ) as session: + async with session.get("/metrics") as resp: + resp.raise_for_status() + text = await resp.text() + for l in text.split("\n"): + if l.startswith("sglang:num_used_tokens"): + return float(l.split(" ")[1]) + raise RuntimeError(f"Failed to get token usage metrics from {server_url}") + + async def _get_server_num_running_requests(self, server_url): + async with aiohttp.ClientSession( + server_url, + timeout=aiohttp.ClientTimeout( + total=self.config.flush_request_timeout, sock_connect=30 + ), + ) as session: + async with session.get(f"/metrics") as resp: + resp.raise_for_status() + text = await resp.text() + for line in text.split("\n"): + if line.startswith("sglang:num_running_reqs"): + return float(line.split(" ")[1]) + raise RuntimeError( + f"Failed to get num running requests metrics from {server_url}" ) + + def is_staled(self): + global_sample_cnt = self.rollout_stat.accepted + expected_version = global_sample_cnt // self.config.train_batch_size + version = self._last_param_realloc_step + staled = expected_version > self.config.max_head_offpolicyness + version global STALENESS_WARNED - if staled and not STALENESS_WARNED[self._last_param_realloc_step]: + if staled and not STALENESS_WARNED[version]: logger.warning( f"expected version ({expected_version}) = " f"global sample cnt ({global_sample_cnt}) // batch size ({self.config.train_batch_size}), " - f"current version {self._last_param_realloc_step}, " + f"current latest version {version}, " f"offpolicyness {self.config.max_head_offpolicyness}. Staled? {staled}" ) - STALENESS_WARNED[self._last_param_realloc_step] = True + STALENESS_WARNED[version] = True return staled def _run_routing_service(self): @@ -282,60 +358,94 @@ class GserverManager(Worker): @self.app.post("/schedule_request") async def schedule_request(req_meta: GenReqMeta): with self.threading_lock: - async with self.async_lock: - version = self._last_param_realloc_step - # FIXME: We only implement a round-robin scheduler that - # ignores server status and request metadata + if ( + req_meta.previous_server_url + and req_meta.previous_version == self._last_param_realloc_step + ): + return dict( + url=req_meta.previous_server_url, + version=req_meta.previous_version, + ) + + if self.schedule_policy == "round_robin": server_idx = self._round_robin_schedule(req_meta) - return dict(url=self.server_urls[server_idx], version=max(0, version)) + elif self.schedule_policy == "least_token_usage": + server_idx = self._least_token_usage_schedule(req_meta) + elif self.schedule_policy == "least_requests": + server_idx = self._least_requests_schedule(req_meta) + else: + raise NotImplementedError( + f"Unknown schedule policy {self.schedule_policy}" + ) + + server_url = self.server_urls[server_idx] + # qid prompt (n samples) use the same dst server + self._qid_to_server_url[req_meta.qid] = server_url + self._server_request_counts[server_url] += 1 + self._server_token_usage[server_url] += ( + req_meta.prompt_len + + req_meta.new_token_budget * req_meta.group_size * 0.4 + ) + + version = self._last_param_realloc_step + return dict(url=server_url, version=version) @self.app.post("/get_model_version") async def get_model_version(req: ModelVersionReq): with self.threading_lock: - async with self.async_lock: - # FIXME: we may have different versions for different servers - version = self._last_param_realloc_step + # FIXME: we may have different versions for different servers + version = self._last_param_realloc_step return dict(version=version) - @self.app.get("/allocate_rollout") - async def allocate_rollout(): + @self.app.post("/allocate_rollout") + async def allocate_rollout(req: AllocateRolloutInput): with self.threading_lock: - async with self.async_lock: - has_capacity = ( - self.n_running_rollouts < self.config.max_concurrent_rollouts - ) - is_staled = await self.is_staled() - reason = "" - if has_capacity and not is_staled: - self.n_running_rollouts += 1 - self.n_total_rollouts += 1 - return dict(success=True, reason=reason) - else: - if not has_capacity: - reason += f"capacity: {self.n_running_rollouts} >= {self.config.max_concurrent_rollouts}" - if is_staled: - global_sample_cnt = self.n_total_rollouts - expected_version = ( - global_sample_cnt // self.config.train_batch_size - ) - reason += ( - f" and staled: expected version ({expected_version}) = " - f"global sample cnt ({global_sample_cnt}) // batch size ({self.config.train_batch_size}), " - f"current version {self._last_param_realloc_step}, " - f"offpolicyness {self.config.max_head_offpolicyness}." - ) - return dict(success=False, reason=reason) + has_capacity = ( + self.rollout_stat.running < self.config.max_concurrent_rollouts + ) + is_staled = self.is_staled() + reason = "" + if has_capacity and not is_staled: + self.rollout_stat.inc() + return dict(success=True, reason=reason) + else: + if not has_capacity: + reason += f"capacity: {self.rollout_stat.running} >= {self.config.max_concurrent_rollouts}" + if is_staled: + global_sample_cnt = self.rollout_stat.accepted + expected_version = ( + global_sample_cnt // self.config.train_batch_size + ) + version = self._last_param_realloc_step + reason += ( + f" and staled: expected version ({expected_version}) = " + f"global sample cnt ({global_sample_cnt}) // batch size ({self.config.train_batch_size}), " + f"current latest version {version}, " + f"offpolicyness {self.config.max_head_offpolicyness}." + ) + return dict(success=False, reason=reason) @self.app.post("/finish_rollout") async def finish_rollout(resp_meta: GenRespMeta): with self.threading_lock: - async with self.async_lock: - self.n_running_rollouts -= 1 - if resp_meta.accepted: - self.accepted_rollouts += 1 - return dict(success=True) + server_url = self._qid_to_server_url[resp_meta.qid] + self._server_request_counts[server_url] -= 1 + assert ( + self._server_request_counts[server_url] >= 0 + ), "server request count < 0" + self._qid_to_server_url.pop(resp_meta.qid) + self._gen_tokens += resp_meta.n_tokens + if resp_meta.accepted: + self.rollout_stat.accept() + else: + self.rollout_stat.reject() + return dict(success=True) - self.manager_addr = f"{network.gethostip()}:{network.find_free_port()}" + port = network.find_free_port( + experiment_name=self.experiment_name, + trial_name=self.trial_name, + ) + self.manager_addr = f"{network.gethostip()}:{port}" config = uvicorn.Config( self.app, @@ -343,12 +453,12 @@ class GserverManager(Worker): port=int(self.manager_addr.split(":")[1]), log_level="warning", ) - self.server = uvicorn.Server(config) - self.server.run() + self.manager_http_server = uvicorn.Server(config) + self.manager_http_server.run() def _exit_hook(self, exit_status): - if self.server: - self.server.should_exit = True + if self.manager_http_server: + self.manager_http_server.should_exit = True if self.thread: self.thread.join(timeout=3) logger.info("Server stopped") diff --git a/realhf/system/master_worker.py b/realhf/system/master_worker.py index 71383e2..7d27bde 100644 --- a/realhf/system/master_worker.py +++ b/realhf/system/master_worker.py @@ -145,6 +145,7 @@ class MasterWorker(worker_base.Worker): # for benchmark self.e2e_time_history = [] self.__benchmark_steps = config.exp_ctrl.benchmark_steps + self.__benchmark_n_seqs = config.exp_ctrl.benchmark_n_seqs return config.worker_info @@ -210,13 +211,14 @@ class MasterWorker(worker_base.Worker): src_rpc_dp_size = src_rpc_topo.get_dim("data") # Request training specification from data workers. - self._dataset_size = sum( - self.__stream.call( - handlers=[f"__data{i}__" for i in range(src_rpc_dp_size)], - datas=[None for i in range(src_rpc_dp_size)], - handle_type="spec", - ), + specs = self.__stream.call( + handlers=[f"__data{i}__" for i in range(src_rpc_dp_size)], + datas=[None for i in range(src_rpc_dp_size)], + handle_type="spec", ) + assert all(x["n_datasets"] == specs[0]["n_datasets"] for x in specs), specs + self._dataset_size = sum(x["dataset_size"] for x in specs) + self._n_datasets = specs[0]["n_datasets"] self._steps_per_epoch = self._dataset_size // src_rpc.n_seqs @@ -239,7 +241,7 @@ class MasterWorker(worker_base.Worker): src_rpc_dp_size = src_rpc_topo.get_dim("data") src_rpc_pp_size = src_rpc_topo.get_dim("pipe") for i in range(src_rpc_dp_size): - rank = src_rpc_topo.get_rank(data=i, pipe=src_rpc_pp_size - 1, model=0) + rank = src_rpc_topo.get_rank(data=i, pipe=src_rpc_pp_size - 1, tensor=0) handler_routing[f"__data{i}__"] = self.config.msid2mwid[ config_pkg.ModelShardID.from_parallelism_rank( model_name=src_rpc.model_name, @@ -263,10 +265,13 @@ class MasterWorker(worker_base.Worker): self.initialize_models() - self.__seqbuffer = AsyncIOSequenceBuffer( - self.__model_rpcs, - max_size=int(os.getenv("REAL_MASTER_BUFFER_SIZE", str(int(1e7)))), - ) + self.__seqbuffers = [ + AsyncIOSequenceBuffer( + self.__model_rpcs, + max_size=int(os.getenv("REAL_MASTER_BUFFER_SIZE", str(int(1e7)))), + ) + for _ in range(self._n_datasets) + ] # wandb init, connect to remote wandb host wandb.login() @@ -300,7 +305,7 @@ class MasterWorker(worker_base.Worker): rpcs=self.__model_rpcs, msid2mwid=self.config.msid2mwid, stream=self.__stream, - buffer=self.__seqbuffer, + buffers=self.__seqbuffers, model_topos=self.__model_topos, model_configs=self.__model_configs, ctrl=self.__rpc_ctrl, @@ -395,20 +400,33 @@ class MasterWorker(worker_base.Worker): # Pause the worker if experiment or system-wise benchmark completes. if ( - self.__benchmark_steps is not None - and self.__rpc_ctrl.step_info.global_step >= self.__benchmark_steps - ) or ( - self.__rpc_ctrl.step_info.global_step * self.__src_rpc.n_seqs - >= self.__total_train_epochs * self._dataset_size + ( + self.__benchmark_steps is not None + and self.__rpc_ctrl.step_info.global_step >= self.__benchmark_steps + ) + or ( + self.__rpc_ctrl.step_info.global_step * self.__src_rpc.n_seqs + >= self.__total_train_epochs * self._dataset_size + ) + or ( + self.__benchmark_n_seqs is not None + and self.__rpc_ctrl.step_info.global_step + * self._ft_spec.train_batch_size + >= self.__benchmark_n_seqs + ) ): # We don't know whether it is the last step of the current epoch, # so we exit at the first step of the next epoch. - if self.__benchmark_steps is not None: + if ( + self.__benchmark_steps is not None + or self.__benchmark_n_seqs is not None + ): logger.info( f"Finished benchmark {self.__benchmark_steps}. " f"Time consumption of this setup: {time_since_configure:.3f}" ) logger.info(f"avg #e2e# time *{np.mean(self.e2e_time_history):.3f}*") + # TODO: inform generation workers to exit return self.experiment_complete_exit() return worker_base.PollResult(sample_count=1, batch_count=1) @@ -439,9 +457,7 @@ class MasterWorker(worker_base.Worker): s += f"(global step {global_step}) finishes. " s += f"#End to end# execution time: *{e2e_time:.3f}*s. " s += f"Total time consumption: {time_since_configure:.3f}s. " - logging.log_wandb_tensorboard( - {"timeperf/e2e": e2e_time}, step=self.__rpc_ctrl.step_info.global_step - ) + logging.log_wandb_tensorboard({"timeperf/e2e": e2e_time}) if len(self.e2e_time_history) > 2: remaining_steps = self._steps_per_epoch - epoch_step remaining_epochs = self.__total_train_epochs - epoch diff --git a/realhf/system/model_function_call.py b/realhf/system/model_function_call.py index 9391fb7..35181a5 100644 --- a/realhf/system/model_function_call.py +++ b/realhf/system/model_function_call.py @@ -63,7 +63,7 @@ class ModelFunctionCall: model_topos: Dict[str, topology.ProcessTopology], model_configs: Dict[str, None | ReaLModelConfig], ctrl: RPCCorountineControl, - buffer: AsyncIOSequenceBuffer, + buffers: List[AsyncIOSequenceBuffer], redistrib_planner: RedistribPlanner, summary_writer: SummaryWriter | None, ): @@ -89,7 +89,7 @@ class ModelFunctionCall: ) self.rpc_ctrl = ctrl - self.buffer = buffer + self.buffers = buffers self.redistrib_planner = redistrib_planner self.summary_writer = summary_writer @@ -306,7 +306,7 @@ class ModelFunctionCall: ).partitions return buf_indices, sample, partitions - async def run_step(self, buf_indices, sample): + async def run_step(self, buf_indices, sample, buffer_id: int): rpc = self.rpc topo = self.model_topos[rpc.model_name] ctrl = self.rpc_ctrl @@ -317,7 +317,7 @@ class ModelFunctionCall: ] dp_head_indices = [ - topo.get_rank(data=i, pipe=topo.get_dim("pipe") - 1, model=0) + topo.get_rank(data=i, pipe=topo.get_dim("pipe") - 1, tensor=0) for i in range(self.dp_size) ] @@ -348,12 +348,7 @@ class ModelFunctionCall: if i not in dests: dests[i] = [] - # NOTE: The data loaded from the dataset may be unevenly distributed across DP ranks. - # Only bcast works in this case. - if rpc.is_src: - pattern = "bcast" - else: - pattern = "gather-scatter" + pattern = "gather-scatter" data_transfer_plan = self.redistrib_planner.derive_plan( dests, keys=rpc.input_keys, @@ -362,14 +357,14 @@ class ModelFunctionCall: blogger.info(f"Data tranfer plan for `{rpc.name}`: {data_transfer_plan}.") # Update storage tracker for transferred data. - if rpc.is_src: + if pattern == "bcast": # NOTE: since the data we loaded may be unevenly distributed across DP ranks, # we should change the owner of the data to the src RPC. for i in range(topo.world_size()): h = ModelShardID.from_parallelism_rank( model_name=rpc.model_name, topo=topo, parallelism_rank=i ) - is_dp_head = h.mp_rank == 0 and h.pp_rank == topo.get_dim("pipe") - 1 + is_dp_head = h.tp_rank == 0 and h.pp_rank == topo.get_dim("pipe") - 1 gpu_id = self.msid2mwid[h] for key in rpc.input_keys: await self.redistrib_planner.storage_tracker.add_data( @@ -414,13 +409,13 @@ class ModelFunctionCall: responses, time_records = list(zip(*[responses[i] for i in dp_head_indices])) # If the returned data is a SequenceSample, it is the data returned by - # model function calls. The data shoulbe be amended into buffer. + # model function calls. The data should be amended into buffer. # Otherwise, it's the train statistics and should be reduced and logged. if isinstance(responses[-1], data_api.SequenceSample): # Update storage tracker for generated data. for dp_rank, x in enumerate(responses): pp_size = topo.get_dim("pipe") - ranks = topo.filter_match(data=dp_rank, pipe=pp_size - 1, model=0) + ranks = topo.filter_match(data=dp_rank, pipe=pp_size - 1, tensor=0) for rank in ranks: h = config_pkg.ModelShardID.from_parallelism_rank( model_name=rpc.model_name, topo=topo, parallelism_rank=rank @@ -434,8 +429,14 @@ class ModelFunctionCall: is_owner=True, ) res = data_api.SequenceSample.gather(responses) - else: + elif isinstance(responses[0], dict): res = data_api.gather_stat(responses) + else: + assert isinstance(responses[0], list) + res = [ + data_api.gather_stat([r[i] for r in responses]) + for i in range(len(responses[0])) + ] if rpc.log_return_value: if isinstance(res, dict): @@ -447,6 +448,17 @@ class ModelFunctionCall: step=ctrl.step_info.global_step, summary_writer=self.summary_writer, ) + elif isinstance(res, list): + for j, r in enumerate(res): + logger.info( + f"RPC name {rpc.name} returns ({j}/{len(res)})\n{data_api.tabulate_stats(r)}" + ) + offset = len(res) * ctrl.step_info.global_step + logging.log_wandb_tensorboard( + r, + step=offset + j, + summary_writer=self.summary_writer, + ) else: logger.info(f"RPC name {rpc.name} returns\n{res}") @@ -456,7 +468,6 @@ class ModelFunctionCall: time_stats = stats_tracker.export() logging.log_wandb_tensorboard( time_stats, - step=ctrl.step_info.global_step, summary_writer=self.summary_writer, ) @@ -475,7 +486,7 @@ class ModelFunctionCall: await ctrl.train_count.put(1) else: logger.info(f"Amending RPC {rpc.name} output keys: {res.keys}") - await self.buffer.amend_batch(buf_indices, res.unpack()) + await self.buffers[buffer_id].amend_batch(buf_indices, res.unpack()) # Wait for all side-effect requests to finish. # Side-effect or empty requests are required for data transfer @@ -483,20 +494,20 @@ class ModelFunctionCall: # Wait them after the main request to log the oorrect MFC time. await self.stream.gather_async(other_req_ids) - async def run(self): + async def run(self, buffer_id: int): rpc = self.rpc topo = self.model_topos[rpc.model_name] logger.info( f"Running Model RPC, interface_type=#{rpc.interface_type}# " - f"(dp,mp,pp) = *({topo.get_dim('data')},{topo.get_dim('model')},{topo.get_dim('pipe')})*" + f"(dp,tp,pp) = *({topo.get_dim('data')},{topo.get_dim('tensor')},{topo.get_dim('pipe')})*" ) consumed = 0 while True: - buf_indices, sample = await self.buffer.get_batch_for_rpc(rpc) + buf_indices, sample = await self.buffers[buffer_id].get_batch_for_rpc(rpc) - await self.run_step(buf_indices, sample) + await self.run_step(buf_indices, sample, buffer_id) consumed += sample.bs # Ensure that parent RPCs will not be over-consumed. diff --git a/realhf/system/model_worker.py b/realhf/system/model_worker.py index 5c909b4..d93ef36 100644 --- a/realhf/system/model_worker.py +++ b/realhf/system/model_worker.py @@ -153,7 +153,7 @@ class ModelWorker(worker_base.Worker): ] for s in self.config.shards: _pp_size = s.id.topo.get_dim("pipe") - if not (s.id.mp_rank == 0 and s.id.pp_rank == _pp_size - 1): + if not (s.id.tp_rank == 0 and s.id.pp_rank == _pp_size - 1): continue if src_rpc.model_name == s.id.model_name: self.__has_dataset = True @@ -195,8 +195,8 @@ class ModelWorker(worker_base.Worker): return None @property - def _mp_rank(self) -> int: - return constants.model_parallel_rank() + def _tp_rank(self) -> int: + return constants.tensor_parallel_rank() @property def _pp_rank(self) -> int: @@ -211,8 +211,8 @@ class ModelWorker(worker_base.Worker): return constants.pipe_parallel_world_size() @property - def _mp_size(self) -> int: - return constants.model_parallel_world_size() + def _tp_size(self) -> int: + return constants.tensor_parallel_world_size() @property def _dp_size(self) -> int: @@ -220,7 +220,7 @@ class ModelWorker(worker_base.Worker): @property def _is_dp_head(self) -> bool: - return self._mp_rank == 0 and self._pp_rank == self._pp_size - 1 + return self._tp_rank == 0 and self._pp_rank == self._pp_size - 1 @property def _model(self) -> model_api.Model: @@ -302,6 +302,7 @@ class ModelWorker(worker_base.Worker): constants.set_grid(model_name_, grid) # Set up training dataset for source RPCs. + self.__datasets = [] if self.__has_dataset: datasets = [ data_api.make_dataset( @@ -321,31 +322,34 @@ class ModelWorker(worker_base.Worker): ) for d in self.config.datasets ] - if len(self.config.datasets) == 1: - self.__dataset = datasets[0] - else: - self.__dataset = torch.utils.data.ConcatDataset(datasets) + self.__datasets = datasets - g = torch.Generator() - g.manual_seed(seeding.get_seed()) - dataloader_kwargs = dict( - shuffle=self.config.shuffle_dataset, - generator=g, - ) - if not isinstance(self.__dataset, PullerStreamDataset): - dataloader_kwargs["collate_fn"] = data_api.SequenceSample.gather - # NOTE: This is *NOT* the actual batch size for training. - # It is just a proper size to load data to workers. - dataloader_kwargs["batch_size"] = 10240 - else: - dataloader_kwargs["batch_size"] = None - self.__dataloader = torch.utils.data.DataLoader( - self.__dataset, **dataloader_kwargs - ) + self.__dataloaders: List[ + torch.utils.data.DataLoader[data_api.SequenceSample] + ] = [] + for i, d in enumerate(self.__datasets): + g = torch.Generator() + g.manual_seed( + self.config.base_seed + seeding._seed_from_key(f"__dataloader{i}__") + ) + dataloader_kwargs = dict( + shuffle=self.config.shuffle_dataset, + generator=g, + ) + if not isinstance(d, PullerStreamDataset): + dataloader_kwargs["collate_fn"] = data_api.SequenceSample.gather + # NOTE: This is *NOT* the actual batch size for training. + # It is just a proper size to load data to workers. + dataloader_kwargs["batch_size"] = 10240 + else: + dataloader_kwargs["batch_size"] = None + self.__dataloaders.append( + torch.utils.data.DataLoader(d, **dataloader_kwargs) + ) - self.dataset_size = len(self.__dataset) + self.dataset_size = sum(len(d) for d in self.__datasets) - self.__data_generator = enumerate(self.__dataloader) + self.__data_generators = [enumerate(d) for d in self.__dataloaders] self.__models: Dict[ModelName, model_api.Model] = dict() self.__model_is_handle: Dict[ModelName, bool] = dict() @@ -377,25 +381,26 @@ class ModelWorker(worker_base.Worker): ) # Recover indices for dynamic dataset - if ( - s.id.model_name == self.src_rpc.model_name - and self.__has_dataset - and hasattr(self.__dataset, "filter") - ): - dataset_indices_path = os.path.join( - constants.MODEL_SAVE_ROOT, - constants.experiment_name(), - constants.trial_name(), - "dataset_indices", - f"{self._dp_rank}.npy", - ) - if os.path.exists(dataset_indices_path): - indices = np.load(dataset_indices_path).tolist() - logger.info( - f"DP rank {self._dp_rank} updating dataset indices upon recover, " - f"size {len(self.__dataset.active_indices)} -> {len(indices)}" + for i, d in enumerate(self.__datasets): + if ( + s.id.model_name == self.src_rpc.model_name + and self.__has_dataset + and hasattr(d, "filter") + ): + dataset_indices_path = os.path.join( + constants.MODEL_SAVE_ROOT, + constants.experiment_name(), + constants.trial_name(), + "dataset_indices", + f"{self._dp_rank}_{i}.npy", ) - self.__dataset.active_indices = indices + if os.path.exists(dataset_indices_path): + indices = np.load(dataset_indices_path).tolist() + logger.info( + f"DP rank {self._dp_rank} updating dataset indices upon recover, " + f"size {len(d.active_indices)} -> {len(indices)}" + ) + d.active_indices = indices if constants.parallelism_rank() == 0: self.logger.info( @@ -537,9 +542,13 @@ class ModelWorker(worker_base.Worker): cache = [] while True: try: - request, data, handled, res, time_record = ( - self.__request_queue.get_nowait() - ) + ( + request, + data, + handled, + res, + time_record, + ) = self.__request_queue.get_nowait() request: request_reply_stream.Payload if not handled: while len(request.pre_hooks) > 0: @@ -582,9 +591,13 @@ class ModelWorker(worker_base.Worker): elif request.handle_name == "fetch": dp_rank = int(re.search(r"__data(\d+)__", request.handler).group(1)) assert self.__has_dataset + assert isinstance(request.data, int), request.data + dataset_id = request.data # Fetch. try: - self.__dataset_batch_counter, cur_sample = next(self.__data_generator) + self.__dataset_batch_counter, cur_sample = next( + self.__data_generators[dataset_id] + ) except StopIteration: # Upon the first fetch request, filter dataset and create dataloader. eval_scores_path = os.path.join( @@ -598,39 +611,43 @@ class ModelWorker(worker_base.Worker): constants.experiment_name(), constants.trial_name(), "dataset_indices", - f"{dp_rank}.npy", + f"{dp_rank}_{dataset_id}.npy", ) os.makedirs(os.path.dirname(dataset_indices_path), exist_ok=True) - if hasattr(self.__dataset, "filter") and os.path.exists( + if hasattr(self.__datasets[dataset_id], "filter") and os.path.exists( eval_scores_path ): # Don't filter dataset on the first poll after recover. with open(eval_scores_path, "r", encoding="utf-8") as f: dataset_eval_scores = json.load(f) - self.__dataset.filter(dataset_eval_scores) + self.__datasets[dataset_id].filter(dataset_eval_scores) # Save the dataset indices after filtering np.save( dataset_indices_path, - self.__dataset.active_indices, + self.__datasets[dataset_id].active_indices, ) g = torch.Generator() - g = g.set_state(self.__dataloader.generator.get_state()) + g = g.set_state(self.__dataloaders[dataset_id].generator.get_state()) dataloader_kwargs = dict( shuffle=self.config.shuffle_dataset, generator=g, ) - if not isinstance(self.__dataset, PullerStreamDataset): + if not isinstance(self.__datasets[dataset_id], PullerStreamDataset): dataloader_kwargs["collate_fn"] = data_api.SequenceSample.gather # NOTE: This is *NOT* the actual batch size for training. # It is just a proper size to load data to workers. dataloader_kwargs["batch_size"] = 10240 else: dataloader_kwargs["batch_size"] = None - self.__dataloader = torch.utils.data.DataLoader( - self.__dataset, **dataloader_kwargs + self.__dataloaders[dataset_id] = torch.utils.data.DataLoader( + self.__datasets[dataset_id], **dataloader_kwargs + ) + self.__data_generators[dataset_id] = enumerate( + self.__dataloaders[dataset_id] + ) + self.__dataset_batch_counter, cur_sample = next( + self.__data_generators[dataset_id] ) - self.__data_generator = enumerate(self.__dataloader) - self.__dataset_batch_counter, cur_sample = next(self.__data_generator) if isinstance(cur_sample, data_api.SequenceSample): samples = cur_sample.unpack() @@ -663,7 +680,10 @@ class ModelWorker(worker_base.Worker): ) elif request.handle_name == "spec": # Raw dataset without filtering. - res = self.dataset_size + res = { + "n_datasets": len(self.__datasets), + "dataset_size": self.dataset_size, + } elif request.handle_name == "clear_data_cache": with cuda_tmarked("clear_data_cache", CUDATimeMarkType.misc): ids = request.data @@ -772,8 +792,10 @@ class ModelWorker(worker_base.Worker): if hook == "evaluate": assert request.handle_name == "train_step", request.handle_name assert isinstance(ret, dict), ret - assert isinstance(res, dict), res - res.update(ret) + if isinstance(res, dict): + res.update(ret) + else: + res[0].update(ret) time_record[ f"timeperf/{request.handler.model_name.role}_{request.handle_name}/post-{hook}" ] += (time.perf_counter() - tik) @@ -803,13 +825,7 @@ class ModelWorker(worker_base.Worker): with constants.model_scope(model_name): dist.barrier(group=constants.cpu_parallelism_group()) if constants.parallelism_rank() == 0: - name_resolve.add( - name, - str(global_step), - delete_on_exit=False, - keepalive_ttl=30, - replace=True, - ) + name_resolve.add(name, str(global_step), replace=True) time_record[ f"timeperf/{request.handler.model_name.role}_{request.handle_name}/param-sync-save" ] += (time.perf_counter() - tik) @@ -867,7 +883,7 @@ class ModelWorker(worker_base.Worker): if len(self.__performance_recorder) == 0: self.__performance_recorder["info"] = { "pipeline_size": self._pp_size, - "model_size": self._mp_size, + "model_size": self._tp_size, "data_size": self._dp_size, "rank": constants.parallelism_rank(), "sequence_parallel_enabled": constants.sequence_parallel(), @@ -1374,9 +1390,13 @@ class ModelWorker(worker_base.Worker): rescheduled_requests = [] other_requests = [] for _ in range(self.__request_queue.qsize()): - request, data, handled, res, time_record = ( - self.__request_queue.get_nowait() - ) + ( + request, + data, + handled, + res, + time_record, + ) = self.__request_queue.get_nowait() if request.handle_name not in ["inference", "generate", "train_step"]: other_requests.append((request, data, handled, res, time_record)) else: @@ -1399,9 +1419,13 @@ class ModelWorker(worker_base.Worker): # we can correctly log the time consumption in the master worker. while True: try: - request, data, handled, res, time_record = ( - self.__request_queue.get_nowait() - ) + ( + request, + data, + handled, + res, + time_record, + ) = self.__request_queue.get_nowait() self.handle_blocking_request( request, data, handled, res, time_record ) diff --git a/realhf/system/partial_rollout.py b/realhf/system/partial_rollout.py index 59b4d13..9483075 100644 --- a/realhf/system/partial_rollout.py +++ b/realhf/system/partial_rollout.py @@ -103,9 +103,14 @@ class PartialRolloutManager: ): from realhf.impl.model.backend.sglang import SGLangAPIClient + max_new_tokens = min(raw_gconfig.max_new_tokens, self.new_tokens_per_chunk) + max_new_tokens = min( + max_new_tokens, + raw_gconfig.max_new_tokens - len(input_ids) + len(prompt_ids), + ) gconfig = raw_gconfig.new( n=1, - max_new_tokens=min(raw_gconfig.max_new_tokens, self.new_tokens_per_chunk), + max_new_tokens=max_new_tokens, ) assert self.tokenizer.pad_token_id is not None assert self.tokenizer.eos_token_id is not None @@ -130,6 +135,7 @@ class PartialRolloutManager: group_idx=group_idx, raw_gconfig=raw_gconfig, server_url=url, + version=cur_server_version, ), ), stream=False, @@ -190,6 +196,7 @@ class PartialRolloutManager: s: APIGenerateOutput = await task group_idx = s.metadata["group_idx"] raw_gconfig = s.metadata["raw_gconfig"] + previous_version = s.metadata["version"] assert s.group_size == 1 no_eos = s.no_eos[0] @@ -202,20 +209,27 @@ class PartialRolloutManager: if no_eos and gen_len < raw_gconfig.max_new_tokens: # Unfinished request due to chunked generation. # Send it back to continue. - async with aiohttp.ClientSession() as session: - async with session.post( - f"http://{self.gserver_manager_addr}/get_model_version", - json=dict(server_url=s.metadata["server_url"]), - timeout=ClientTimeout(total=self.timeout, sock_connect=30), - ) as resp: - resp.raise_for_status() - cur_version = (await resp.json())["version"] + req_meta = GenReqMeta( + qid=s.qid, + prompt_len=s.prompt_len, + group_size=raw_gconfig.n, + new_token_budget=raw_gconfig.max_new_tokens, + predicted_new_tokens=None, + previous_server_url=s.metadata["server_url"], + previous_version=previous_version, + ) + info = await self._schedule_request(req_meta) + cur_version = info["version"] + server_url = info["url"] + if len(s.output_logprobs) > 0: prev_logprobs = s.prev_logprobs + s.output_logprobs[0] else: - prev_logprobs = [] + prev_logprobs = s.prev_logprobs + if prev_logprobs is None: + prev_logprobs = [] await self._issue_generation( - s.metadata["server_url"], + server_url, s.qid, group_idx, s.prompt_ids, @@ -240,9 +254,10 @@ class PartialRolloutManager: try: qid, prompt_token_ids, gconfig = self.request_queue.get_nowait() req_meta = GenReqMeta( + qid=qid, prompt_len=len(prompt_token_ids), group_size=gconfig.n, - new_token_budget=self.new_tokens_per_chunk, + new_token_budget=gconfig.max_new_tokens, predicted_new_tokens=None, ) dst_server_info = await self._schedule_request(req_meta) diff --git a/realhf/system/push_pull_stream.py b/realhf/system/push_pull_stream.py index 3b4b253..106dab6 100644 --- a/realhf/system/push_pull_stream.py +++ b/realhf/system/push_pull_stream.py @@ -171,7 +171,9 @@ class NameResolvingZmqPuller(ZMQJsonPuller): name = names.push_pull_stream( experiment_name, trial_name, stream_name=f"puller{puller_index}" ) - host, port = network.gethostip(), network.find_free_port() + host, port = network.gethostip(), network.find_free_port( + experiment_name=experiment_name, trial_name=trial_name + ) addr = f"{host}:{port}" name_resolve.add(name, addr) super().__init__(host, port, **kwargs) diff --git a/realhf/system/rollout_worker.py b/realhf/system/rollout_worker.py index ecee758..765521c 100644 --- a/realhf/system/rollout_worker.py +++ b/realhf/system/rollout_worker.py @@ -189,10 +189,11 @@ class RolloutWorker(AsyncWorker): assert data_id not in self.rollout_tasks return cur_sample - async def allocate_new_rollout(self) -> bool: + async def allocate_new_rollout(self, qid) -> bool: async with aiohttp.ClientSession() as session: - async with session.get( + async with session.post( f"http://{self.gserver_manager_addr}/allocate_rollout", + json=dict(qid=qid), timeout=ClientTimeout( total=self.config.rollout_request_timeout, sock_connect=30 ), @@ -231,10 +232,10 @@ class RolloutWorker(AsyncWorker): self._cur_data = self.load_next_data() if self._cur_data is not None: - can_rollout = await self.allocate_new_rollout() + data = self._cur_data + qid = data.ids[0] + can_rollout = await self.allocate_new_rollout(qid) if can_rollout: - data = self._cur_data - qid = data.ids[0] self.act_queues[qid] = asyncio.Queue(1024) task = asyncio.create_task(self.rollout_task(qid, data)) @@ -265,7 +266,11 @@ class RolloutWorker(AsyncWorker): accepted = True self.push_stream.push([traj.as_json_compatible() for traj in trajs]) - info = dict(qid=qid, accepted=accepted) + n_tokens = 0 + for traj in trajs: + seqlens = [sum(datapack.flat2d(ss)) for ss in traj.seqlens.values()] + n_tokens += max(seqlens) + info = dict(qid=qid, accepted=accepted, n_tokens=n_tokens) async with aiohttp.ClientSession( f"http://{self.gserver_manager_addr}" ) as session: diff --git a/requirements.txt b/requirements.txt index bc95f23..71baf79 100644 --- a/requirements.txt +++ b/requirements.txt @@ -36,7 +36,6 @@ ray redis scipy seaborn -setuptools>=61.0 tqdm networkx==3.3 matplotlib @@ -59,3 +58,5 @@ protobuf<3.21 rich orjson>=3.10.16 flask +setuptools>=62.3.0,<75.9 +func_timeout diff --git a/setup.py b/setup.py index 525c3a0..165bfc3 100644 --- a/setup.py +++ b/setup.py @@ -246,30 +246,6 @@ if not no_ext and _is_cuda(): ext_modules.append(interval_op_cuda) if not no_ext: - search_extension = setuptools.Extension( - name="realhf._C.mdm_search", - sources=[ - "csrc/search/search.cpp", - "csrc/search/rpc.cpp", - "csrc/search/device_mesh.cpp", - "csrc/search/simulate.cpp", - ], - language="c++", - extra_compile_args=[ - "-O3", - "-Wall", - "-shared", - "-std=c++11", - "-fPIC", - "-std=c++17", - ], - include_dirs=[ - os.path.join(os.path.abspath(os.path.dirname(__file__)), "csrc", "search"), - get_pybind11_include_path(), - ], - ) - ext_modules.append(search_extension) - interval_extension = setuptools.Extension( name="realhf._C.interval_op", sources=[ diff --git a/tests/comm/test_data_transfer.py b/tests/comm/test_data_transfer.py index e741066..c718266 100644 --- a/tests/comm/test_data_transfer.py +++ b/tests/comm/test_data_transfer.py @@ -25,28 +25,28 @@ from realhf.system.redistributor import GlobalStorageTracker, RedistribPlanner def get_data_manager( from_model_name, to_model_name, - from_pp_dp_mp, - to_pp_dp_mp, + from_pp_dp_tp, + to_pp_dp_tp, ): - from_num_pp, from_num_dp, from_num_mp = from_pp_dp_mp - to_num_pp, to_num_dp, to_num_mp = to_pp_dp_mp + from_num_pp, from_num_dp, from_num_tp = from_pp_dp_tp + to_num_pp, to_num_dp, to_num_tp = to_pp_dp_tp - from_world_size = from_num_dp * from_num_mp * from_num_pp - to_world_size = to_num_dp * to_num_mp * to_num_pp + from_world_size = from_num_dp * from_num_tp * from_num_pp + to_world_size = to_num_dp * to_num_tp * to_num_pp - from_topo = topology.PipeDataModelParallelTopology( + from_topo = topology.PipeDataTensorParallelTopology( num_dp=from_num_dp, - num_mp=from_num_mp, + num_tp=from_num_tp, num_pp=from_num_pp, sequence_parallel=False, gradient_checkpointing=False, max_prompt_len=None, gradient_accumulation_fusion=False, ) - to_topo = topology.PipeDataModelParallelTopology( + to_topo = topology.PipeDataTensorParallelTopology( num_dp=to_num_dp, - num_mp=to_num_mp, + num_tp=to_num_tp, num_pp=to_num_pp, sequence_parallel=False, gradient_checkpointing=False, @@ -80,7 +80,7 @@ def get_data_manager( k = ModelShardID( _model_name, dp_rank=coord.data, - mp_rank=coord.model, + tp_rank=coord.tensor, pp_rank=coord.pipe, topo=model_topos[_model_name], ) @@ -88,7 +88,7 @@ def get_data_manager( init_global_constants( num_dp=from_num_dp, - num_mp=from_num_mp, + num_tp=from_num_tp, num_pp=from_num_pp, topo=from_topo, model_name=from_model_name, @@ -98,7 +98,7 @@ def get_data_manager( init_global_constants( num_dp=to_num_dp, - num_mp=to_num_mp, + num_tp=to_num_tp, num_pp=to_num_pp, model_name=to_model_name, sequence_parallel=False, @@ -134,24 +134,24 @@ def recursive_assert_equal(x1, x2): def _test_data_transfer( tmp_path, - from_pp_dp_mp: Tuple, - to_pp_dp_mp: Tuple, + from_pp_dp_tp: Tuple, + to_pp_dp_tp: Tuple, ): from_model_name = ModelName("data_transfer_test", 0) - from_topo = topology.PipeDataModelParallelTopology( - num_pp=from_pp_dp_mp[0], - num_mp=from_pp_dp_mp[-1], - num_dp=from_pp_dp_mp[1], + from_topo = topology.PipeDataTensorParallelTopology( + num_pp=from_pp_dp_tp[0], + num_tp=from_pp_dp_tp[-1], + num_dp=from_pp_dp_tp[1], sequence_parallel=True, gradient_checkpointing=True, gradient_accumulation_fusion=True, ) to_model_name = ModelName("data_transfer_test", 1) - to_topo = topology.PipeDataModelParallelTopology( - num_pp=to_pp_dp_mp[0], - num_mp=to_pp_dp_mp[-1], - num_dp=to_pp_dp_mp[1], + to_topo = topology.PipeDataTensorParallelTopology( + num_pp=to_pp_dp_tp[0], + num_tp=to_pp_dp_tp[-1], + num_dp=to_pp_dp_tp[1], sequence_parallel=True, gradient_checkpointing=True, gradient_accumulation_fusion=True, @@ -160,8 +160,8 @@ def _test_data_transfer( data_manager = get_data_manager( from_model_name, to_model_name, - from_pp_dp_mp, - to_pp_dp_mp, + from_pp_dp_tp, + to_pp_dp_tp, ) data_manager.setup_process_groups() @@ -172,13 +172,13 @@ def _test_data_transfer( world_size = dist.get_world_size() samples = [] - for dp_rank in range(from_pp_dp_mp[1]): + for dp_rank in range(from_pp_dp_tp[1]): gpu_id = data_manager.msid2mwid[ ModelShardID( from_model_name, dp_rank=dp_rank, - mp_rank=0, - pp_rank=from_pp_dp_mp[0] - 1, + tp_rank=0, + pp_rank=from_pp_dp_tp[0] - 1, topo=from_topo, ) ] @@ -230,7 +230,7 @@ def _test_data_transfer( ModelShardID( to_model_name, dp_rank=coord.data, - mp_rank=coord.model, + tp_rank=coord.tensor, pp_rank=coord.pipe, topo=to_topo, ) @@ -260,13 +260,13 @@ parallelism = [(1, 4, 2), (1, 8, 1)] os.cpu_count() < 32 or testing.get_free_mem_gb() < 50, reason="The parameter reallocation test requires at least 32 CPUs and 50GB memory.", ) -@pytest.mark.parametrize("from_pp_dp_mp", [(1, 4, 2)]) -@pytest.mark.parametrize("to_pp_dp_mp", [(1, 8, 1)]) +@pytest.mark.parametrize("from_pp_dp_tp", [(1, 4, 2)]) +@pytest.mark.parametrize("to_pp_dp_tp", [(1, 8, 1)]) @pytest.mark.distributed def test_data_transfer( tmp_path, - from_pp_dp_mp: Tuple, - to_pp_dp_mp: Tuple, + from_pp_dp_tp: Tuple, + to_pp_dp_tp: Tuple, ): expr_name = uuid.uuid4() trial_name = uuid.uuid4() @@ -278,7 +278,7 @@ def test_data_transfer( trial_name=trial_name, timeout_secs=300, tmp_path=tmp_path, - from_pp_dp_mp=from_pp_dp_mp, - to_pp_dp_mp=to_pp_dp_mp, + from_pp_dp_tp=from_pp_dp_tp, + to_pp_dp_tp=to_pp_dp_tp, ) test_impl.launch() diff --git a/tests/comm/test_param_realloc.py b/tests/comm/test_param_realloc.py index 0131e29..557bf4e 100644 --- a/tests/comm/test_param_realloc.py +++ b/tests/comm/test_param_realloc.py @@ -128,29 +128,29 @@ def build_engine(module, model_name, trainable) -> "MockTrainEngine": def setup_constants_and_param_realloc( from_model_name, to_model_name, - from_pp_dp_mp, - to_pp_dp_mp, + from_pp_dp_tp, + to_pp_dp_tp, ): from realhf.impl.model.comm.param_realloc import setup_param_realloc - from_num_pp, from_num_dp, from_num_mp = from_pp_dp_mp - to_num_pp, to_num_dp, to_num_mp = to_pp_dp_mp + from_num_pp, from_num_dp, from_num_tp = from_pp_dp_tp + to_num_pp, to_num_dp, to_num_tp = to_pp_dp_tp - from_world_size = from_num_dp * from_num_mp * from_num_pp - to_world_size = to_num_dp * to_num_mp * to_num_pp + from_world_size = from_num_dp * from_num_tp * from_num_pp + to_world_size = to_num_dp * to_num_tp * to_num_pp - from_topo = topology.PipeDataModelParallelTopology( + from_topo = topology.PipeDataTensorParallelTopology( num_dp=from_num_dp, - num_mp=from_num_mp, + num_tp=from_num_tp, num_pp=from_num_pp, sequence_parallel=False, gradient_checkpointing=False, max_prompt_len=None, gradient_accumulation_fusion=False, ) - to_topo = topology.PipeDataModelParallelTopology( + to_topo = topology.PipeDataTensorParallelTopology( num_dp=to_num_dp, - num_mp=to_num_mp, + num_tp=to_num_tp, num_pp=to_num_pp, sequence_parallel=False, gradient_checkpointing=False, @@ -184,7 +184,7 @@ def setup_constants_and_param_realloc( k = ModelShardID( _model_name, dp_rank=coord.data, - mp_rank=coord.model, + tp_rank=coord.tensor, pp_rank=coord.pipe, topo=model_topos[_model_name], ) @@ -192,7 +192,7 @@ def setup_constants_and_param_realloc( init_global_constants( num_dp=from_num_dp, - num_mp=from_num_mp, + num_tp=from_num_tp, num_pp=from_num_pp, topo=from_topo, model_name=from_model_name, @@ -202,7 +202,7 @@ def setup_constants_and_param_realloc( init_global_constants( num_dp=to_num_dp, - num_mp=to_num_mp, + num_tp=to_num_tp, num_pp=to_num_pp, model_name=to_model_name, sequence_parallel=False, @@ -320,8 +320,8 @@ def _test_para_realloc( tmp_path: pathlib.Path, model_family_name: str, is_critic: bool, - from_pp_dp_mp: Tuple, - to_pp_dp_mp: Tuple, + from_pp_dp_tp: Tuple, + to_pp_dp_tp: Tuple, n_iterations: int, skip_saveload: bool, ): @@ -339,12 +339,12 @@ def _test_para_realloc( pg_info = setup_constants_and_param_realloc( from_model_name, to_model_name, - from_pp_dp_mp, - to_pp_dp_mp, + from_pp_dp_tp, + to_pp_dp_tp, ) # Create model 1 - if dist.get_rank() < from_pp_dp_mp[0] * from_pp_dp_mp[1] * from_pp_dp_mp[2]: + if dist.get_rank() < from_pp_dp_tp[0] * from_pp_dp_tp[1] * from_pp_dp_tp[2]: from_model = create_model( tmp_dir=tmp_path, model_family_name=model_family_name, @@ -357,7 +357,7 @@ def _test_para_realloc( # Creat model 2 if ( dist.get_rank() - >= dist.get_world_size() - to_pp_dp_mp[0] * to_pp_dp_mp[1] * to_pp_dp_mp[2] + >= dist.get_world_size() - to_pp_dp_tp[0] * to_pp_dp_tp[1] * to_pp_dp_tp[2] ): to_model = create_model( tmp_dir=tmp_path, @@ -532,19 +532,19 @@ parallelism = [(4, 1, 1), (2, 2, 2), (1, 8, 1), (3, 2, 1), (2, 1, 2), (1, 2, 2)] ) @pytest.mark.parametrize("model_family_name", ["gpt2", "llama"]) @pytest.mark.parametrize("is_critic", [False, True]) -@pytest.mark.parametrize("from_pp_dp_mp", parallelism) -@pytest.mark.parametrize("to_pp_dp_mp", parallelism) +@pytest.mark.parametrize("from_pp_dp_tp", parallelism) +@pytest.mark.parametrize("to_pp_dp_tp", parallelism) @pytest.mark.parametrize("skip_saveload", [False]) @pytest.mark.distributed def test_param_realloc( tmp_path: pathlib.Path, model_family_name: str, is_critic: bool, - from_pp_dp_mp: Tuple, - to_pp_dp_mp: Tuple, + from_pp_dp_tp: Tuple, + to_pp_dp_tp: Tuple, skip_saveload: bool, ): - if model_family_name == "gpt2" and (from_pp_dp_mp[-1] > 1 or to_pp_dp_mp[-1] > 1): + if model_family_name == "gpt2" and (from_pp_dp_tp[-1] > 1 or to_pp_dp_tp[-1] > 1): # Since the vocabulary size of gpt2 is odd, # it does not support tensor model parallelism. return @@ -560,8 +560,8 @@ def test_param_realloc( tmp_path=tmp_path, model_family_name=model_family_name, is_critic=is_critic, - from_pp_dp_mp=from_pp_dp_mp, - to_pp_dp_mp=to_pp_dp_mp, + from_pp_dp_tp=from_pp_dp_tp, + to_pp_dp_tp=to_pp_dp_tp, n_iterations=3, skip_saveload=skip_saveload, ) diff --git a/tests/cpp_extensions/test_grouped_gemm.py b/tests/cpp_extensions/test_grouped_gemm.py index 2487981..d1bfce5 100644 --- a/tests/cpp_extensions/test_grouped_gemm.py +++ b/tests/cpp_extensions/test_grouped_gemm.py @@ -14,7 +14,7 @@ import realhf.base.testing as testing # This is a test for grouped_gemm experts implementation of MoE. @torch.no_grad() -def run_grouped_mlp(num_tokens, mp_size, token_dispatch_strategy, seed=1): +def run_grouped_mlp(num_tokens, tp_size, token_dispatch_strategy, seed=1): # inline import to avoid torch re-initialize from realhf.api.core.model_api import ReaLModelConfig from realhf.impl.model.modules.moe.experts import GroupedMLP, SequentialMLP @@ -29,7 +29,7 @@ def run_grouped_mlp(num_tokens, mp_size, token_dispatch_strategy, seed=1): testing.init_global_constants( num_dp=1, - num_mp=mp_size, + num_tp=tp_size, num_pp=1, sequence_parallel=False, # grouped gemm does not support sequence parallel max_prompt_len=128, # useless value in this test @@ -85,14 +85,14 @@ def run_grouped_mlp(num_tokens, mp_size, token_dispatch_strategy, seed=1): t2 = time.perf_counter() - st print( - f"rank {constants.model_parallel_rank()}: " + f"rank {constants.tensor_parallel_rank()}: " f"{token_dispatch_strategy} diff: {(o1 - o2).abs().max()}: time {t1:.4f} {t2:.4f}" ) # NOTE: With some input shapes, there are possibility that # GroupedMLP and SequentialMLP produce results of around 2% difference # due to grouped_gemm implementation assert torch.allclose(o1, o2, rtol=0.02), ( - constants.model_parallel_rank(), + constants.tensor_parallel_rank(), token_dispatch_strategy, (o1 - o2).abs().max(), o1.abs().max(), @@ -104,20 +104,20 @@ def run_grouped_mlp(num_tokens, mp_size, token_dispatch_strategy, seed=1): reason="This test requires GPU to run", ) @pytest.mark.parametrize("num_tokens", [200]) -@pytest.mark.parametrize("mp_size", [1, 2]) +@pytest.mark.parametrize("tp_size", [1, 2]) @pytest.mark.parametrize("token_dispatch_strategy", ["random"]) @pytest.mark.gpu @pytest.mark.distributed def test_grouped_mlp( num_tokens, - mp_size, + tp_size, token_dispatch_strategy, ): test = testing.LocalMultiProcessTest( - mp_size, + tp_size, run_grouped_mlp, num_tokens, - mp_size, + tp_size, token_dispatch_strategy, ) test.launch() diff --git a/tests/distributed/test_find_port.py b/tests/distributed/test_find_port.py new file mode 100644 index 0000000..e07efac --- /dev/null +++ b/tests/distributed/test_find_port.py @@ -0,0 +1,54 @@ +import multiprocessing + +from realhf.base import name_resolve, names, testing +from realhf.base.network import find_multiple_free_ports, gethostip + + +def _worker_process(result_queue, count, low, high, experiment, trial): + """Helper function for multi-process testing.""" + ports = find_multiple_free_ports( + count, + low=low, + high=high, + experiment_name=experiment, + trial_name=trial, + ) + for port in ports: + result_queue.put(port) + + +def test_find_free_port_multiprocess(): + """Test that multiple processes get different ports.""" + num_processes = 100 + experiment = "multi_port_test" + trial = "trial1" + + testing.clear_name_resolve(experiment, trial) + + result_queue = multiprocessing.Queue() + count = 2 + processes = [] + + for _ in range(num_processes): + p = multiprocessing.Process( + target=_worker_process, + args=(result_queue, count, 10000, 60000, experiment, trial), + ) + processes.append(p) + p.start() + + for p in processes: + p.join() + assert p.exitcode == 0 + + results = [] + while not result_queue.empty(): + results.append(result_queue.get()) + + assert len(results) == num_processes * count + assert len(set(results)) == num_processes * count # All ports are unique + + # Verify all ports are registered in name_resolve + ports_name = names.used_ports(experiment, trial, gethostip()) + used_ports = list(map(int, name_resolve.get_subtree(ports_name))) + assert set(results).issubset(set(used_ports)) diff --git a/tests/experiments/test_buffer_recover.py b/tests/experiments/test_buffer_recover.py index e93aac5..b7a5a35 100644 --- a/tests/experiments/test_buffer_recover.py +++ b/tests/experiments/test_buffer_recover.py @@ -80,7 +80,7 @@ def test_buffer_recover( inf=MFCConfig( device_mesh="NODE01:0,1,2,3,4,5,6,7", parallel=ParallelismConfig( - model_parallel_size=2, + tensor_parallel_size=2, pipeline_parallel_size=2, data_parallel_size=dp // 2, ), @@ -88,7 +88,7 @@ def test_buffer_recover( train=MFCConfig( device_mesh="NODE01:8,9,10,11,12,13,14,15", parallel=ParallelismConfig( - model_parallel_size=2, + tensor_parallel_size=2, pipeline_parallel_size=2, data_parallel_size=dp // 2, ), diff --git a/tests/experiments/test_math_ppo.py b/tests/experiments/test_math_ppo.py index 823e724..e15617e 100644 --- a/tests/experiments/test_math_ppo.py +++ b/tests/experiments/test_math_ppo.py @@ -196,14 +196,14 @@ def test_ppo_global_reshard( actor_gen=MFCConfig( parallel=ParallelismConfig( data_parallel_size=actor_gen[0], - model_parallel_size=actor_gen[1], + tensor_parallel_size=actor_gen[1], pipeline_parallel_size=actor_gen[2], ) ), actor_train=MFCConfig( parallel=ParallelismConfig( data_parallel_size=actor_train[0], - model_parallel_size=actor_train[1], + tensor_parallel_size=actor_train[1], pipeline_parallel_size=actor_train[2], ), ), @@ -211,7 +211,7 @@ def test_ppo_global_reshard( mb_spec=MicroBatchSpec(max_tokens_per_mb=32), parallel=ParallelismConfig( data_parallel_size=critic_inf[0], - model_parallel_size=critic_inf[1], + tensor_parallel_size=critic_inf[1], pipeline_parallel_size=critic_inf[2], ), ), @@ -219,7 +219,7 @@ def test_ppo_global_reshard( mb_spec=MicroBatchSpec(max_tokens_per_mb=128), parallel=ParallelismConfig( data_parallel_size=rew_inf[0], - model_parallel_size=rew_inf[1], + tensor_parallel_size=rew_inf[1], pipeline_parallel_size=rew_inf[2], ), ), @@ -227,14 +227,14 @@ def test_ppo_global_reshard( mb_spec=MicroBatchSpec(max_tokens_per_mb=256), parallel=ParallelismConfig( data_parallel_size=ref_inf[0], - model_parallel_size=ref_inf[1], + tensor_parallel_size=ref_inf[1], pipeline_parallel_size=ref_inf[2], ), ), critic_train=MFCConfig( parallel=ParallelismConfig( data_parallel_size=critic_train[0], - model_parallel_size=critic_train[1], + tensor_parallel_size=critic_train[1], pipeline_parallel_size=critic_train[2], ), ), @@ -309,7 +309,7 @@ def test_ppo_param_realloc_sub_device_mesh( device_mesh="NODE01:0,1,2,3", parallel=ParallelismConfig( data_parallel_size=actor_gen[0], - model_parallel_size=actor_gen[1], + tensor_parallel_size=actor_gen[1], pipeline_parallel_size=actor_gen[2], ), ), @@ -317,7 +317,7 @@ def test_ppo_param_realloc_sub_device_mesh( device_mesh="NODE01:4,5,6,7", parallel=ParallelismConfig( data_parallel_size=4, - model_parallel_size=1, + tensor_parallel_size=1, pipeline_parallel_size=1, ), ), @@ -325,7 +325,7 @@ def test_ppo_param_realloc_sub_device_mesh( device_mesh="NODE01:4,5,6,7", parallel=ParallelismConfig( data_parallel_size=critic_inf[0], - model_parallel_size=critic_inf[1], + tensor_parallel_size=critic_inf[1], pipeline_parallel_size=critic_inf[2], ), ), @@ -333,7 +333,7 @@ def test_ppo_param_realloc_sub_device_mesh( device_mesh="NODE01:4,5,6,7", parallel=ParallelismConfig( data_parallel_size=4, - model_parallel_size=1, + tensor_parallel_size=1, pipeline_parallel_size=1, ), ), @@ -341,7 +341,7 @@ def test_ppo_param_realloc_sub_device_mesh( device_mesh="NODE01:4,5,6,7", parallel=ParallelismConfig( data_parallel_size=1, - model_parallel_size=2, + tensor_parallel_size=2, pipeline_parallel_size=2, ), ), @@ -349,7 +349,7 @@ def test_ppo_param_realloc_sub_device_mesh( device_mesh="NODE01:4,5,6,7", parallel=ParallelismConfig( data_parallel_size=2, - model_parallel_size=1, + tensor_parallel_size=1, pipeline_parallel_size=2, ), ), @@ -431,7 +431,7 @@ def test_ppo_save( actor_gen=MFCConfig( parallel=ParallelismConfig( data_parallel_size=1, - model_parallel_size=2, + tensor_parallel_size=2, pipeline_parallel_size=1, ) ), @@ -439,28 +439,28 @@ def test_ppo_save( device_mesh="NODE01:0", parallel=ParallelismConfig( data_parallel_size=1, - model_parallel_size=1, + tensor_parallel_size=1, pipeline_parallel_size=1, ), ), critic_inf=MFCConfig( parallel=ParallelismConfig( data_parallel_size=2, - model_parallel_size=1, + tensor_parallel_size=1, pipeline_parallel_size=1, ) ), rew_inf=MFCConfig( parallel=ParallelismConfig( data_parallel_size=2, - model_parallel_size=1, + tensor_parallel_size=1, pipeline_parallel_size=1, ) ), ref_inf=MFCConfig( parallel=ParallelismConfig( data_parallel_size=2, - model_parallel_size=1, + tensor_parallel_size=1, pipeline_parallel_size=1, ) ), @@ -468,7 +468,7 @@ def test_ppo_save( device_mesh="NODE01:1", parallel=ParallelismConfig( data_parallel_size=1, - model_parallel_size=1, + tensor_parallel_size=1, pipeline_parallel_size=1, ), ), diff --git a/tests/fixtures.py b/tests/fixtures.py index fb96bb5..6abb0f6 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -108,7 +108,7 @@ def maybe_prepare_cpu_env(max_prompt_len: int): ) testing.init_global_constants( num_dp=1, - num_mp=1, + num_tp=1, num_pp=1, sequence_parallel=False, max_prompt_len=max_prompt_len, diff --git a/tests/legacy/test_sglang_tp.py b/tests/legacy/test_sglang_tp.py index a96f2d1..1fdb99d 100644 --- a/tests/legacy/test_sglang_tp.py +++ b/tests/legacy/test_sglang_tp.py @@ -82,7 +82,7 @@ def test_fn( constants.set_experiment_trial_names("slang-test", str(uuid.uuid4())) init_global_constants( num_dp=dp, - num_mp=tp, + num_tp=tp, num_pp=pp, sequence_parallel=False, model_name=model_name, @@ -181,7 +181,7 @@ def test_fn( tokenizer=tokenizer, gconfig=gconfig, ) - if constants.model_parallel_rank() == 0: + if constants.tensor_parallel_rank() == 0: # The outputs are Nones for tp_rank > 1 in SGLang _, _, token_match_percent, seq_match_percent = ( check_sequences_consistency(gen_tokens1, gen_tokens2) diff --git a/tests/legacy/test_vllm_tp.py b/tests/legacy/test_vllm_tp.py index 191582a..9c9d191 100644 --- a/tests/legacy/test_vllm_tp.py +++ b/tests/legacy/test_vllm_tp.py @@ -85,7 +85,7 @@ def test_fn( model_name = ModelName("default", 0) init_global_constants( num_dp=dp, - num_mp=tp, + num_tp=tp, num_pp=pp, sequence_parallel=False, model_name=model_name, diff --git a/tests/model/test_cpu_inference.py b/tests/model/test_cpu_inference.py index 6e96d4a..3ded97c 100644 --- a/tests/model/test_cpu_inference.py +++ b/tests/model/test_cpu_inference.py @@ -18,7 +18,7 @@ logger = logging.getLogger("tests.test_cpu") # NOTE: To run test for a new model class, please implement and register `real_config_maker` # in realhf.api.from_hf. and add the model class name to the # `model_class` fixture in this file. -@pytest.fixture(params=["llama", "gpt2", "qwen2", "gemma", "mistral", "mixtral"]) +@pytest.fixture(params=["llama", "gpt2", "qwen2", "mistral", "mixtral", "qwen3"]) def model_class(request): return request.param diff --git a/tests/model/test_distributed_load_hf.py b/tests/model/test_distributed_load_hf.py index 5d049bf..4a3f2c6 100644 --- a/tests/model/test_distributed_load_hf.py +++ b/tests/model/test_distributed_load_hf.py @@ -14,8 +14,7 @@ import torch import torch.distributed as dist import transformers -from realhf.api.cli_args import ModelFamily -from realhf.api.core.model_api import HF_MODEL_FAMILY_REGISTRY, ReaLModelConfig +from realhf.api.core.model_api import ReaLModelConfig from realhf.base import constants, logging from realhf.base.testing import ( LocalMultiProcessTest, @@ -43,7 +42,7 @@ def _save_then_load( model_family_name: str, is_critic: bool, init_critic_from_actor: bool, - pp_dp_mp: Tuple, + pp_dp_tp: Tuple, device: torch.device, ): # NOTE: import here to avoid initializing CUDA context in the main process @@ -52,10 +51,10 @@ def _save_then_load( # os.environ["REAL_SAVE_MAX_SHARD_SIZE_BYTE"] = str(int(1e6)) model_name = f"saveload_test_{model_family_name}" - num_pp, num_dp, num_mp = pp_dp_mp + num_pp, num_dp, num_tp = pp_dp_tp init_global_constants( num_dp=num_dp, - num_mp=num_mp, + num_tp=num_tp, num_pp=num_pp, model_name=model_name, ) @@ -71,7 +70,7 @@ def _save_then_load( ReaLModel, f"make_{model_family_name}_config" )() mconfig.is_critic = is_critic - if mconfig.n_kv_heads % num_mp != 0: + if mconfig.n_kv_heads % num_tp != 0: return # load from hf model or create a new critic model @@ -146,20 +145,20 @@ def _save_then_load( @pytest.mark.parametrize( "model_family_name", - ["gemma", "gpt2", "llama", "qwen2", "mistral", "mixtral"], + ["gemma", "gpt2", "llama", "qwen2", "mistral", "mixtral", "qwen3"], ) @pytest.mark.parametrize("is_critic", [True, False]) @pytest.mark.parametrize("init_critic_from_actor", [True, False]) -@pytest.mark.parametrize("pp_dp_mp", [(4, 2, 1), (2, 2, 2), (1, 2, 4), (1, 8, 1)]) +@pytest.mark.parametrize("pp_dp_tp", [(4, 2, 1), (2, 2, 2), (1, 2, 4), (1, 8, 1)]) @pytest.mark.distributed def test_save_then_load( tmp_path: pathlib.Path, model_family_name: str, is_critic: bool, init_critic_from_actor: bool, - pp_dp_mp: Tuple, + pp_dp_tp: Tuple, ): - if model_family_name == "gpt2" and pp_dp_mp[-1] > 1: + if model_family_name == "gpt2" and pp_dp_tp[-1] > 1: # GPT-2 has an odd vocabulary size, so it doesn't work # with tensor-model parallelism. return @@ -176,7 +175,7 @@ def test_save_then_load( model_family_name=model_family_name, is_critic=is_critic, init_critic_from_actor=init_critic_from_actor, - pp_dp_mp=pp_dp_mp, + pp_dp_tp=pp_dp_tp, tmp_path=tmp_path, device="cpu", ) diff --git a/tests/system/test_gserver_manager.py b/tests/system/test_gserver_manager.py index 7354f70..174b5e0 100644 --- a/tests/system/test_gserver_manager.py +++ b/tests/system/test_gserver_manager.py @@ -157,6 +157,7 @@ async def test_schedule_policy(gserver_manager): from realhf.api.core.model_api import GenReqMeta req_meta = GenReqMeta( + "1", prompt_len=100, group_size=2, new_token_budget=1024, @@ -187,6 +188,7 @@ async def test_weight_update(gserver_manager): UPDATE_WEIGHTS_CALL_COUNT.clear() req_meta = GenReqMeta( + "2", prompt_len=100, group_size=2, new_token_budget=1024, @@ -233,6 +235,7 @@ async def test_http_server_endpoints(gserver_manager): # Test schedule_request endpoint req_meta = GenReqMeta( + "3", prompt_len=100, group_size=2, new_token_budget=1024,