Update HCCL to support multi npus

This commit is contained in:
zjp_shadow 2025-02-10 16:18:34 +08:00
parent 8289c61138
commit a5c37bd58e
21 changed files with 1789 additions and 1201 deletions

View File

@ -600,6 +600,26 @@ def setup_nccl():
nccl_ops = nccl.ops nccl_ops = nccl.ops
LOG.vv("Get nccl_ops: "+str(dir(nccl_ops))) LOG.vv("Get nccl_ops: "+str(dir(nccl_ops)))
def setup_hccl():
global hccl_ops
hccl_src_dir = os.path.join(jittor_path, "extern", "acl", "hccl")
hccl_src_files = []
for r, _, f in os.walk(hccl_src_dir):
for fname in f:
hccl_src_files.append(os.path.join(r, fname))
hccl_include_path = os.path.join(os.environ.get("ASCEND_TOOLKIT_HOME"), "aarch64-linux/include/hccl")
hccl_lib_name = os.path.join(os.environ.get("ASCEND_TOOLKIT_HOME"), "aarch64-linux/lib64/libhccl.so")
ctypes.CDLL(hccl_lib_name, dlopen_flags)
hccl = compile_custom_ops(hccl_src_files,
extra_flags=f" -I\"{hccl_include_path}\" {mpi_compile_flags} ",
return_module=True, dlopen_flags=os.RTLD_GLOBAL | os.RTLD_NOW,
gen_name_="jittor_hccl_core")
hccl_ops = hccl.ops
LOG.vv("Get hccl_ops: "+str(dir(hccl_ops)))
def manual_link(flags): def manual_link(flags):
lib_dirs = [] lib_dirs = []
libs = [] libs = []
@ -697,10 +717,12 @@ cudnn = cublas = curand = cufft = None
setup_mpi() setup_mpi()
rank = mpi.world_rank() if in_mpi else 0 rank = mpi.world_rank() if in_mpi else 0
world_size = mpi.world_size() if in_mpi else 1 world_size = mpi.world_size() if in_mpi else 1
setup_nccl() if has_acl:
setup_hccl()
setup_cutt() elif has_cuda:
setup_cutlass() setup_nccl()
setup_cutt()
setup_cutlass()
# try: # try:
setup_mkl() setup_mkl()

File diff suppressed because it is too large Load Diff

View File

@ -38,7 +38,6 @@ namespace jittor
static void *acl_jittor_process_callback(void *) static void *acl_jittor_process_callback(void *)
{ {
acl_jittor_thread_running = 1; acl_jittor_thread_running = 1;
int deviceId = 0;
while (acl_jittor_thread_running) while (acl_jittor_thread_running)
{ {

View File

@ -1,496 +1,502 @@
// *************************************************************** // ***************************************************************
// Copyright (c) 2023 Jittor. All Rights Reserved. // Copyright (c) 2023 Jittor. All Rights Reserved.
// Maintainers: Dun Liang <randonlang@gmail.com>. // Maintainers: Dun Liang <randonlang@gmail.com>.
// This file is subject to the terms and conditions defined in // This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package. // file 'LICENSE.txt', which is part of this source code package.
// *************************************************************** // ***************************************************************
#include <acl/acl.h> #include <acl/acl.h>
#include <acl/acl_op_compiler.h> #include <acl/acl_op_compiler.h>
#include <Python.h> #include <Python.h>
#include <pystate.h> #include <pystate.h>
#include <algorithm> #include <algorithm>
#include <queue> #include <queue>
#include <set> #include <set>
#include "common.h" #include "common.h"
#include "op.h" #include "op.h"
#include "acl_jittor.h" #include "acl_jittor.h"
#include "ops/random_op.h" #include "ops/random_op.h"
#include "ops/reduce_op.h" #include "ops/reduce_op.h"
#include "ops/binary_op.h" #include "ops/binary_op.h"
#include "ops/broadcast_to_op.h" #include "ops/broadcast_to_op.h"
#include "ops/transpose_op.h" #include "ops/transpose_op.h"
#include "ops/array_op.h" #include "ops/array_op.h"
#include "ops/code_op.h" #include "ops/code_op.h"
#include "fused_op.h" #include "fused_op.h"
#include "ops/unary_op.h" #include "ops/unary_op.h"
#include "ops/ternary_op.h" #include "ops/ternary_op.h"
#include "executor.h" #include "executor.h"
#include "misc/cuda_flags.h" #include "misc/cuda_flags.h"
#include "mem/allocator.h" #include "mem/allocator.h"
#include "op_compiler.h" #include "op_compiler.h"
#include "ops/op_register.h" #include "ops/op_register.h"
#include "opt/tuner_manager.h" #include "opt/tuner_manager.h"
#include "utils/str_utils.h" #include "utils/str_utils.h"
#include "aclnn/aclnn.h" #include "aclnn/aclnn.h"
#include "aclops/aclops.h" #include "aclops/aclops.h"
namespace jittor namespace jittor
{ {
void free_var_mem(Var *v); void free_var_mem(Var *v);
unordered_map<uint32, string> opname_map = { unordered_map<uint32, string> opname_map = {
// unary op // unary op
{ns_cast, "Cast"}, {ns_cast, "Cast"},
{ns_negative, "Neg"}, {ns_negative, "Neg"},
{ns_abs, "Abs"}, {ns_abs, "Abs"},
{ns_exp, "Exp"}, {ns_exp, "Exp"},
{ns_log, "Log"}, {ns_log, "Log"},
{ns_sqrt, "Sqrt"}, {ns_sqrt, "Sqrt"},
{ns_ceil, "Ceil"}, {ns_ceil, "Ceil"},
{ns_floor, "Floor"}, {ns_floor, "Floor"},
{ns_round, "Round"}, {ns_round, "Round"},
// m(round_int) // m(round_int)
// m(floor_int) // m(floor_int)
// m(ceil_int) // m(ceil_int)
{ns_sin, "Sin"}, {ns_sin, "Sin"},
{ns_cos, "Cos"}, {ns_cos, "Cos"},
{ns_tan, "Tan"}, {ns_tan, "Tan"},
{ns_asin, "Asin"}, {ns_asin, "Asin"},
{ns_acos, "Acos"}, {ns_acos, "Acos"},
{ns_atan, "Atan"}, {ns_atan, "Atan"},
{ns_sinh, "Sinh"}, {ns_sinh, "Sinh"},
{ns_cosh, "Cosh"}, {ns_cosh, "Cosh"},
{ns_tanh, "Tanh"}, {ns_tanh, "Tanh"},
{ns_asinh, "Asinh"}, {ns_asinh, "Asinh"},
{ns_acosh, "Acosh"}, {ns_acosh, "Acosh"},
{ns_atanh, "Atanh"}, {ns_atanh, "Atanh"},
{ns_sigmoid, "Sigmoid"}, {ns_sigmoid, "Sigmoid"},
{ns_erf, "Erf"}, {ns_erf, "Erf"},
{ns_erfinv, "Erfinv"}, {ns_erfinv, "Erfinv"},
{ns_logical_not, "LogicalNot"}, {ns_logical_not, "LogicalNot"},
{ns_bitwise_not, "BitwiseNot"}, {ns_bitwise_not, "BitwiseNot"},
// binary op // binary op
{ns_pow, "Pow"}, {ns_pow, "Pow"},
{ns_maximum, "Maximum"}, {ns_maximum, "Maximum"},
{ns_minimum, "Minimum"}, {ns_minimum, "Minimum"},
{ns_add, "Add"}, {ns_add, "Add"},
{ns_subtract, "Sub"}, {ns_subtract, "Sub"},
{ns_multiply, "Mul"}, {ns_multiply, "Mul"},
{ns_divide, "RealDiv"}, {ns_divide, "RealDiv"},
{ns_floor_divide, "FloorDiv"}, {ns_floor_divide, "FloorDiv"},
{ns_mod, "Mod"}, {ns_mod, "Mod"},
{ns_less, "Less"}, {ns_less, "Less"},
{ns_less_equal, "LessEqual"}, {ns_less_equal, "LessEqual"},
{ns_greater, "Greater"}, {ns_greater, "Greater"},
{ns_greater_equal, "GreaterEqual"}, {ns_greater_equal, "GreaterEqual"},
{ns_equal, "Equal"}, {ns_equal, "Equal"},
{ns_not_equal, "NotEqual"}, {ns_not_equal, "NotEqual"},
{ns_left_shift, "LeftShift"}, {ns_left_shift, "LeftShift"},
{ns_right_shift, "RightShift"}, {ns_right_shift, "RightShift"},
{ns_logical_and, "LogicalAnd"}, {ns_logical_and, "LogicalAnd"},
{ns_logical_or, "LogicalOr"}, {ns_logical_or, "LogicalOr"},
{ns_logical_xor, "LogicalXor"}, {ns_logical_xor, "LogicalXor"},
{ns_bitwise_and, "BitwiseAnd"}, {ns_bitwise_and, "BitwiseAnd"},
{ns_bitwise_or, "BitwiseOr"}, {ns_bitwise_or, "BitwiseOr"},
{ns_bitwise_xor, "BitwiseXor"}, {ns_bitwise_xor, "BitwiseXor"},
}; };
void fallback_cpu(Op *op) void fallback_cpu(Op *op)
{ {
LOGy << "!!! fallback_cpu " << op; LOGy << "!!! fallback_cpu " << op;
use_cuda = 0; use_cuda = 0;
for (auto v : op->inputs()) for (auto v : op->inputs())
{ {
if (v->mem_ptr && v->allocator->is_cuda()) if (v->mem_ptr && v->allocator->is_cuda())
{ {
migrate_to_cpu(v, exe.allocator); migrate_to_cpu(v, exe.allocator);
} }
} }
for (auto v : op->outputs()) for (auto v : op->outputs())
{ {
if (v->mem_ptr && v->allocator->is_cuda()) if (v->mem_ptr && v->allocator->is_cuda())
{ {
migrate_to_cpu(v, exe.allocator); migrate_to_cpu(v, exe.allocator);
} }
} }
op->flags.set(NodeFlags::_cpu); op->flags.set(NodeFlags::_cpu);
op->flags.set(NodeFlags::_cuda, 0); op->flags.set(NodeFlags::_cuda, 0);
if (op->name() == string("fused")) if (op->name() == string("fused"))
{ {
auto fop = (FusedOp *)op; auto fop = (FusedOp *)op;
for (auto op : fop->ops) for (auto op : fop->ops)
{ {
op->flags.set(NodeFlags::_cpu); op->flags.set(NodeFlags::_cpu);
op->flags.set(NodeFlags::_cuda, 0); op->flags.set(NodeFlags::_cuda, 0);
} }
} }
op->do_run(); op->do_run();
use_cuda = 1; use_cuda = 1;
} }
/* /*
check compile check compile
if compiled: exec if compiled: exec
else: compile else: compile
check is fused check is fused
check is relay check is relay
else else
compile func = try exec compile func = try exec
if failed: fallback_cpu if failed: fallback_cpu
else else
try compile try compile
if failed: fallback_cpu if failed: fallback_cpu
*/ */
extern jit_op_entry_t (*do_compile_hook)(Op *); extern jit_op_entry_t (*do_compile_hook)(Op *);
jit_op_entry_t do_compile_inner(Op *op); jit_op_entry_t do_compile_inner(Op *op);
void try_exec_and_fallback_cpu(Op *op) void try_exec_and_fallback_cpu(Op *op)
{ {
auto fop = (FusedOp *)op; aclrtSynchronizeStream(aclstream);
auto fop = (FusedOp *)op;
std::set<Var *> new_alloced;
map<Op *, int> op_indeg; std::set<Var *> new_alloced;
map<Var *, int> var_outdeg; map<Op *, int> op_indeg;
std::queue<Op *> queue; map<Var *, int> var_outdeg;
std::queue<Op *> queue;
for (Op *op : fop->ops)
op_indeg[op] = 0; for (Op *op : fop->ops)
op_indeg[op] = 0;
map<Op *, vector<Op *>> out_map;
map<Var *, vector<Op *>> from; map<Op *, vector<Op *>> out_map;
map<Var *, vector<Op *>> from;
int len = 0;
for (Op *v : fop->ops) int len = 0;
{ for (Op *v : fop->ops)
for (auto in : v->inputs()) {
from[in].push_back(v); for (auto in : v->inputs())
++len; from[in].push_back(v);
} ++len;
for (Op *u : fop->ops) }
{ for (Op *u : fop->ops)
for (auto out : u->outputs()) {
{ for (auto out : u->outputs())
if (from.find(out) != from.end()) {
{ if (from.find(out) != from.end())
for (auto v : from[out]) {
{ for (auto v : from[out])
++op_indeg[v]; {
++var_outdeg[out]; ++op_indeg[v];
out_map[u].push_back(v); ++var_outdeg[out];
} out_map[u].push_back(v);
} }
} }
} }
for (Op *op : fop->ops) }
{ for (Op *op : fop->ops)
if (op_indeg[op] == 0) {
queue.push(op); if (op_indeg[op] == 0)
} queue.push(op);
}
int total = 0;
int fallback = 0; int total = 0;
try int fallback = 0;
{ try
while (!queue.empty()) {
{ while (!queue.empty())
total++; {
total++;
for (auto in : op->inputs())
{ for (auto in : op->inputs())
ASSERT(in->mem_ptr); {
} ASSERT(in->mem_ptr);
auto op = queue.front(); }
queue.pop(); auto op = queue.front();
for (auto out : op->outputs()) queue.pop();
{ for (auto out : op->outputs())
if (out->mem_ptr) {
continue; if (out->mem_ptr)
out->alloc(exe.allocator); continue;
new_alloced.insert(out); out->alloc(exe.allocator);
} new_alloced.insert(out);
for (auto out : out_map[op]) }
{ for (auto out : out_map[op])
--op_indeg[out]; {
if (op_indeg[out] == 0) --op_indeg[out];
queue.push(out); if (op_indeg[out] == 0)
} queue.push(out);
if (op->name() == string("unary")) }
{ if (op->name() == string("unary"))
auto uop = (UnaryOp *)op; {
UnaryOpRunner op; auto uop = (UnaryOp *)op;
op.add(uop->x, true); UnaryOpRunner op;
op.add(uop->y, false); op.add(uop->x, true);
auto iter = opname_map.find(uop->ns); op.add(uop->y, false);
ASSERT(iter != opname_map.end()) << "op " << uop->ns << " not found"; auto iter = opname_map.find(uop->ns);
op.name = iter->second; ASSERT(iter != opname_map.end()) << "op " << uop->ns << " not found";
op.jt_name = uop->name(); op.name = iter->second;
op.run(); op.jt_name = uop->name();
} op.run();
else if (op->name() == string("binary")) }
{ else if (op->name() == string("binary"))
auto bop = (BinaryOp *)op; {
BinaryOpRunner op; auto bop = (BinaryOp *)op;
op.add(bop->x, true); BinaryOpRunner op;
op.add(bop->y, true); op.add(bop->x, true);
op.add(bop->z, false); op.add(bop->y, true);
auto iter = opname_map.find(bop->ns); op.add(bop->z, false);
ASSERT(iter != opname_map.end()) << "op " << bop->ns << " not found"; auto iter = opname_map.find(bop->ns);
op.name = iter->second; ASSERT(iter != opname_map.end()) << "op " << bop->ns << " not found";
op.jt_name = bop->name(); op.name = iter->second;
op.jt_name = bop->name();
if (bop->x->dtype() == ns_bool and bop->y->dtype() == ns_bool)
{ if (bop->x->dtype() == ns_bool and bop->y->dtype() == ns_bool)
// BitwiseOr, BitwiseAnd, BitwiseXor -> LogicalOr, LogicalAnd, LogicalXor {
if (bop->ns == ns_bitwise_or) // BitwiseOr, BitwiseAnd, BitwiseXor -> LogicalOr, LogicalAnd, LogicalXor
{ if (bop->ns == ns_bitwise_or)
op.name = "LogicalOr"; {
} op.name = "LogicalOr";
else if (bop->ns == ns_bitwise_and) }
{ else if (bop->ns == ns_bitwise_and)
op.name = "LogicalAnd"; {
} op.name = "LogicalAnd";
else if (bop->ns == ns_bitwise_xor) }
{ else if (bop->ns == ns_bitwise_xor)
op.name = "LogicalXor"; {
} op.name = "LogicalXor";
} }
op.run(); }
} op.run();
else if (op->name() == string("ternary")) }
{ else if (op->name() == string("ternary"))
auto top = (TernaryOp *)op; {
TernaryOpRunner op; auto top = (TernaryOp *)op;
op.add(top->cond, true); TernaryOpRunner op;
op.add(top->x, true); op.add(top->cond, true);
op.add(top->y, true); op.add(top->x, true);
op.add(top->z, false); op.add(top->y, true);
op.run(); op.add(top->z, false);
} op.run();
else if (op->name() == string("array")) }
{ else if (op->name() == string("array"))
auto aop = (ArrayOp *)op; {
aclrtMemcpy(aop->output->mem_ptr, aop->output->size, aop->ptr<void>(), aop->output->size, ACL_MEMCPY_HOST_TO_DEVICE); auto aop = (ArrayOp *)op;
} aclrtMemcpy(aop->output->mem_ptr, aop->output->size, aop->ptr<void>(), aop->output->size, ACL_MEMCPY_HOST_TO_DEVICE);
else if (op->name() == string("reduce")) }
{ else if (op->name() == string("reduce"))
auto rop = (ReduceOp *)op; {
ReduceOpRunner op; auto rop = (ReduceOp *)op;
if (rop->ns == ns_add) ReduceOpRunner op;
op.op_idx = 9; if (rop->ns == ns_add)
else if (rop->ns == ns_multiply) op.op_idx = 9;
// TODO unsupported the multi dim else if (rop->ns == ns_multiply)
op.op_idx = 999; // TODO unsupported the multi dim
else if (rop->ns == ns_maximum) op.op_idx = 999;
op.op_idx = 11; else if (rop->ns == ns_maximum)
else if (rop->ns == ns_minimum) op.op_idx = 11;
op.op_idx = 12; else if (rop->ns == ns_minimum)
else if (rop->ns == ns_mean) op.op_idx = 12;
op.op_idx = 10; else if (rop->ns == ns_mean)
else op.op_idx = 10;
LOGf << "op " << rop->ns << " not supported"; else
op.add(rop->x, true); LOGf << "op " << rop->ns << " not supported";
op.add(rop->x, true);
ReduceAttr *attr = new ReduceAttr();
for (int i = 0; i < rop->x->shape.size(); i++) ReduceAttr *attr = new ReduceAttr();
if (rop->reduce_mask & (1 << i)) for (int i = 0; i < rop->x->shape.size(); i++)
attr->axes.push_back(i); if (rop->reduce_mask & (1 << i))
if (rop->x->shape.size() == rop->y->shape.size()) attr->axes.push_back(i);
attr->keepdims = true; if (rop->x->shape.size() == rop->y->shape.size())
else attr->keepdims = true;
attr->keepdims = false; else
attr->keepdims = false;
op.op_attr.reset(attr);
op.add(rop->y, false); op.op_attr.reset(attr);
op.run(); op.add(rop->y, false);
aclrtSynchronizeStream(aclstream); op.run();
} aclrtSynchronizeStream(aclstream);
else if (op->name() == string("broadcast_to")) }
{ else if (op->name() == string("broadcast_to"))
auto bop = (BroadcastToOp *)op; {
ExpandOpRunner op; auto bop = (BroadcastToOp *)op;
op.jt_name = "expand"; ExpandOpRunner op;
NanoVector xshape, xshape_bk = bop->x->shape; op.jt_name = "expand";
NanoVector zshape = bop->z->shape; NanoVector xshape, xshape_bk = bop->x->shape;
NanoVector zshape = bop->z->shape;
for (int i = 0; i < zshape.size(); i++)
{ for (int i = 0; i < zshape.size(); i++)
if (bop->bcast_mask & (1 << i)) {
{ if (bop->bcast_mask & (1 << i))
xshape.push_back(1); {
} xshape.push_back(1);
else }
{ else
xshape.push_back(zshape[i]); {
} xshape.push_back(zshape[i]);
} }
bop->x->shape = xshape; }
op.add(bop->x, true); bop->x->shape = xshape;
// bop->x->shape = xshape_bk; op.add(bop->x, true);
op.add(bop->z, false); // bop->x->shape = xshape_bk;
op.run(); op.add(bop->z, false);
bop->x->shape = xshape_bk; op.run();
aclrtSynchronizeStream(aclstream); bop->x->shape = xshape_bk;
} aclrtSynchronizeStream(aclstream);
else if (op->name() == string("fuse_transpose")) }
{ else if (op->name() == string("fuse_transpose"))
// replace fuse_transpose with transpose {
auto top = (TransposeOp *)op; // replace fuse_transpose with transpose
TransposeOpRunner op; auto top = (TransposeOp *)op;
op.add(top->x, true); TransposeOpRunner op;
op.add(top->y, false); op.add(top->x, true);
op.jt_name = "transpose"; op.add(top->y, false);
op.jt_name = "transpose";
ReduceAttr *attr = new ReduceAttr();
for (int i = 0; i < top->axes.size(); i++) ReduceAttr *attr = new ReduceAttr();
attr->axes.push_back(top->axes[i]); for (int i = 0; i < top->axes.size(); i++)
op.op_attr.reset(attr); attr->axes.push_back(top->axes[i]);
op.op_attr.reset(attr);
op.run();
} op.run();
else }
{ else
LOGf << "op " << op->name() << " not supported"; {
} LOGf << "op " << op->name() << " not supported";
}
for (auto in : op->inputs())
{ for (auto in : op->inputs())
--var_outdeg[in]; {
if (var_outdeg[in] == 0) --var_outdeg[in];
{ if (var_outdeg[in] == 0)
if (new_alloced.find(in) != new_alloced.end()) {
{ if (new_alloced.find(in) != new_alloced.end())
free_var_mem(in); {
new_alloced.erase(in); free_var_mem(in);
} new_alloced.erase(in);
} }
} }
} }
} }
catch (std::exception &e) }
{ catch (std::exception &e)
fallback = 1; {
LOGir << "fallback cpu" << e.what(); fallback = 1;
} LOGir << "fallback cpu" << e.what();
for (auto v : new_alloced) }
{ for (auto v : new_alloced)
free_var_mem(v); {
} free_var_mem(v);
if (fallback) }
{ if (fallback)
fallback_cpu(op); {
} fallback_cpu(op);
} }
}
extern int current_seed;
extern int64 current_offset; extern int current_seed;
extern int64 current_offset;
static unordered_map<string, std::function<void(Op *)>> acl_ops = {
{"curand_random", [&current_seed, &current_offset](Op *op) static unordered_map<string, std::function<void(Op *)>> acl_ops = {
{ {"curand_random", [&current_seed, &current_offset](Op *op)
auto _op = (RandomOp *)op; {
RandomOpRunner runner(_op->type == ns_uniform ? "RandomUniform" : "RandomNormal"); auto _op = (RandomOp *)op;
auto out = op->output(0); RandomOpRunner runner(_op->type == ns_uniform ? "RandomUniform" : "RandomNormal");
RandomAttr *attr = new RandomAttr(); auto out = op->output(0);
attr->seed = current_seed; RandomAttr *attr = new RandomAttr();
attr->offset = current_offset; attr->seed = current_seed;
runner.jt_name = "random"; attr->offset = current_offset;
runner.op_attr.reset(attr); runner.jt_name = "random";
runner.op_attr.reset(attr);
runner.add(out, false);
runner.run(); runner.add(out, false);
current_offset += out->numel(); runner.run();
}}, current_offset += out->numel();
}; }},
};
static void exec_mapped_acl_ops(Op *op)
{ static void exec_mapped_acl_ops(Op *op)
auto iter = acl_ops.find(op->name()); {
if (iter != acl_ops.end()) auto iter = acl_ops.find(op->name());
{ if (iter != acl_ops.end())
LOGv << "exec acl op " << op->name() << op; {
iter->second(op); LOGv << "exec acl op " << op->name() << op;
} iter->second(op);
else }
{ else
LOGf << "op " << op->name() << " not supported"; {
} LOGf << "op " << op->name() << " not supported";
} }
}
static jit_op_entry_t acl_do_compile(Op *op)
{ static jit_op_entry_t acl_do_compile(Op *op)
LOGv << "compile" << op; {
OpCompiler oc(op); LOGv << "compile" << op;
string *src = &oc.src; OpCompiler oc(op);
for (auto op_type : op_types) string *src = &oc.src;
op_type->post_pass(&oc); for (auto op_type : op_types)
string src_after_passes; op_type->post_pass(&oc);
// if is fused op string src_after_passes;
if (oc.op) // if is fused op
{ if (oc.op)
TunerManager tm(&oc); {
src_after_passes = tm.tune(); TunerManager tm(&oc);
src = &src_after_passes; src_after_passes = tm.tune();
} src = &src_after_passes;
op->compile_optimize(*src); }
if (!op->flags.get(NodeFlags::_cuda)) op->compile_optimize(*src);
{ if (!op->flags.get(NodeFlags::_cuda))
LOGv << "compile cpu"; {
return oc.compile(op->get_jit_key(get_jk()), *src); LOGv << "compile cpu";
} return oc.compile(op->get_jit_key(get_jk()), *src);
if (op->name() == string("fused")) }
{ if (op->name() == string("fused"))
FusedOp *fop = (FusedOp *)op; {
// if is a relayed op FusedOp *fop = (FusedOp *)op;
if (fop->context->vrm.relay_groups.size()) // if is a relayed op
{ if (fop->context->vrm.relay_groups.size())
LOGv << "relay fused op"; {
return oc.compile(op->get_jit_key(get_jk()), *src); LOGv << "relay fused op";
} return oc.compile(op->get_jit_key(get_jk()), *src);
else }
{ else
return &try_exec_and_fallback_cpu; {
} return &try_exec_and_fallback_cpu;
} }
else if (op->name() == string("code")) }
{ else if (op->name() == string("code"))
CodeOp *cop = (CodeOp *)op; {
if (cop->cuda_src.find("acl") != string::npos) CodeOp *cop = (CodeOp *)op;
{ if (cop->cuda_src.find("acl") != string::npos)
LOGv << "compile acl op"; {
return oc.compile(op->get_jit_key(get_jk()), *src); LOGv << "compile acl op";
} return oc.compile(op->get_jit_key(get_jk()), *src);
else }
{ else
return &exec_mapped_acl_ops; {
} return &exec_mapped_acl_ops;
} }
else }
{ else if (strncmp(op->name(), "hccl", 4) == 0)
LOGv << "compile finish" << op; {
return &exec_mapped_acl_ops; LOGv << "Compiling HCCL op: " << op->name();
} return oc.compile(op->get_jit_key(get_jk()), *src);
return do_compile_inner(op); }
} else
{
// from op_register.cc LOGv << "compile finish" << op;
extern unordered_map<string, OpInfo> op_info_map; return &exec_mapped_acl_ops;
}
void init_acl_ops() return do_compile_inner(op);
{ }
do_compile_hook = acl_do_compile;
vector<string> to_erase; // from op_register.cc
for (auto &kv : op_info_map) extern unordered_map<string, OpInfo> op_info_map;
{
if (startswith(kv.first, "cu") && acl_ops.count(kv.first) == 0) void init_acl_ops()
{ {
to_erase.push_back(kv.first); do_compile_hook = acl_do_compile;
} vector<string> to_erase;
} for (auto &kv : op_info_map)
for (auto &k : to_erase) {
{ if (startswith(kv.first, "cu") && acl_ops.count(kv.first) == 0)
LOGv << "op not supported: " << k << ", erase it."; {
op_info_map.erase(k); to_erase.push_back(kv.first);
} }
} }
for (auto &k : to_erase)
} // jittor {
LOGv << "op not supported: " << k << ", erase it.";
op_info_map.erase(k);
}
}
} // jittor

View File

@ -21,7 +21,7 @@ void PrintOutResult(std::vector<int64_t> &shape, void** deviceAddr) {
} }
} }
int Init(int32_t deviceId) { /*int Init(int32_t deviceId) {
// 固定写法AscendCL初始化 // 固定写法AscendCL初始化
auto ret = aclInit(nullptr); auto ret = aclInit(nullptr);
CHECK_RET(ret == ACL_SUCCESS or ret == 100002, LOG_PRINT("aclInit failed. ERROR: %d\n", ret); return ret); CHECK_RET(ret == ACL_SUCCESS or ret == 100002, LOG_PRINT("aclInit failed. ERROR: %d\n", ret); return ret);
@ -30,7 +30,7 @@ int Init(int32_t deviceId) {
//ret = aclrtCreateStream(stream); //ret = aclrtCreateStream(stream);
//CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtCreateStream failed. ERROR: %d\n", ret); return ret); //CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtCreateStream failed. ERROR: %d\n", ret); return ret);
return 0; return 0;
} }*/
/* /*
template <typename T> template <typename T>

View File

@ -125,7 +125,7 @@ int64_t GetShapeSize(const std::vector<int64_t> &shape);
void PrintOutResult(std::vector<int64_t> &shape, void **deviceAddr); void PrintOutResult(std::vector<int64_t> &shape, void **deviceAddr);
int Init(int32_t deviceId); //int Init(int32_t deviceId);
/* /*
template <typename T> template <typename T>

View File

@ -0,0 +1,38 @@
// ***************************************************************
// Copyright (c) 2025 Jittor.
// All Rights Reserved.
// Maintainers:
// Jiapeng Zhang <zjp24@mails.tsinghua.edu.cn>.
//
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#pragma once
#include "mpi_wrapper.h"
#define ACLCHECK(ret) do {\
if(ret != ACL_SUCCESS)\
{\
LOGe << "retcode: " << ret;\
return;\
}\
} while(0)\
#define HCCLCHECK(ret) do {\
if(ret != HCCL_SUCCESS)\
{\
LOGe << HcclGetErrorString(ret) << " retcode: " << ret;\
return;\
}\
} while(0)
#include <hccl.h>
namespace jittor {
EXTERN_LIB HcclRootInfo root_info;
EXTERN_LIB HcclComm comm;
EXTERN_LIB uint32_t hccl_device_id;
} // jittor

View File

@ -0,0 +1,70 @@
// ***************************************************************
// Copyright (c) 2025 Jittor.
// All Rights Reserved.
// Maintainers:
// Jiapeng Zhang <zjp24@mails.tsinghua.edu.cn>.
//
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#include "var.h"
#include "hccl_all_gather_op.h"
#include "ops/op_register.h"
#include "utils/str_utils.h"
#include "hccl_wrapper.h"
namespace jittor {
#ifndef JIT
static auto hccl_all_gather =
get_op_info("hccl_all_gather").get_constructor<VarPtr, Var*>();
HcclAllGatherOp::HcclAllGatherOp(Var* x) : x(x) {
flags.set(NodeFlags::_cpu, 0);
flags.set(NodeFlags::_cuda, 1);
y = create_output(nullptr, x->dtype());
}
void HcclAllGatherOp::infer_shape() {
NanoVector yshape;
yshape.push_back(mpi_world_size * x->shape[0]);
for (int i=1; i<x->shape.size(); i++)
yshape.push_back(x->shape[i]);
y->set_shape(yshape);
}
VarPtr HcclAllGatherOp::grad(Var* out, Var* dout, Var* v, int v_index) {
LOGf << "not implemented";
return nullptr;
}
void HcclAllGatherOp::jit_prepare(JK& jk) {
jk << "«Tx:" << x->dtype();
}
#else // JIT
void HcclAllGatherOp::jit_run() {
LOGir << "HcclAllGatherOp::jit_run";
@define(T_HCCL,
@if(@strcmp(@Tx,float)==0 || @strcmp(@Tx,float32)==0, HcclDataType::HCCL_DATA_TYPE_FP32)
@if(@strcmp(@Tx,int)==0 || @strcmp(@Tx,int32)==0, HcclDataType::HCCL_DATA_TYPE_INT32)
@if(@strcmp(@Tx,float64)==0, HcclDataType::HCCL_DATA_TYPE_FP64)
@if(@strcmp(@Tx,int64)==0, HcclDataType::HCCL_DATA_TYPE_INT64)
@if(@strcmp(@Tx,uint8)==0, HcclDataType::HCCL_DATA_TYPE_UINT8)
@if(@strcmp(@Tx,float16)==0, HcclDataType::HCCL_DATA_TYPE_FP16)
)
auto* __restrict__ xp = x->ptr<Tx>();
auto* __restrict__ yp = y->ptr<Tx>();
ACLCHECK(aclrtSynchronizeDevice());
ACLCHECK(aclrtSynchronizeStream(aclstream));
HCCLCHECK(HcclAllGather(xp, yp, (uint64_t)x->num, @T_HCCL, comm, aclstream));
ACLCHECK(aclrtSynchronizeDevice());
ACLCHECK(aclrtSynchronizeStream(aclstream));
}
#endif // JIT
} // jittor

View File

@ -0,0 +1,26 @@
// ***************************************************************
// Copyright (c) 2025 Jittor.
// All Rights Reserved.
// Maintainers:
// Jiapeng Zhang <zjp24@mails.tsinghua.edu.cn>.
//
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#pragma once
#include "op.h"
namespace jittor {
struct HcclAllGatherOp : Op {
Var* x, * y;
HcclAllGatherOp(Var* x);
void infer_shape() override;
const char* name() const override { return "hccl_all_gather"; }
VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override;
DECLARE_jit_run;
};
} // jittor

View File

@ -0,0 +1,62 @@
#include "var.h"
#include "hccl_all_reduce_op.h"
#include "ops/op_register.h"
#include "utils/str_utils.h"
#include "hccl_wrapper.h"
namespace jittor {
#ifndef JIT
static auto hccl_all_reduce =
get_op_info("hccl_all_reduce").get_constructor<VarPtr, Var*, string>();
HcclAllReduceOp::HcclAllReduceOp(Var* x, string reduce_op) : x(x), reduce_op(reduce_op) {
flags.set(NodeFlags::_cpu, 0);
flags.set(NodeFlags::_cuda, 1);
y = create_output(nullptr, x->dtype());
}
void HcclAllReduceOp::infer_shape() {
y->set_shape(x->shape);
}
VarPtr HcclAllReduceOp::grad(Var* out, Var* dout, Var* v, int v_index) {
return hccl_all_reduce(dout, reduce_op);
}
void HcclAllReduceOp::jit_prepare(JK& jk) {
jk << "«Tx:" << x->dtype();
jk << "«Op:" << reduce_op;
}
#else // JIT
void HcclAllReduceOp::jit_run() {
//LOGir << "HcclAllReduceOp::jit_run";
@define(T_HCCL,
@if(@strcmp(@Tx,float)==0 || @strcmp(@Tx,float32)==0, HcclDataType::HCCL_DATA_TYPE_FP32)
@if(@strcmp(@Tx,int)==0 || @strcmp(@Tx,int32)==0, HcclDataType::HCCL_DATA_TYPE_INT32)
@if(@strcmp(@Tx,float64)==0, HcclDataType::HCCL_DATA_TYPE_FP64)
@if(@strcmp(@Tx,int64)==0, HcclDataType::HCCL_DATA_TYPE_INT64)
@if(@strcmp(@Tx,uint8)==0, HcclDataType::HCCL_DATA_TYPE_UINT8)
@if(@strcmp(@Tx,float16)==0, HcclDataType::HCCL_DATA_TYPE_FP16)
)
@define(REDUCE_OP,
@if(@strcmp(@Op,sum)==0, HcclReduceOp::HCCL_REDUCE_SUM)
@if(@strcmp(@Op,prod)==0, HcclReduceOp::HCCL_REDUCE_PROD)
@if(@strcmp(@Op,max)==0, HcclReduceOp::HCCL_REDUCE_MAX)
@if(@strcmp(@Op,min)==0, HcclReduceOp::HCCL_REDUCE_MIN)
)
auto* __restrict__ xp = x->ptr<Tx>();
auto* __restrict__ yp = y->ptr<Tx>();
ACLCHECK(aclrtSynchronizeDevice());
ACLCHECK(aclrtSynchronizeStream(aclstream));
HCCLCHECK(HcclAllReduce(xp, yp, (uint64_t)x->num, @T_HCCL, @REDUCE_OP, comm, aclstream));
ACLCHECK(aclrtSynchronizeDevice());
ACLCHECK(aclrtSynchronizeStream(aclstream));
}
#endif // JIT
} // jittor

View File

@ -0,0 +1,18 @@
#pragma once
#include "op.h"
namespace jittor {
struct HcclAllReduceOp : Op {
Var* x, * y;
string reduce_op;
HcclAllReduceOp(Var* x, string reduce_op="sum");
void infer_shape() override;
const char* name() const override { return "hccl_all_reduce"; }
VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override;
DECLARE_jit_run;
};
} // jittor

View File

@ -0,0 +1,63 @@
#include "var.h"
#include "hccl_broadcast_op.h"
#include "ops/op_register.h"
#include "utils/str_utils.h"
#include "hccl_wrapper.h"
#include <cassert>
namespace jittor {
#ifndef JIT
static auto hccl_broadcast =
get_op_info("hccl_broadcast").get_constructor<VarPtr, Var*, int>();
HcclBroadcastOp::HcclBroadcastOp(Var* x, int root) : x(x), root(root) {
flags.set(NodeFlags::_cpu, 0);
flags.set(NodeFlags::_cuda, 1);
y = create_output(nullptr, x->dtype());
}
void HcclBroadcastOp::infer_shape() {
y->set_shape(x->shape);
}
VarPtr HcclBroadcastOp::grad(Var* out, Var* dout, Var* v, int v_index) {
return hccl_broadcast(dout, root);
}
void HcclBroadcastOp::jit_prepare(JK& jk) {
jk << "«Tx:" << x->dtype();
jk << "«Root:" << root;
}
#else // JIT
void HcclBroadcastOp::jit_run() {
//LOGir << "HcclBroadcastOp::jit_run";
@define(T_HCCL,
@if(@strcmp(@Tx,float)==0 || @strcmp(@Tx,float32)==0, HcclDataType::HCCL_DATA_TYPE_FP32)
@if(@strcmp(@Tx,int)==0 || @strcmp(@Tx,int32)==0, HcclDataType::HCCL_DATA_TYPE_INT32)
@if(@strcmp(@Tx,float64)==0, HcclDataType::HCCL_DATA_TYPE_FP64)
@if(@strcmp(@Tx,int64)==0, HcclDataType::HCCL_DATA_TYPE_INT64)
@if(@strcmp(@Tx,uint8)==0, HcclDataType::HCCL_DATA_TYPE_UINT8)
@if(@strcmp(@Tx,float16)==0, HcclDataType::HCCL_DATA_TYPE_FP16)
)
auto* __restrict__ xp = x->ptr<Tx>();
auto* __restrict__ yp = y->ptr<Tx>();
//LOGir << "HcclBroadcastOp::jit_run " << @Root << " " << hccl_device_id << " " << xp << " " << yp;
//ACLCHECK(aclrtSynchronizeStream(aclstream));
ACLCHECK(aclrtSynchronizeDevice());
ACLCHECK(aclrtSynchronizeStream(aclstream));
HCCLCHECK(HcclBroadcast(@Root == hccl_device_id ? xp : yp, (uint64_t)x->num, @T_HCCL, @Root, comm, aclstream));
if (@Root == hccl_device_id) {
ACLCHECK(aclrtMemcpy(yp, x->num * sizeof(Tx), xp, x->num * sizeof(Tx), ACL_MEMCPY_DEVICE_TO_DEVICE));
ACLCHECK(aclrtSynchronizeDevice());
}
ACLCHECK(aclrtSynchronizeDevice());
ACLCHECK(aclrtSynchronizeStream(aclstream));
}
#endif // JIT
} // jittor

View File

@ -0,0 +1,18 @@
#pragma once
#include "op.h"
namespace jittor {
struct HcclBroadcastOp : Op {
Var* x, * y;
int root;
HcclBroadcastOp(Var* x, int root=0);
void infer_shape() override;
const char* name() const override { return "hccl_broadcast"; }
VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override;
DECLARE_jit_run;
};
} // jittor

View File

@ -0,0 +1,63 @@
#include "var.h"
#include "hccl_reduce_op.h"
#include "ops/op_register.h"
#include "utils/str_utils.h"
#include "hccl_wrapper.h"
namespace jittor {
#ifndef JIT
HcclReduceOp::HcclReduceOp(Var* x, string reduce_op, int root) : x(x), reduce_op(reduce_op), root(root) {
flags.set(NodeFlags::_cpu, 0);
flags.set(NodeFlags::_cuda, 1);
y = create_output(nullptr, x->dtype());
}
void HcclReduceOp::infer_shape() {
y->set_shape(x->shape);
}
VarPtr HcclReduceOp::grad(Var* out, Var* dout, Var* v, int v_index) {
static auto hccl_broadcast =
get_op_info("hccl_broadcast").get_constructor<VarPtr, Var*, int>();
return hccl_broadcast(dout, root);
}
void HcclReduceOp::jit_prepare(JK& jk) {
jk << "«Tx:" << x->dtype();
jk << "«Op:" << reduce_op;
jk << "«Root:" << root;
}
#else // JIT
void HcclReduceOp::jit_run() {
LOGir << "HcclReduceOp::jit_run";
@define(T_HCCL,
@if(@strcmp(@Tx,float)==0 || @strcmp(@Tx,float32)==0, HcclDataType::HCCL_DATA_TYPE_FP32)
@if(@strcmp(@Tx,int)==0 || @strcmp(@Tx,int32)==0, HcclDataType::HCCL_DATA_TYPE_INT32)
@if(@strcmp(@Tx,float64)==0, HcclDataType::HCCL_DATA_TYPE_FP64)
@if(@strcmp(@Tx,int64)==0, HcclDataType::HCCL_DATA_TYPE_INT64)
@if(@strcmp(@Tx,uint8)==0, HcclDataType::HCCL_DATA_TYPE_UINT8)
@if(@strcmp(@Tx,float16)==0, HcclDataType::HCCL_DATA_TYPE_FP16)
)
@define(REDUCE_OP,
@if(@strcmp(@Op,sum)==0, HcclReduceOp::HCCL_REDUCE_SUM)
@if(@strcmp(@Op,prod)==0, HcclReduceOp::HCCL_REDUCE_PROD)
@if(@strcmp(@Op,max)==0, HcclReduceOp::HCCL_REDUCE_MAX)
@if(@strcmp(@Op,min)==0, HcclReduceOp::HCCL_REDUCE_MIN)
)
auto* __restrict__ xp = x->ptr<Tx>();
auto* __restrict__ yp = y->ptr<Tx>();
ACLCHECK(aclrtSynchronizeDevice());
ACLCHECK(aclrtSynchronizeStream(aclstream));
HCCLCHECK(HcclReduce(xp, yp, (uint64_t)x->num, @T_HCCL, @REDUCE_OP, @Root, comm, aclstream));
ACLCHECK(aclrtSynchronizeDevice());
ACLCHECK(aclrtSynchronizeStream(aclstream));
}
#endif // JIT
} // jittor

View File

@ -0,0 +1,19 @@
#pragma once
#include "op.h"
namespace jittor {
struct HcclReduceOp : Op {
Var* x, * y;
string reduce_op;
int root;
HcclReduceOp(Var* x, string reduce_op="sum", int root=0);
void infer_shape() override;
const char* name() const override { return "hccl_reduce"; }
VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override;
DECLARE_jit_run;
};
} // jittor

View File

@ -0,0 +1,60 @@
// ***************************************************************
// Copyright (c) 2025 Jittor.
// All Rights Reserved.
// Maintainers:
// Jiapeng Zhang <zjp24@mails.tsinghua.edu.cn>.
//
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#include "hccl_wrapper.h"
#include "event_queue.h"
#include "acl_jittor.h"
#include <acl/acl.h>
namespace jittor {
HcclRootInfo root_info;
HcclComm comm;
uint32_t hccl_device_id = 0;
struct hccl_initer {
uint32_t device_count = 0;
hccl_initer() {
ACLCHECK(aclrtGetDeviceCount(&device_count));
if (!device_count) return;
if (!inside_mpi) return;
hccl_device_id = mpi_local_rank;
if (mpi_local_rank >= device_count) {
LOGw << "mpi_local_rank(">>mpi_local_rank>>") is larger than device_count("
>>device_count>>")";
hccl_device_id = hccl_device_id % device_count;
}
LOGv << "HCCL init in device" << hccl_device_id << "local_rank" << mpi_local_rank;
//LOGir << aclstream;
//event_queue.run_sync([]() {
ACLCHECK(aclrtSetDevice(hccl_device_id));
//});
use_device_mpi = true;
LOGir << "HCCL init in device" << hccl_device_id << "local_rank" << mpi_local_rank;
if (mpi_world_rank == 0)
HCCLCHECK(HcclGetRootInfo(&root_info));
MPI_CHECK(MPI_Bcast(&root_info, HCCL_ROOT_INFO_BYTES, MPI_CHAR, 0, MPI_COMM_WORLD));
//MPI_Barrier(MPI_COMM_WORLD);
LOGir << "Count:" << device_count << "HCCL init in device" << hccl_device_id;
HCCLCHECK(HcclCommInitRootInfo(device_count, &root_info, hccl_device_id, &comm));
ACLCHECK(aclrtCreateStream(&aclstream));
LOGi << "HCCL init success in device" << hccl_device_id;
}
~hccl_initer() {
if (!device_count) return;
if (!inside_mpi) return;
if (!use_device_mpi) return;
HCCLCHECK(HcclCommDestroy(comm));
}
};
static hccl_initer hccl_initer;
}

View File

@ -37,14 +37,23 @@ MpiAllReduceOp::MpiAllReduceOp(Var* x, NanoString op) : x(x), op(op) {
} }
ASSERT(op == ns_add) << "Not supported MPI op" << op; ASSERT(op == ns_add) << "Not supported MPI op" << op;
#ifdef HAS_CUDA #ifdef HAS_CUDA
if (use_device_mpi && use_cuda) { if (use_device_mpi && use_cuda) {
static auto nccl_all_reduce = has_op("nccl_all_reduce") static auto nccl_all_reduce = has_op("nccl_all_reduce")
? get_op_info("nccl_all_reduce").get_constructor<VarPtr, Var*>() ? get_op_info("nccl_all_reduce").get_constructor<VarPtr, Var*>()
: nullptr; : nullptr;
static auto hccl_all_reduce = has_op("hccl_all_reduce")
? get_op_info("hccl_all_reduce").get_constructor<VarPtr, Var*, string>()
: nullptr;
if (nccl_all_reduce) { if (nccl_all_reduce) {
auto var = nccl_all_reduce(x); auto var = nccl_all_reduce(x);
forward(var); forward(var);
return; return;
} else if (hccl_all_reduce) {
auto var = hccl_all_reduce(x, "sum");
//exe.run_sync({var}, true);
forward(var);
return;
} }
} }
#endif #endif

View File

@ -26,10 +26,18 @@ MpiBroadcastOp::MpiBroadcastOp(Var* x, int root) : x(x), root(root) {
static auto nccl_broadcast = has_op("nccl_broadcast") static auto nccl_broadcast = has_op("nccl_broadcast")
? get_op_info("nccl_broadcast").get_constructor<VarPtr, Var*, int>() ? get_op_info("nccl_broadcast").get_constructor<VarPtr, Var*, int>()
: nullptr; : nullptr;
static auto hccl_broadcast = has_op("hccl_broadcast")
? get_op_info("hccl_broadcast").get_constructor<VarPtr, Var*, int>()
: nullptr;
if (nccl_broadcast) { if (nccl_broadcast) {
auto var = nccl_broadcast(x, root); auto var = nccl_broadcast(x, root);
forward(var); forward(var);
return; return;
} else if (hccl_broadcast) {
auto var = hccl_broadcast(x, root);
//exe.run_sync({var}, true);
forward(var);
return;
} }
} }
#endif #endif

View File

@ -41,10 +41,18 @@ MpiReduceOp::MpiReduceOp(Var* x, NanoString op, int root) : x(x), op(op), root(r
static auto nccl_reduce = has_op("nccl_reduce") static auto nccl_reduce = has_op("nccl_reduce")
? get_op_info("nccl_reduce").get_constructor<VarPtr, Var*, int>() ? get_op_info("nccl_reduce").get_constructor<VarPtr, Var*, int>()
: nullptr; : nullptr;
static auto hccl_reduce = has_op("hccl_reduce")
? get_op_info("hccl_reduce").get_constructor<VarPtr, Var*, string, int>()
: nullptr;
if (nccl_reduce) { if (nccl_reduce) {
auto var = nccl_reduce(x, root); auto var = nccl_reduce(x, root);
forward(var); forward(var);
return; return;
} else if (hccl_reduce) {
auto var = hccl_reduce(x, "sum", root);
//exe.run_sync({var}, true);
forward(var);
return;
} }
} }
#endif #endif

View File

@ -7,7 +7,13 @@
// This file is subject to the terms and conditions defined in // This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package. // file 'LICENSE.txt', which is part of this source code package.
// *************************************************************** // ***************************************************************
#include <cmath>
#include <limits>
#include <cstring>
#if defined(__x86_64__) || defined(_M_X64)
#include <immintrin.h> #include <immintrin.h>
#endif
#include <unistd.h> #include <unistd.h>
#include <stdint.h> #include <stdint.h>
#include <stdio.h> #include <stdio.h>
@ -31,8 +37,80 @@ namespace jittor {
MPI_Datatype MPI_HALF; MPI_Datatype MPI_HALF;
MPI_Op MPI_HALF_ADD; MPI_Op MPI_HALF_ADD;
#if !defined(__x86_64__) && !defined(_M_X64)
// ARM架构下的FP16-FP32转换辅助函数
static inline float fp16_to_fp32_value(uint16_t h) {
unsigned sign = ((h >> 15) & 1);
unsigned exponent = ((h >> 10) & 0x1f);
unsigned mantissa = ((h & 0x3ff) << 13);
if (exponent == 0) {
if (mantissa == 0) {
return sign ? -0.0f : 0.0f;
} else {
// 非规格化数
while (!(mantissa & 0x400000)) {
mantissa <<= 1;
exponent -= 1;
}
exponent += 1;
mantissa &= ~0x400000;
}
} else if (exponent == 31) {
if (mantissa == 0) {
return sign ? -std::numeric_limits<float>::infinity() : std::numeric_limits<float>::infinity();
} else {
return std::numeric_limits<float>::quiet_NaN();
}
}
exponent += (127 - 15);
mantissa <<= 10;
unsigned int i = ((sign << 31) | (exponent << 23) | mantissa);
float f;
std::memcpy(&f, &i, sizeof(float));
return f;
}
static inline uint16_t fp32_to_fp16_value(float f) {
unsigned int i;
std::memcpy(&i, &f, sizeof(float));
unsigned sign = ((i >> 31) & 0x1);
unsigned exponent = ((i >> 23) & 0xff);
unsigned mantissa = (i & 0x7fffff);
unsigned short h = 0;
if (exponent == 0) {
// 零或非规格化数
h = (sign << 15);
} else if (exponent == 0xff) {
// 无穷大或NaN
h = (sign << 15) | 0x7c00;
if (mantissa) h |= 0x200;
} else {
// 规格化数
int new_exp = exponent - 127 + 15;
if (new_exp < 0) {
// 下溢出到零
h = (sign << 15);
} else if (new_exp > 30) {
// 上溢出到无穷大
h = (sign << 15) | 0x7c00;
} else {
// 正常转换
h = (sign << 15) | (new_exp << 10) | (mantissa >> 13);
}
}
return h;
}
#endif
void HalfAdd(void* invec, void* inoutvec, int* len, MPI_Datatype* type) { void HalfAdd(void* invec, void* inoutvec, int* len, MPI_Datatype* type) {
// return; #if defined(__x86_64__) || defined(_M_X64)
short* in = (short*)invec; short* in = (short*)invec;
short* inout = (short*)inoutvec; short* inout = (short*)inoutvec;
@ -62,9 +140,27 @@ void HalfAdd(void* invec, void* inoutvec, int* len, MPI_Datatype* type) {
// 将单精度浮点数转换回半精度浮点数,并存储结果 // 将单精度浮点数转换回半精度浮点数,并存储结果
*(inout + i) = _mm_cvtps_ph(out, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)[0]; *(inout + i) = _mm_cvtps_ph(out, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)[0];
} }
#else
// ARM架构实现使用基本的半精度浮点数运算
uint16_t* in = (uint16_t*)invec;
uint16_t* inout = (uint16_t*)inoutvec;
int total = *len;
// 简单的逐元素相加实现
for (int i = 0; i < total; i++) {
// 将FP16转换为FP32
float in_val = fp16_to_fp32_value(in[i]);
float inout_val = fp16_to_fp32_value(inout[i]);
// 执行加法
float result = in_val + inout_val;
// 将结果转回FP16
inout[i] = fp32_to_fp16_value(result);
}
#endif
} }
int mpi_world_size = 1; int mpi_world_size = 1;
int mpi_world_rank = 0; int mpi_world_rank = 0;
int mpi_local_size = 1; int mpi_local_size = 1;

View File

@ -31,15 +31,15 @@ cudaEvent_t event;
struct Init { struct Init {
Init() { Init() {
if (!get_device_count()) return; if (!get_device_count()) return;
checkCudaErrors(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking)); //checkCudaErrors(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking));
checkCudaErrors(cudaEventCreate(&event, cudaEventDisableTiming)); //checkCudaErrors(cudaEventCreate(&event, cudaEventDisableTiming));
stream = aclstream; stream = aclstream;
} }
~Init() { ~Init() {
if (!get_device_count()) return; if (!get_device_count()) return;
peekCudaErrors(cudaDeviceSynchronize()); //peekCudaErrors(cudaDeviceSynchronize());
peekCudaErrors(cudaStreamDestroy(stream)); //peekCudaErrors(cudaStreamDestroy(stream));
peekCudaErrors(cudaEventDestroy(event)); //peekCudaErrors(cudaEventDestroy(event));
} }
} init; } init;