Update HCCL to support multi npus

This commit is contained in:
zjp_shadow 2025-02-10 16:18:34 +08:00
parent 8289c61138
commit a5c37bd58e
21 changed files with 1789 additions and 1201 deletions

View File

@ -600,6 +600,26 @@ def setup_nccl():
nccl_ops = nccl.ops
LOG.vv("Get nccl_ops: "+str(dir(nccl_ops)))
def setup_hccl():
global hccl_ops
hccl_src_dir = os.path.join(jittor_path, "extern", "acl", "hccl")
hccl_src_files = []
for r, _, f in os.walk(hccl_src_dir):
for fname in f:
hccl_src_files.append(os.path.join(r, fname))
hccl_include_path = os.path.join(os.environ.get("ASCEND_TOOLKIT_HOME"), "aarch64-linux/include/hccl")
hccl_lib_name = os.path.join(os.environ.get("ASCEND_TOOLKIT_HOME"), "aarch64-linux/lib64/libhccl.so")
ctypes.CDLL(hccl_lib_name, dlopen_flags)
hccl = compile_custom_ops(hccl_src_files,
extra_flags=f" -I\"{hccl_include_path}\" {mpi_compile_flags} ",
return_module=True, dlopen_flags=os.RTLD_GLOBAL | os.RTLD_NOW,
gen_name_="jittor_hccl_core")
hccl_ops = hccl.ops
LOG.vv("Get hccl_ops: "+str(dir(hccl_ops)))
def manual_link(flags):
lib_dirs = []
libs = []
@ -697,10 +717,12 @@ cudnn = cublas = curand = cufft = None
setup_mpi()
rank = mpi.world_rank() if in_mpi else 0
world_size = mpi.world_size() if in_mpi else 1
setup_nccl()
setup_cutt()
setup_cutlass()
if has_acl:
setup_hccl()
elif has_cuda:
setup_nccl()
setup_cutt()
setup_cutlass()
# try:
setup_mkl()

File diff suppressed because it is too large Load Diff

View File

@ -38,7 +38,6 @@ namespace jittor
static void *acl_jittor_process_callback(void *)
{
acl_jittor_thread_running = 1;
int deviceId = 0;
while (acl_jittor_thread_running)
{

View File

@ -1,496 +1,502 @@
// ***************************************************************
// Copyright (c) 2023 Jittor. All Rights Reserved.
// Maintainers: Dun Liang <randonlang@gmail.com>.
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#include <acl/acl.h>
#include <acl/acl_op_compiler.h>
#include <Python.h>
#include <pystate.h>
#include <algorithm>
#include <queue>
#include <set>
#include "common.h"
#include "op.h"
#include "acl_jittor.h"
#include "ops/random_op.h"
#include "ops/reduce_op.h"
#include "ops/binary_op.h"
#include "ops/broadcast_to_op.h"
#include "ops/transpose_op.h"
#include "ops/array_op.h"
#include "ops/code_op.h"
#include "fused_op.h"
#include "ops/unary_op.h"
#include "ops/ternary_op.h"
#include "executor.h"
#include "misc/cuda_flags.h"
#include "mem/allocator.h"
#include "op_compiler.h"
#include "ops/op_register.h"
#include "opt/tuner_manager.h"
#include "utils/str_utils.h"
#include "aclnn/aclnn.h"
#include "aclops/aclops.h"
namespace jittor
{
void free_var_mem(Var *v);
unordered_map<uint32, string> opname_map = {
// unary op
{ns_cast, "Cast"},
{ns_negative, "Neg"},
{ns_abs, "Abs"},
{ns_exp, "Exp"},
{ns_log, "Log"},
{ns_sqrt, "Sqrt"},
{ns_ceil, "Ceil"},
{ns_floor, "Floor"},
{ns_round, "Round"},
// m(round_int)
// m(floor_int)
// m(ceil_int)
{ns_sin, "Sin"},
{ns_cos, "Cos"},
{ns_tan, "Tan"},
{ns_asin, "Asin"},
{ns_acos, "Acos"},
{ns_atan, "Atan"},
{ns_sinh, "Sinh"},
{ns_cosh, "Cosh"},
{ns_tanh, "Tanh"},
{ns_asinh, "Asinh"},
{ns_acosh, "Acosh"},
{ns_atanh, "Atanh"},
{ns_sigmoid, "Sigmoid"},
{ns_erf, "Erf"},
{ns_erfinv, "Erfinv"},
{ns_logical_not, "LogicalNot"},
{ns_bitwise_not, "BitwiseNot"},
// binary op
{ns_pow, "Pow"},
{ns_maximum, "Maximum"},
{ns_minimum, "Minimum"},
{ns_add, "Add"},
{ns_subtract, "Sub"},
{ns_multiply, "Mul"},
{ns_divide, "RealDiv"},
{ns_floor_divide, "FloorDiv"},
{ns_mod, "Mod"},
{ns_less, "Less"},
{ns_less_equal, "LessEqual"},
{ns_greater, "Greater"},
{ns_greater_equal, "GreaterEqual"},
{ns_equal, "Equal"},
{ns_not_equal, "NotEqual"},
{ns_left_shift, "LeftShift"},
{ns_right_shift, "RightShift"},
{ns_logical_and, "LogicalAnd"},
{ns_logical_or, "LogicalOr"},
{ns_logical_xor, "LogicalXor"},
{ns_bitwise_and, "BitwiseAnd"},
{ns_bitwise_or, "BitwiseOr"},
{ns_bitwise_xor, "BitwiseXor"},
};
void fallback_cpu(Op *op)
{
LOGy << "!!! fallback_cpu " << op;
use_cuda = 0;
for (auto v : op->inputs())
{
if (v->mem_ptr && v->allocator->is_cuda())
{
migrate_to_cpu(v, exe.allocator);
}
}
for (auto v : op->outputs())
{
if (v->mem_ptr && v->allocator->is_cuda())
{
migrate_to_cpu(v, exe.allocator);
}
}
op->flags.set(NodeFlags::_cpu);
op->flags.set(NodeFlags::_cuda, 0);
if (op->name() == string("fused"))
{
auto fop = (FusedOp *)op;
for (auto op : fop->ops)
{
op->flags.set(NodeFlags::_cpu);
op->flags.set(NodeFlags::_cuda, 0);
}
}
op->do_run();
use_cuda = 1;
}
/*
check compile
if compiled: exec
else: compile
check is fused
check is relay
else
compile func = try exec
if failed: fallback_cpu
else
try compile
if failed: fallback_cpu
*/
extern jit_op_entry_t (*do_compile_hook)(Op *);
jit_op_entry_t do_compile_inner(Op *op);
void try_exec_and_fallback_cpu(Op *op)
{
auto fop = (FusedOp *)op;
std::set<Var *> new_alloced;
map<Op *, int> op_indeg;
map<Var *, int> var_outdeg;
std::queue<Op *> queue;
for (Op *op : fop->ops)
op_indeg[op] = 0;
map<Op *, vector<Op *>> out_map;
map<Var *, vector<Op *>> from;
int len = 0;
for (Op *v : fop->ops)
{
for (auto in : v->inputs())
from[in].push_back(v);
++len;
}
for (Op *u : fop->ops)
{
for (auto out : u->outputs())
{
if (from.find(out) != from.end())
{
for (auto v : from[out])
{
++op_indeg[v];
++var_outdeg[out];
out_map[u].push_back(v);
}
}
}
}
for (Op *op : fop->ops)
{
if (op_indeg[op] == 0)
queue.push(op);
}
int total = 0;
int fallback = 0;
try
{
while (!queue.empty())
{
total++;
for (auto in : op->inputs())
{
ASSERT(in->mem_ptr);
}
auto op = queue.front();
queue.pop();
for (auto out : op->outputs())
{
if (out->mem_ptr)
continue;
out->alloc(exe.allocator);
new_alloced.insert(out);
}
for (auto out : out_map[op])
{
--op_indeg[out];
if (op_indeg[out] == 0)
queue.push(out);
}
if (op->name() == string("unary"))
{
auto uop = (UnaryOp *)op;
UnaryOpRunner op;
op.add(uop->x, true);
op.add(uop->y, false);
auto iter = opname_map.find(uop->ns);
ASSERT(iter != opname_map.end()) << "op " << uop->ns << " not found";
op.name = iter->second;
op.jt_name = uop->name();
op.run();
}
else if (op->name() == string("binary"))
{
auto bop = (BinaryOp *)op;
BinaryOpRunner op;
op.add(bop->x, true);
op.add(bop->y, true);
op.add(bop->z, false);
auto iter = opname_map.find(bop->ns);
ASSERT(iter != opname_map.end()) << "op " << bop->ns << " not found";
op.name = iter->second;
op.jt_name = bop->name();
if (bop->x->dtype() == ns_bool and bop->y->dtype() == ns_bool)
{
// BitwiseOr, BitwiseAnd, BitwiseXor -> LogicalOr, LogicalAnd, LogicalXor
if (bop->ns == ns_bitwise_or)
{
op.name = "LogicalOr";
}
else if (bop->ns == ns_bitwise_and)
{
op.name = "LogicalAnd";
}
else if (bop->ns == ns_bitwise_xor)
{
op.name = "LogicalXor";
}
}
op.run();
}
else if (op->name() == string("ternary"))
{
auto top = (TernaryOp *)op;
TernaryOpRunner op;
op.add(top->cond, true);
op.add(top->x, true);
op.add(top->y, true);
op.add(top->z, false);
op.run();
}
else if (op->name() == string("array"))
{
auto aop = (ArrayOp *)op;
aclrtMemcpy(aop->output->mem_ptr, aop->output->size, aop->ptr<void>(), aop->output->size, ACL_MEMCPY_HOST_TO_DEVICE);
}
else if (op->name() == string("reduce"))
{
auto rop = (ReduceOp *)op;
ReduceOpRunner op;
if (rop->ns == ns_add)
op.op_idx = 9;
else if (rop->ns == ns_multiply)
// TODO unsupported the multi dim
op.op_idx = 999;
else if (rop->ns == ns_maximum)
op.op_idx = 11;
else if (rop->ns == ns_minimum)
op.op_idx = 12;
else if (rop->ns == ns_mean)
op.op_idx = 10;
else
LOGf << "op " << rop->ns << " not supported";
op.add(rop->x, true);
ReduceAttr *attr = new ReduceAttr();
for (int i = 0; i < rop->x->shape.size(); i++)
if (rop->reduce_mask & (1 << i))
attr->axes.push_back(i);
if (rop->x->shape.size() == rop->y->shape.size())
attr->keepdims = true;
else
attr->keepdims = false;
op.op_attr.reset(attr);
op.add(rop->y, false);
op.run();
aclrtSynchronizeStream(aclstream);
}
else if (op->name() == string("broadcast_to"))
{
auto bop = (BroadcastToOp *)op;
ExpandOpRunner op;
op.jt_name = "expand";
NanoVector xshape, xshape_bk = bop->x->shape;
NanoVector zshape = bop->z->shape;
for (int i = 0; i < zshape.size(); i++)
{
if (bop->bcast_mask & (1 << i))
{
xshape.push_back(1);
}
else
{
xshape.push_back(zshape[i]);
}
}
bop->x->shape = xshape;
op.add(bop->x, true);
// bop->x->shape = xshape_bk;
op.add(bop->z, false);
op.run();
bop->x->shape = xshape_bk;
aclrtSynchronizeStream(aclstream);
}
else if (op->name() == string("fuse_transpose"))
{
// replace fuse_transpose with transpose
auto top = (TransposeOp *)op;
TransposeOpRunner op;
op.add(top->x, true);
op.add(top->y, false);
op.jt_name = "transpose";
ReduceAttr *attr = new ReduceAttr();
for (int i = 0; i < top->axes.size(); i++)
attr->axes.push_back(top->axes[i]);
op.op_attr.reset(attr);
op.run();
}
else
{
LOGf << "op " << op->name() << " not supported";
}
for (auto in : op->inputs())
{
--var_outdeg[in];
if (var_outdeg[in] == 0)
{
if (new_alloced.find(in) != new_alloced.end())
{
free_var_mem(in);
new_alloced.erase(in);
}
}
}
}
}
catch (std::exception &e)
{
fallback = 1;
LOGir << "fallback cpu" << e.what();
}
for (auto v : new_alloced)
{
free_var_mem(v);
}
if (fallback)
{
fallback_cpu(op);
}
}
extern int current_seed;
extern int64 current_offset;
static unordered_map<string, std::function<void(Op *)>> acl_ops = {
{"curand_random", [&current_seed, &current_offset](Op *op)
{
auto _op = (RandomOp *)op;
RandomOpRunner runner(_op->type == ns_uniform ? "RandomUniform" : "RandomNormal");
auto out = op->output(0);
RandomAttr *attr = new RandomAttr();
attr->seed = current_seed;
attr->offset = current_offset;
runner.jt_name = "random";
runner.op_attr.reset(attr);
runner.add(out, false);
runner.run();
current_offset += out->numel();
}},
};
static void exec_mapped_acl_ops(Op *op)
{
auto iter = acl_ops.find(op->name());
if (iter != acl_ops.end())
{
LOGv << "exec acl op " << op->name() << op;
iter->second(op);
}
else
{
LOGf << "op " << op->name() << " not supported";
}
}
static jit_op_entry_t acl_do_compile(Op *op)
{
LOGv << "compile" << op;
OpCompiler oc(op);
string *src = &oc.src;
for (auto op_type : op_types)
op_type->post_pass(&oc);
string src_after_passes;
// if is fused op
if (oc.op)
{
TunerManager tm(&oc);
src_after_passes = tm.tune();
src = &src_after_passes;
}
op->compile_optimize(*src);
if (!op->flags.get(NodeFlags::_cuda))
{
LOGv << "compile cpu";
return oc.compile(op->get_jit_key(get_jk()), *src);
}
if (op->name() == string("fused"))
{
FusedOp *fop = (FusedOp *)op;
// if is a relayed op
if (fop->context->vrm.relay_groups.size())
{
LOGv << "relay fused op";
return oc.compile(op->get_jit_key(get_jk()), *src);
}
else
{
return &try_exec_and_fallback_cpu;
}
}
else if (op->name() == string("code"))
{
CodeOp *cop = (CodeOp *)op;
if (cop->cuda_src.find("acl") != string::npos)
{
LOGv << "compile acl op";
return oc.compile(op->get_jit_key(get_jk()), *src);
}
else
{
return &exec_mapped_acl_ops;
}
}
else
{
LOGv << "compile finish" << op;
return &exec_mapped_acl_ops;
}
return do_compile_inner(op);
}
// from op_register.cc
extern unordered_map<string, OpInfo> op_info_map;
void init_acl_ops()
{
do_compile_hook = acl_do_compile;
vector<string> to_erase;
for (auto &kv : op_info_map)
{
if (startswith(kv.first, "cu") && acl_ops.count(kv.first) == 0)
{
to_erase.push_back(kv.first);
}
}
for (auto &k : to_erase)
{
LOGv << "op not supported: " << k << ", erase it.";
op_info_map.erase(k);
}
}
} // jittor
// ***************************************************************
// Copyright (c) 2023 Jittor. All Rights Reserved.
// Maintainers: Dun Liang <randonlang@gmail.com>.
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#include <acl/acl.h>
#include <acl/acl_op_compiler.h>
#include <Python.h>
#include <pystate.h>
#include <algorithm>
#include <queue>
#include <set>
#include "common.h"
#include "op.h"
#include "acl_jittor.h"
#include "ops/random_op.h"
#include "ops/reduce_op.h"
#include "ops/binary_op.h"
#include "ops/broadcast_to_op.h"
#include "ops/transpose_op.h"
#include "ops/array_op.h"
#include "ops/code_op.h"
#include "fused_op.h"
#include "ops/unary_op.h"
#include "ops/ternary_op.h"
#include "executor.h"
#include "misc/cuda_flags.h"
#include "mem/allocator.h"
#include "op_compiler.h"
#include "ops/op_register.h"
#include "opt/tuner_manager.h"
#include "utils/str_utils.h"
#include "aclnn/aclnn.h"
#include "aclops/aclops.h"
namespace jittor
{
void free_var_mem(Var *v);
unordered_map<uint32, string> opname_map = {
// unary op
{ns_cast, "Cast"},
{ns_negative, "Neg"},
{ns_abs, "Abs"},
{ns_exp, "Exp"},
{ns_log, "Log"},
{ns_sqrt, "Sqrt"},
{ns_ceil, "Ceil"},
{ns_floor, "Floor"},
{ns_round, "Round"},
// m(round_int)
// m(floor_int)
// m(ceil_int)
{ns_sin, "Sin"},
{ns_cos, "Cos"},
{ns_tan, "Tan"},
{ns_asin, "Asin"},
{ns_acos, "Acos"},
{ns_atan, "Atan"},
{ns_sinh, "Sinh"},
{ns_cosh, "Cosh"},
{ns_tanh, "Tanh"},
{ns_asinh, "Asinh"},
{ns_acosh, "Acosh"},
{ns_atanh, "Atanh"},
{ns_sigmoid, "Sigmoid"},
{ns_erf, "Erf"},
{ns_erfinv, "Erfinv"},
{ns_logical_not, "LogicalNot"},
{ns_bitwise_not, "BitwiseNot"},
// binary op
{ns_pow, "Pow"},
{ns_maximum, "Maximum"},
{ns_minimum, "Minimum"},
{ns_add, "Add"},
{ns_subtract, "Sub"},
{ns_multiply, "Mul"},
{ns_divide, "RealDiv"},
{ns_floor_divide, "FloorDiv"},
{ns_mod, "Mod"},
{ns_less, "Less"},
{ns_less_equal, "LessEqual"},
{ns_greater, "Greater"},
{ns_greater_equal, "GreaterEqual"},
{ns_equal, "Equal"},
{ns_not_equal, "NotEqual"},
{ns_left_shift, "LeftShift"},
{ns_right_shift, "RightShift"},
{ns_logical_and, "LogicalAnd"},
{ns_logical_or, "LogicalOr"},
{ns_logical_xor, "LogicalXor"},
{ns_bitwise_and, "BitwiseAnd"},
{ns_bitwise_or, "BitwiseOr"},
{ns_bitwise_xor, "BitwiseXor"},
};
void fallback_cpu(Op *op)
{
LOGy << "!!! fallback_cpu " << op;
use_cuda = 0;
for (auto v : op->inputs())
{
if (v->mem_ptr && v->allocator->is_cuda())
{
migrate_to_cpu(v, exe.allocator);
}
}
for (auto v : op->outputs())
{
if (v->mem_ptr && v->allocator->is_cuda())
{
migrate_to_cpu(v, exe.allocator);
}
}
op->flags.set(NodeFlags::_cpu);
op->flags.set(NodeFlags::_cuda, 0);
if (op->name() == string("fused"))
{
auto fop = (FusedOp *)op;
for (auto op : fop->ops)
{
op->flags.set(NodeFlags::_cpu);
op->flags.set(NodeFlags::_cuda, 0);
}
}
op->do_run();
use_cuda = 1;
}
/*
check compile
if compiled: exec
else: compile
check is fused
check is relay
else
compile func = try exec
if failed: fallback_cpu
else
try compile
if failed: fallback_cpu
*/
extern jit_op_entry_t (*do_compile_hook)(Op *);
jit_op_entry_t do_compile_inner(Op *op);
void try_exec_and_fallback_cpu(Op *op)
{
aclrtSynchronizeStream(aclstream);
auto fop = (FusedOp *)op;
std::set<Var *> new_alloced;
map<Op *, int> op_indeg;
map<Var *, int> var_outdeg;
std::queue<Op *> queue;
for (Op *op : fop->ops)
op_indeg[op] = 0;
map<Op *, vector<Op *>> out_map;
map<Var *, vector<Op *>> from;
int len = 0;
for (Op *v : fop->ops)
{
for (auto in : v->inputs())
from[in].push_back(v);
++len;
}
for (Op *u : fop->ops)
{
for (auto out : u->outputs())
{
if (from.find(out) != from.end())
{
for (auto v : from[out])
{
++op_indeg[v];
++var_outdeg[out];
out_map[u].push_back(v);
}
}
}
}
for (Op *op : fop->ops)
{
if (op_indeg[op] == 0)
queue.push(op);
}
int total = 0;
int fallback = 0;
try
{
while (!queue.empty())
{
total++;
for (auto in : op->inputs())
{
ASSERT(in->mem_ptr);
}
auto op = queue.front();
queue.pop();
for (auto out : op->outputs())
{
if (out->mem_ptr)
continue;
out->alloc(exe.allocator);
new_alloced.insert(out);
}
for (auto out : out_map[op])
{
--op_indeg[out];
if (op_indeg[out] == 0)
queue.push(out);
}
if (op->name() == string("unary"))
{
auto uop = (UnaryOp *)op;
UnaryOpRunner op;
op.add(uop->x, true);
op.add(uop->y, false);
auto iter = opname_map.find(uop->ns);
ASSERT(iter != opname_map.end()) << "op " << uop->ns << " not found";
op.name = iter->second;
op.jt_name = uop->name();
op.run();
}
else if (op->name() == string("binary"))
{
auto bop = (BinaryOp *)op;
BinaryOpRunner op;
op.add(bop->x, true);
op.add(bop->y, true);
op.add(bop->z, false);
auto iter = opname_map.find(bop->ns);
ASSERT(iter != opname_map.end()) << "op " << bop->ns << " not found";
op.name = iter->second;
op.jt_name = bop->name();
if (bop->x->dtype() == ns_bool and bop->y->dtype() == ns_bool)
{
// BitwiseOr, BitwiseAnd, BitwiseXor -> LogicalOr, LogicalAnd, LogicalXor
if (bop->ns == ns_bitwise_or)
{
op.name = "LogicalOr";
}
else if (bop->ns == ns_bitwise_and)
{
op.name = "LogicalAnd";
}
else if (bop->ns == ns_bitwise_xor)
{
op.name = "LogicalXor";
}
}
op.run();
}
else if (op->name() == string("ternary"))
{
auto top = (TernaryOp *)op;
TernaryOpRunner op;
op.add(top->cond, true);
op.add(top->x, true);
op.add(top->y, true);
op.add(top->z, false);
op.run();
}
else if (op->name() == string("array"))
{
auto aop = (ArrayOp *)op;
aclrtMemcpy(aop->output->mem_ptr, aop->output->size, aop->ptr<void>(), aop->output->size, ACL_MEMCPY_HOST_TO_DEVICE);
}
else if (op->name() == string("reduce"))
{
auto rop = (ReduceOp *)op;
ReduceOpRunner op;
if (rop->ns == ns_add)
op.op_idx = 9;
else if (rop->ns == ns_multiply)
// TODO unsupported the multi dim
op.op_idx = 999;
else if (rop->ns == ns_maximum)
op.op_idx = 11;
else if (rop->ns == ns_minimum)
op.op_idx = 12;
else if (rop->ns == ns_mean)
op.op_idx = 10;
else
LOGf << "op " << rop->ns << " not supported";
op.add(rop->x, true);
ReduceAttr *attr = new ReduceAttr();
for (int i = 0; i < rop->x->shape.size(); i++)
if (rop->reduce_mask & (1 << i))
attr->axes.push_back(i);
if (rop->x->shape.size() == rop->y->shape.size())
attr->keepdims = true;
else
attr->keepdims = false;
op.op_attr.reset(attr);
op.add(rop->y, false);
op.run();
aclrtSynchronizeStream(aclstream);
}
else if (op->name() == string("broadcast_to"))
{
auto bop = (BroadcastToOp *)op;
ExpandOpRunner op;
op.jt_name = "expand";
NanoVector xshape, xshape_bk = bop->x->shape;
NanoVector zshape = bop->z->shape;
for (int i = 0; i < zshape.size(); i++)
{
if (bop->bcast_mask & (1 << i))
{
xshape.push_back(1);
}
else
{
xshape.push_back(zshape[i]);
}
}
bop->x->shape = xshape;
op.add(bop->x, true);
// bop->x->shape = xshape_bk;
op.add(bop->z, false);
op.run();
bop->x->shape = xshape_bk;
aclrtSynchronizeStream(aclstream);
}
else if (op->name() == string("fuse_transpose"))
{
// replace fuse_transpose with transpose
auto top = (TransposeOp *)op;
TransposeOpRunner op;
op.add(top->x, true);
op.add(top->y, false);
op.jt_name = "transpose";
ReduceAttr *attr = new ReduceAttr();
for (int i = 0; i < top->axes.size(); i++)
attr->axes.push_back(top->axes[i]);
op.op_attr.reset(attr);
op.run();
}
else
{
LOGf << "op " << op->name() << " not supported";
}
for (auto in : op->inputs())
{
--var_outdeg[in];
if (var_outdeg[in] == 0)
{
if (new_alloced.find(in) != new_alloced.end())
{
free_var_mem(in);
new_alloced.erase(in);
}
}
}
}
}
catch (std::exception &e)
{
fallback = 1;
LOGir << "fallback cpu" << e.what();
}
for (auto v : new_alloced)
{
free_var_mem(v);
}
if (fallback)
{
fallback_cpu(op);
}
}
extern int current_seed;
extern int64 current_offset;
static unordered_map<string, std::function<void(Op *)>> acl_ops = {
{"curand_random", [&current_seed, &current_offset](Op *op)
{
auto _op = (RandomOp *)op;
RandomOpRunner runner(_op->type == ns_uniform ? "RandomUniform" : "RandomNormal");
auto out = op->output(0);
RandomAttr *attr = new RandomAttr();
attr->seed = current_seed;
attr->offset = current_offset;
runner.jt_name = "random";
runner.op_attr.reset(attr);
runner.add(out, false);
runner.run();
current_offset += out->numel();
}},
};
static void exec_mapped_acl_ops(Op *op)
{
auto iter = acl_ops.find(op->name());
if (iter != acl_ops.end())
{
LOGv << "exec acl op " << op->name() << op;
iter->second(op);
}
else
{
LOGf << "op " << op->name() << " not supported";
}
}
static jit_op_entry_t acl_do_compile(Op *op)
{
LOGv << "compile" << op;
OpCompiler oc(op);
string *src = &oc.src;
for (auto op_type : op_types)
op_type->post_pass(&oc);
string src_after_passes;
// if is fused op
if (oc.op)
{
TunerManager tm(&oc);
src_after_passes = tm.tune();
src = &src_after_passes;
}
op->compile_optimize(*src);
if (!op->flags.get(NodeFlags::_cuda))
{
LOGv << "compile cpu";
return oc.compile(op->get_jit_key(get_jk()), *src);
}
if (op->name() == string("fused"))
{
FusedOp *fop = (FusedOp *)op;
// if is a relayed op
if (fop->context->vrm.relay_groups.size())
{
LOGv << "relay fused op";
return oc.compile(op->get_jit_key(get_jk()), *src);
}
else
{
return &try_exec_and_fallback_cpu;
}
}
else if (op->name() == string("code"))
{
CodeOp *cop = (CodeOp *)op;
if (cop->cuda_src.find("acl") != string::npos)
{
LOGv << "compile acl op";
return oc.compile(op->get_jit_key(get_jk()), *src);
}
else
{
return &exec_mapped_acl_ops;
}
}
else if (strncmp(op->name(), "hccl", 4) == 0)
{
LOGv << "Compiling HCCL op: " << op->name();
return oc.compile(op->get_jit_key(get_jk()), *src);
}
else
{
LOGv << "compile finish" << op;
return &exec_mapped_acl_ops;
}
return do_compile_inner(op);
}
// from op_register.cc
extern unordered_map<string, OpInfo> op_info_map;
void init_acl_ops()
{
do_compile_hook = acl_do_compile;
vector<string> to_erase;
for (auto &kv : op_info_map)
{
if (startswith(kv.first, "cu") && acl_ops.count(kv.first) == 0)
{
to_erase.push_back(kv.first);
}
}
for (auto &k : to_erase)
{
LOGv << "op not supported: " << k << ", erase it.";
op_info_map.erase(k);
}
}
} // jittor

View File

@ -21,7 +21,7 @@ void PrintOutResult(std::vector<int64_t> &shape, void** deviceAddr) {
}
}
int Init(int32_t deviceId) {
/*int Init(int32_t deviceId) {
// 固定写法AscendCL初始化
auto ret = aclInit(nullptr);
CHECK_RET(ret == ACL_SUCCESS or ret == 100002, LOG_PRINT("aclInit failed. ERROR: %d\n", ret); return ret);
@ -30,7 +30,7 @@ int Init(int32_t deviceId) {
//ret = aclrtCreateStream(stream);
//CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtCreateStream failed. ERROR: %d\n", ret); return ret);
return 0;
}
}*/
/*
template <typename T>

View File

@ -125,7 +125,7 @@ int64_t GetShapeSize(const std::vector<int64_t> &shape);
void PrintOutResult(std::vector<int64_t> &shape, void **deviceAddr);
int Init(int32_t deviceId);
//int Init(int32_t deviceId);
/*
template <typename T>

View File

@ -0,0 +1,38 @@
// ***************************************************************
// Copyright (c) 2025 Jittor.
// All Rights Reserved.
// Maintainers:
// Jiapeng Zhang <zjp24@mails.tsinghua.edu.cn>.
//
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#pragma once
#include "mpi_wrapper.h"
#define ACLCHECK(ret) do {\
if(ret != ACL_SUCCESS)\
{\
LOGe << "retcode: " << ret;\
return;\
}\
} while(0)\
#define HCCLCHECK(ret) do {\
if(ret != HCCL_SUCCESS)\
{\
LOGe << HcclGetErrorString(ret) << " retcode: " << ret;\
return;\
}\
} while(0)
#include <hccl.h>
namespace jittor {
EXTERN_LIB HcclRootInfo root_info;
EXTERN_LIB HcclComm comm;
EXTERN_LIB uint32_t hccl_device_id;
} // jittor

View File

@ -0,0 +1,70 @@
// ***************************************************************
// Copyright (c) 2025 Jittor.
// All Rights Reserved.
// Maintainers:
// Jiapeng Zhang <zjp24@mails.tsinghua.edu.cn>.
//
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#include "var.h"
#include "hccl_all_gather_op.h"
#include "ops/op_register.h"
#include "utils/str_utils.h"
#include "hccl_wrapper.h"
namespace jittor {
#ifndef JIT
static auto hccl_all_gather =
get_op_info("hccl_all_gather").get_constructor<VarPtr, Var*>();
HcclAllGatherOp::HcclAllGatherOp(Var* x) : x(x) {
flags.set(NodeFlags::_cpu, 0);
flags.set(NodeFlags::_cuda, 1);
y = create_output(nullptr, x->dtype());
}
void HcclAllGatherOp::infer_shape() {
NanoVector yshape;
yshape.push_back(mpi_world_size * x->shape[0]);
for (int i=1; i<x->shape.size(); i++)
yshape.push_back(x->shape[i]);
y->set_shape(yshape);
}
VarPtr HcclAllGatherOp::grad(Var* out, Var* dout, Var* v, int v_index) {
LOGf << "not implemented";
return nullptr;
}
void HcclAllGatherOp::jit_prepare(JK& jk) {
jk << "«Tx:" << x->dtype();
}
#else // JIT
void HcclAllGatherOp::jit_run() {
LOGir << "HcclAllGatherOp::jit_run";
@define(T_HCCL,
@if(@strcmp(@Tx,float)==0 || @strcmp(@Tx,float32)==0, HcclDataType::HCCL_DATA_TYPE_FP32)
@if(@strcmp(@Tx,int)==0 || @strcmp(@Tx,int32)==0, HcclDataType::HCCL_DATA_TYPE_INT32)
@if(@strcmp(@Tx,float64)==0, HcclDataType::HCCL_DATA_TYPE_FP64)
@if(@strcmp(@Tx,int64)==0, HcclDataType::HCCL_DATA_TYPE_INT64)
@if(@strcmp(@Tx,uint8)==0, HcclDataType::HCCL_DATA_TYPE_UINT8)
@if(@strcmp(@Tx,float16)==0, HcclDataType::HCCL_DATA_TYPE_FP16)
)
auto* __restrict__ xp = x->ptr<Tx>();
auto* __restrict__ yp = y->ptr<Tx>();
ACLCHECK(aclrtSynchronizeDevice());
ACLCHECK(aclrtSynchronizeStream(aclstream));
HCCLCHECK(HcclAllGather(xp, yp, (uint64_t)x->num, @T_HCCL, comm, aclstream));
ACLCHECK(aclrtSynchronizeDevice());
ACLCHECK(aclrtSynchronizeStream(aclstream));
}
#endif // JIT
} // jittor

View File

@ -0,0 +1,26 @@
// ***************************************************************
// Copyright (c) 2025 Jittor.
// All Rights Reserved.
// Maintainers:
// Jiapeng Zhang <zjp24@mails.tsinghua.edu.cn>.
//
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#pragma once
#include "op.h"
namespace jittor {
struct HcclAllGatherOp : Op {
Var* x, * y;
HcclAllGatherOp(Var* x);
void infer_shape() override;
const char* name() const override { return "hccl_all_gather"; }
VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override;
DECLARE_jit_run;
};
} // jittor

View File

@ -0,0 +1,62 @@
#include "var.h"
#include "hccl_all_reduce_op.h"
#include "ops/op_register.h"
#include "utils/str_utils.h"
#include "hccl_wrapper.h"
namespace jittor {
#ifndef JIT
static auto hccl_all_reduce =
get_op_info("hccl_all_reduce").get_constructor<VarPtr, Var*, string>();
HcclAllReduceOp::HcclAllReduceOp(Var* x, string reduce_op) : x(x), reduce_op(reduce_op) {
flags.set(NodeFlags::_cpu, 0);
flags.set(NodeFlags::_cuda, 1);
y = create_output(nullptr, x->dtype());
}
void HcclAllReduceOp::infer_shape() {
y->set_shape(x->shape);
}
VarPtr HcclAllReduceOp::grad(Var* out, Var* dout, Var* v, int v_index) {
return hccl_all_reduce(dout, reduce_op);
}
void HcclAllReduceOp::jit_prepare(JK& jk) {
jk << "«Tx:" << x->dtype();
jk << "«Op:" << reduce_op;
}
#else // JIT
void HcclAllReduceOp::jit_run() {
//LOGir << "HcclAllReduceOp::jit_run";
@define(T_HCCL,
@if(@strcmp(@Tx,float)==0 || @strcmp(@Tx,float32)==0, HcclDataType::HCCL_DATA_TYPE_FP32)
@if(@strcmp(@Tx,int)==0 || @strcmp(@Tx,int32)==0, HcclDataType::HCCL_DATA_TYPE_INT32)
@if(@strcmp(@Tx,float64)==0, HcclDataType::HCCL_DATA_TYPE_FP64)
@if(@strcmp(@Tx,int64)==0, HcclDataType::HCCL_DATA_TYPE_INT64)
@if(@strcmp(@Tx,uint8)==0, HcclDataType::HCCL_DATA_TYPE_UINT8)
@if(@strcmp(@Tx,float16)==0, HcclDataType::HCCL_DATA_TYPE_FP16)
)
@define(REDUCE_OP,
@if(@strcmp(@Op,sum)==0, HcclReduceOp::HCCL_REDUCE_SUM)
@if(@strcmp(@Op,prod)==0, HcclReduceOp::HCCL_REDUCE_PROD)
@if(@strcmp(@Op,max)==0, HcclReduceOp::HCCL_REDUCE_MAX)
@if(@strcmp(@Op,min)==0, HcclReduceOp::HCCL_REDUCE_MIN)
)
auto* __restrict__ xp = x->ptr<Tx>();
auto* __restrict__ yp = y->ptr<Tx>();
ACLCHECK(aclrtSynchronizeDevice());
ACLCHECK(aclrtSynchronizeStream(aclstream));
HCCLCHECK(HcclAllReduce(xp, yp, (uint64_t)x->num, @T_HCCL, @REDUCE_OP, comm, aclstream));
ACLCHECK(aclrtSynchronizeDevice());
ACLCHECK(aclrtSynchronizeStream(aclstream));
}
#endif // JIT
} // jittor

View File

@ -0,0 +1,18 @@
#pragma once
#include "op.h"
namespace jittor {
struct HcclAllReduceOp : Op {
Var* x, * y;
string reduce_op;
HcclAllReduceOp(Var* x, string reduce_op="sum");
void infer_shape() override;
const char* name() const override { return "hccl_all_reduce"; }
VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override;
DECLARE_jit_run;
};
} // jittor

View File

@ -0,0 +1,63 @@
#include "var.h"
#include "hccl_broadcast_op.h"
#include "ops/op_register.h"
#include "utils/str_utils.h"
#include "hccl_wrapper.h"
#include <cassert>
namespace jittor {
#ifndef JIT
static auto hccl_broadcast =
get_op_info("hccl_broadcast").get_constructor<VarPtr, Var*, int>();
HcclBroadcastOp::HcclBroadcastOp(Var* x, int root) : x(x), root(root) {
flags.set(NodeFlags::_cpu, 0);
flags.set(NodeFlags::_cuda, 1);
y = create_output(nullptr, x->dtype());
}
void HcclBroadcastOp::infer_shape() {
y->set_shape(x->shape);
}
VarPtr HcclBroadcastOp::grad(Var* out, Var* dout, Var* v, int v_index) {
return hccl_broadcast(dout, root);
}
void HcclBroadcastOp::jit_prepare(JK& jk) {
jk << "«Tx:" << x->dtype();
jk << "«Root:" << root;
}
#else // JIT
void HcclBroadcastOp::jit_run() {
//LOGir << "HcclBroadcastOp::jit_run";
@define(T_HCCL,
@if(@strcmp(@Tx,float)==0 || @strcmp(@Tx,float32)==0, HcclDataType::HCCL_DATA_TYPE_FP32)
@if(@strcmp(@Tx,int)==0 || @strcmp(@Tx,int32)==0, HcclDataType::HCCL_DATA_TYPE_INT32)
@if(@strcmp(@Tx,float64)==0, HcclDataType::HCCL_DATA_TYPE_FP64)
@if(@strcmp(@Tx,int64)==0, HcclDataType::HCCL_DATA_TYPE_INT64)
@if(@strcmp(@Tx,uint8)==0, HcclDataType::HCCL_DATA_TYPE_UINT8)
@if(@strcmp(@Tx,float16)==0, HcclDataType::HCCL_DATA_TYPE_FP16)
)
auto* __restrict__ xp = x->ptr<Tx>();
auto* __restrict__ yp = y->ptr<Tx>();
//LOGir << "HcclBroadcastOp::jit_run " << @Root << " " << hccl_device_id << " " << xp << " " << yp;
//ACLCHECK(aclrtSynchronizeStream(aclstream));
ACLCHECK(aclrtSynchronizeDevice());
ACLCHECK(aclrtSynchronizeStream(aclstream));
HCCLCHECK(HcclBroadcast(@Root == hccl_device_id ? xp : yp, (uint64_t)x->num, @T_HCCL, @Root, comm, aclstream));
if (@Root == hccl_device_id) {
ACLCHECK(aclrtMemcpy(yp, x->num * sizeof(Tx), xp, x->num * sizeof(Tx), ACL_MEMCPY_DEVICE_TO_DEVICE));
ACLCHECK(aclrtSynchronizeDevice());
}
ACLCHECK(aclrtSynchronizeDevice());
ACLCHECK(aclrtSynchronizeStream(aclstream));
}
#endif // JIT
} // jittor

View File

@ -0,0 +1,18 @@
#pragma once
#include "op.h"
namespace jittor {
struct HcclBroadcastOp : Op {
Var* x, * y;
int root;
HcclBroadcastOp(Var* x, int root=0);
void infer_shape() override;
const char* name() const override { return "hccl_broadcast"; }
VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override;
DECLARE_jit_run;
};
} // jittor

View File

@ -0,0 +1,63 @@
#include "var.h"
#include "hccl_reduce_op.h"
#include "ops/op_register.h"
#include "utils/str_utils.h"
#include "hccl_wrapper.h"
namespace jittor {
#ifndef JIT
HcclReduceOp::HcclReduceOp(Var* x, string reduce_op, int root) : x(x), reduce_op(reduce_op), root(root) {
flags.set(NodeFlags::_cpu, 0);
flags.set(NodeFlags::_cuda, 1);
y = create_output(nullptr, x->dtype());
}
void HcclReduceOp::infer_shape() {
y->set_shape(x->shape);
}
VarPtr HcclReduceOp::grad(Var* out, Var* dout, Var* v, int v_index) {
static auto hccl_broadcast =
get_op_info("hccl_broadcast").get_constructor<VarPtr, Var*, int>();
return hccl_broadcast(dout, root);
}
void HcclReduceOp::jit_prepare(JK& jk) {
jk << "«Tx:" << x->dtype();
jk << "«Op:" << reduce_op;
jk << "«Root:" << root;
}
#else // JIT
void HcclReduceOp::jit_run() {
LOGir << "HcclReduceOp::jit_run";
@define(T_HCCL,
@if(@strcmp(@Tx,float)==0 || @strcmp(@Tx,float32)==0, HcclDataType::HCCL_DATA_TYPE_FP32)
@if(@strcmp(@Tx,int)==0 || @strcmp(@Tx,int32)==0, HcclDataType::HCCL_DATA_TYPE_INT32)
@if(@strcmp(@Tx,float64)==0, HcclDataType::HCCL_DATA_TYPE_FP64)
@if(@strcmp(@Tx,int64)==0, HcclDataType::HCCL_DATA_TYPE_INT64)
@if(@strcmp(@Tx,uint8)==0, HcclDataType::HCCL_DATA_TYPE_UINT8)
@if(@strcmp(@Tx,float16)==0, HcclDataType::HCCL_DATA_TYPE_FP16)
)
@define(REDUCE_OP,
@if(@strcmp(@Op,sum)==0, HcclReduceOp::HCCL_REDUCE_SUM)
@if(@strcmp(@Op,prod)==0, HcclReduceOp::HCCL_REDUCE_PROD)
@if(@strcmp(@Op,max)==0, HcclReduceOp::HCCL_REDUCE_MAX)
@if(@strcmp(@Op,min)==0, HcclReduceOp::HCCL_REDUCE_MIN)
)
auto* __restrict__ xp = x->ptr<Tx>();
auto* __restrict__ yp = y->ptr<Tx>();
ACLCHECK(aclrtSynchronizeDevice());
ACLCHECK(aclrtSynchronizeStream(aclstream));
HCCLCHECK(HcclReduce(xp, yp, (uint64_t)x->num, @T_HCCL, @REDUCE_OP, @Root, comm, aclstream));
ACLCHECK(aclrtSynchronizeDevice());
ACLCHECK(aclrtSynchronizeStream(aclstream));
}
#endif // JIT
} // jittor

View File

@ -0,0 +1,19 @@
#pragma once
#include "op.h"
namespace jittor {
struct HcclReduceOp : Op {
Var* x, * y;
string reduce_op;
int root;
HcclReduceOp(Var* x, string reduce_op="sum", int root=0);
void infer_shape() override;
const char* name() const override { return "hccl_reduce"; }
VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override;
DECLARE_jit_run;
};
} // jittor

View File

@ -0,0 +1,60 @@
// ***************************************************************
// Copyright (c) 2025 Jittor.
// All Rights Reserved.
// Maintainers:
// Jiapeng Zhang <zjp24@mails.tsinghua.edu.cn>.
//
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#include "hccl_wrapper.h"
#include "event_queue.h"
#include "acl_jittor.h"
#include <acl/acl.h>
namespace jittor {
HcclRootInfo root_info;
HcclComm comm;
uint32_t hccl_device_id = 0;
struct hccl_initer {
uint32_t device_count = 0;
hccl_initer() {
ACLCHECK(aclrtGetDeviceCount(&device_count));
if (!device_count) return;
if (!inside_mpi) return;
hccl_device_id = mpi_local_rank;
if (mpi_local_rank >= device_count) {
LOGw << "mpi_local_rank(">>mpi_local_rank>>") is larger than device_count("
>>device_count>>")";
hccl_device_id = hccl_device_id % device_count;
}
LOGv << "HCCL init in device" << hccl_device_id << "local_rank" << mpi_local_rank;
//LOGir << aclstream;
//event_queue.run_sync([]() {
ACLCHECK(aclrtSetDevice(hccl_device_id));
//});
use_device_mpi = true;
LOGir << "HCCL init in device" << hccl_device_id << "local_rank" << mpi_local_rank;
if (mpi_world_rank == 0)
HCCLCHECK(HcclGetRootInfo(&root_info));
MPI_CHECK(MPI_Bcast(&root_info, HCCL_ROOT_INFO_BYTES, MPI_CHAR, 0, MPI_COMM_WORLD));
//MPI_Barrier(MPI_COMM_WORLD);
LOGir << "Count:" << device_count << "HCCL init in device" << hccl_device_id;
HCCLCHECK(HcclCommInitRootInfo(device_count, &root_info, hccl_device_id, &comm));
ACLCHECK(aclrtCreateStream(&aclstream));
LOGi << "HCCL init success in device" << hccl_device_id;
}
~hccl_initer() {
if (!device_count) return;
if (!inside_mpi) return;
if (!use_device_mpi) return;
HCCLCHECK(HcclCommDestroy(comm));
}
};
static hccl_initer hccl_initer;
}

View File

@ -37,14 +37,23 @@ MpiAllReduceOp::MpiAllReduceOp(Var* x, NanoString op) : x(x), op(op) {
}
ASSERT(op == ns_add) << "Not supported MPI op" << op;
#ifdef HAS_CUDA
if (use_device_mpi && use_cuda) {
static auto nccl_all_reduce = has_op("nccl_all_reduce")
? get_op_info("nccl_all_reduce").get_constructor<VarPtr, Var*>()
: nullptr;
static auto hccl_all_reduce = has_op("hccl_all_reduce")
? get_op_info("hccl_all_reduce").get_constructor<VarPtr, Var*, string>()
: nullptr;
if (nccl_all_reduce) {
auto var = nccl_all_reduce(x);
forward(var);
return;
} else if (hccl_all_reduce) {
auto var = hccl_all_reduce(x, "sum");
//exe.run_sync({var}, true);
forward(var);
return;
}
}
#endif

View File

@ -26,10 +26,18 @@ MpiBroadcastOp::MpiBroadcastOp(Var* x, int root) : x(x), root(root) {
static auto nccl_broadcast = has_op("nccl_broadcast")
? get_op_info("nccl_broadcast").get_constructor<VarPtr, Var*, int>()
: nullptr;
static auto hccl_broadcast = has_op("hccl_broadcast")
? get_op_info("hccl_broadcast").get_constructor<VarPtr, Var*, int>()
: nullptr;
if (nccl_broadcast) {
auto var = nccl_broadcast(x, root);
forward(var);
return;
} else if (hccl_broadcast) {
auto var = hccl_broadcast(x, root);
//exe.run_sync({var}, true);
forward(var);
return;
}
}
#endif

View File

@ -41,10 +41,18 @@ MpiReduceOp::MpiReduceOp(Var* x, NanoString op, int root) : x(x), op(op), root(r
static auto nccl_reduce = has_op("nccl_reduce")
? get_op_info("nccl_reduce").get_constructor<VarPtr, Var*, int>()
: nullptr;
static auto hccl_reduce = has_op("hccl_reduce")
? get_op_info("hccl_reduce").get_constructor<VarPtr, Var*, string, int>()
: nullptr;
if (nccl_reduce) {
auto var = nccl_reduce(x, root);
forward(var);
return;
} else if (hccl_reduce) {
auto var = hccl_reduce(x, "sum", root);
//exe.run_sync({var}, true);
forward(var);
return;
}
}
#endif

View File

@ -7,7 +7,13 @@
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#include <cmath>
#include <limits>
#include <cstring>
#if defined(__x86_64__) || defined(_M_X64)
#include <immintrin.h>
#endif
#include <unistd.h>
#include <stdint.h>
#include <stdio.h>
@ -31,8 +37,80 @@ namespace jittor {
MPI_Datatype MPI_HALF;
MPI_Op MPI_HALF_ADD;
#if !defined(__x86_64__) && !defined(_M_X64)
// ARM架构下的FP16-FP32转换辅助函数
static inline float fp16_to_fp32_value(uint16_t h) {
unsigned sign = ((h >> 15) & 1);
unsigned exponent = ((h >> 10) & 0x1f);
unsigned mantissa = ((h & 0x3ff) << 13);
if (exponent == 0) {
if (mantissa == 0) {
return sign ? -0.0f : 0.0f;
} else {
// 非规格化数
while (!(mantissa & 0x400000)) {
mantissa <<= 1;
exponent -= 1;
}
exponent += 1;
mantissa &= ~0x400000;
}
} else if (exponent == 31) {
if (mantissa == 0) {
return sign ? -std::numeric_limits<float>::infinity() : std::numeric_limits<float>::infinity();
} else {
return std::numeric_limits<float>::quiet_NaN();
}
}
exponent += (127 - 15);
mantissa <<= 10;
unsigned int i = ((sign << 31) | (exponent << 23) | mantissa);
float f;
std::memcpy(&f, &i, sizeof(float));
return f;
}
static inline uint16_t fp32_to_fp16_value(float f) {
unsigned int i;
std::memcpy(&i, &f, sizeof(float));
unsigned sign = ((i >> 31) & 0x1);
unsigned exponent = ((i >> 23) & 0xff);
unsigned mantissa = (i & 0x7fffff);
unsigned short h = 0;
if (exponent == 0) {
// 零或非规格化数
h = (sign << 15);
} else if (exponent == 0xff) {
// 无穷大或NaN
h = (sign << 15) | 0x7c00;
if (mantissa) h |= 0x200;
} else {
// 规格化数
int new_exp = exponent - 127 + 15;
if (new_exp < 0) {
// 下溢出到零
h = (sign << 15);
} else if (new_exp > 30) {
// 上溢出到无穷大
h = (sign << 15) | 0x7c00;
} else {
// 正常转换
h = (sign << 15) | (new_exp << 10) | (mantissa >> 13);
}
}
return h;
}
#endif
void HalfAdd(void* invec, void* inoutvec, int* len, MPI_Datatype* type) {
// return;
#if defined(__x86_64__) || defined(_M_X64)
short* in = (short*)invec;
short* inout = (short*)inoutvec;
@ -62,9 +140,27 @@ void HalfAdd(void* invec, void* inoutvec, int* len, MPI_Datatype* type) {
// 将单精度浮点数转换回半精度浮点数,并存储结果
*(inout + i) = _mm_cvtps_ph(out, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)[0];
}
#else
// ARM架构实现使用基本的半精度浮点数运算
uint16_t* in = (uint16_t*)invec;
uint16_t* inout = (uint16_t*)inoutvec;
int total = *len;
// 简单的逐元素相加实现
for (int i = 0; i < total; i++) {
// 将FP16转换为FP32
float in_val = fp16_to_fp32_value(in[i]);
float inout_val = fp16_to_fp32_value(inout[i]);
// 执行加法
float result = in_val + inout_val;
// 将结果转回FP16
inout[i] = fp32_to_fp16_value(result);
}
#endif
}
int mpi_world_size = 1;
int mpi_world_rank = 0;
int mpi_local_size = 1;

View File

@ -31,15 +31,15 @@ cudaEvent_t event;
struct Init {
Init() {
if (!get_device_count()) return;
checkCudaErrors(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking));
checkCudaErrors(cudaEventCreate(&event, cudaEventDisableTiming));
//checkCudaErrors(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking));
//checkCudaErrors(cudaEventCreate(&event, cudaEventDisableTiming));
stream = aclstream;
}
~Init() {
if (!get_device_count()) return;
peekCudaErrors(cudaDeviceSynchronize());
peekCudaErrors(cudaStreamDestroy(stream));
peekCudaErrors(cudaEventDestroy(event));
//peekCudaErrors(cudaDeviceSynchronize());
//peekCudaErrors(cudaStreamDestroy(stream));
//peekCudaErrors(cudaEventDestroy(event));
}
} init;