mirror of https://github.com/Jittor/Jittor
Update HCCL to support multi npus
This commit is contained in:
parent
7ba878bf49
commit
e4be9b1f78
|
@ -600,6 +600,26 @@ def setup_nccl():
|
|||
nccl_ops = nccl.ops
|
||||
LOG.vv("Get nccl_ops: "+str(dir(nccl_ops)))
|
||||
|
||||
def setup_hccl():
|
||||
global hccl_ops
|
||||
|
||||
hccl_src_dir = os.path.join(jittor_path, "extern", "acl", "hccl")
|
||||
hccl_src_files = []
|
||||
for r, _, f in os.walk(hccl_src_dir):
|
||||
for fname in f:
|
||||
hccl_src_files.append(os.path.join(r, fname))
|
||||
|
||||
hccl_include_path = os.path.join(os.environ.get("ASCEND_TOOLKIT_HOME"), "aarch64-linux/include/hccl")
|
||||
hccl_lib_name = os.path.join(os.environ.get("ASCEND_TOOLKIT_HOME"), "aarch64-linux/lib64/libhccl.so")
|
||||
ctypes.CDLL(hccl_lib_name, dlopen_flags)
|
||||
|
||||
hccl = compile_custom_ops(hccl_src_files,
|
||||
extra_flags=f" -I\"{hccl_include_path}\" {mpi_compile_flags} ",
|
||||
return_module=True, dlopen_flags=os.RTLD_GLOBAL | os.RTLD_NOW,
|
||||
gen_name_="jittor_hccl_core")
|
||||
hccl_ops = hccl.ops
|
||||
LOG.vv("Get hccl_ops: "+str(dir(hccl_ops)))
|
||||
|
||||
def manual_link(flags):
|
||||
lib_dirs = []
|
||||
libs = []
|
||||
|
@ -697,10 +717,12 @@ cudnn = cublas = curand = cufft = None
|
|||
setup_mpi()
|
||||
rank = mpi.world_rank() if in_mpi else 0
|
||||
world_size = mpi.world_size() if in_mpi else 1
|
||||
setup_nccl()
|
||||
|
||||
setup_cutt()
|
||||
setup_cutlass()
|
||||
if has_acl:
|
||||
setup_hccl()
|
||||
elif has_cuda:
|
||||
setup_nccl()
|
||||
setup_cutt()
|
||||
setup_cutlass()
|
||||
|
||||
# try:
|
||||
setup_mkl()
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -38,7 +38,6 @@ namespace jittor
|
|||
static void *acl_jittor_process_callback(void *)
|
||||
{
|
||||
acl_jittor_thread_running = 1;
|
||||
int deviceId = 0;
|
||||
|
||||
while (acl_jittor_thread_running)
|
||||
{
|
||||
|
|
|
@ -1,496 +1,502 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2023 Jittor. All Rights Reserved.
|
||||
// Maintainers: Dun Liang <randonlang@gmail.com>.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#include <acl/acl.h>
|
||||
#include <acl/acl_op_compiler.h>
|
||||
#include <Python.h>
|
||||
#include <pystate.h>
|
||||
#include <algorithm>
|
||||
#include <queue>
|
||||
#include <set>
|
||||
#include "common.h"
|
||||
#include "op.h"
|
||||
#include "acl_jittor.h"
|
||||
#include "ops/random_op.h"
|
||||
#include "ops/reduce_op.h"
|
||||
#include "ops/binary_op.h"
|
||||
#include "ops/broadcast_to_op.h"
|
||||
#include "ops/transpose_op.h"
|
||||
#include "ops/array_op.h"
|
||||
#include "ops/code_op.h"
|
||||
#include "fused_op.h"
|
||||
#include "ops/unary_op.h"
|
||||
#include "ops/ternary_op.h"
|
||||
#include "executor.h"
|
||||
#include "misc/cuda_flags.h"
|
||||
#include "mem/allocator.h"
|
||||
#include "op_compiler.h"
|
||||
#include "ops/op_register.h"
|
||||
#include "opt/tuner_manager.h"
|
||||
#include "utils/str_utils.h"
|
||||
#include "aclnn/aclnn.h"
|
||||
#include "aclops/aclops.h"
|
||||
namespace jittor
|
||||
{
|
||||
void free_var_mem(Var *v);
|
||||
|
||||
unordered_map<uint32, string> opname_map = {
|
||||
// unary op
|
||||
{ns_cast, "Cast"},
|
||||
{ns_negative, "Neg"},
|
||||
{ns_abs, "Abs"},
|
||||
{ns_exp, "Exp"},
|
||||
{ns_log, "Log"},
|
||||
{ns_sqrt, "Sqrt"},
|
||||
{ns_ceil, "Ceil"},
|
||||
{ns_floor, "Floor"},
|
||||
{ns_round, "Round"},
|
||||
// m(round_int)
|
||||
// m(floor_int)
|
||||
// m(ceil_int)
|
||||
{ns_sin, "Sin"},
|
||||
{ns_cos, "Cos"},
|
||||
{ns_tan, "Tan"},
|
||||
{ns_asin, "Asin"},
|
||||
{ns_acos, "Acos"},
|
||||
{ns_atan, "Atan"},
|
||||
{ns_sinh, "Sinh"},
|
||||
{ns_cosh, "Cosh"},
|
||||
{ns_tanh, "Tanh"},
|
||||
{ns_asinh, "Asinh"},
|
||||
{ns_acosh, "Acosh"},
|
||||
{ns_atanh, "Atanh"},
|
||||
{ns_sigmoid, "Sigmoid"},
|
||||
{ns_erf, "Erf"},
|
||||
{ns_erfinv, "Erfinv"},
|
||||
{ns_logical_not, "LogicalNot"},
|
||||
{ns_bitwise_not, "BitwiseNot"},
|
||||
// binary op
|
||||
{ns_pow, "Pow"},
|
||||
{ns_maximum, "Maximum"},
|
||||
{ns_minimum, "Minimum"},
|
||||
{ns_add, "Add"},
|
||||
{ns_subtract, "Sub"},
|
||||
{ns_multiply, "Mul"},
|
||||
{ns_divide, "RealDiv"},
|
||||
{ns_floor_divide, "FloorDiv"},
|
||||
{ns_mod, "Mod"},
|
||||
{ns_less, "Less"},
|
||||
{ns_less_equal, "LessEqual"},
|
||||
{ns_greater, "Greater"},
|
||||
{ns_greater_equal, "GreaterEqual"},
|
||||
{ns_equal, "Equal"},
|
||||
{ns_not_equal, "NotEqual"},
|
||||
{ns_left_shift, "LeftShift"},
|
||||
{ns_right_shift, "RightShift"},
|
||||
{ns_logical_and, "LogicalAnd"},
|
||||
{ns_logical_or, "LogicalOr"},
|
||||
{ns_logical_xor, "LogicalXor"},
|
||||
{ns_bitwise_and, "BitwiseAnd"},
|
||||
{ns_bitwise_or, "BitwiseOr"},
|
||||
{ns_bitwise_xor, "BitwiseXor"},
|
||||
};
|
||||
|
||||
void fallback_cpu(Op *op)
|
||||
{
|
||||
LOGy << "!!! fallback_cpu " << op;
|
||||
use_cuda = 0;
|
||||
for (auto v : op->inputs())
|
||||
{
|
||||
if (v->mem_ptr && v->allocator->is_cuda())
|
||||
{
|
||||
migrate_to_cpu(v, exe.allocator);
|
||||
}
|
||||
}
|
||||
for (auto v : op->outputs())
|
||||
{
|
||||
if (v->mem_ptr && v->allocator->is_cuda())
|
||||
{
|
||||
migrate_to_cpu(v, exe.allocator);
|
||||
}
|
||||
}
|
||||
op->flags.set(NodeFlags::_cpu);
|
||||
op->flags.set(NodeFlags::_cuda, 0);
|
||||
if (op->name() == string("fused"))
|
||||
{
|
||||
auto fop = (FusedOp *)op;
|
||||
for (auto op : fop->ops)
|
||||
{
|
||||
op->flags.set(NodeFlags::_cpu);
|
||||
op->flags.set(NodeFlags::_cuda, 0);
|
||||
}
|
||||
}
|
||||
op->do_run();
|
||||
use_cuda = 1;
|
||||
}
|
||||
|
||||
/*
|
||||
check compile
|
||||
if compiled: exec
|
||||
else: compile
|
||||
check is fused
|
||||
check is relay
|
||||
else
|
||||
compile func = try exec
|
||||
if failed: fallback_cpu
|
||||
else
|
||||
try compile
|
||||
if failed: fallback_cpu
|
||||
*/
|
||||
|
||||
extern jit_op_entry_t (*do_compile_hook)(Op *);
|
||||
jit_op_entry_t do_compile_inner(Op *op);
|
||||
|
||||
void try_exec_and_fallback_cpu(Op *op)
|
||||
{
|
||||
auto fop = (FusedOp *)op;
|
||||
|
||||
std::set<Var *> new_alloced;
|
||||
map<Op *, int> op_indeg;
|
||||
map<Var *, int> var_outdeg;
|
||||
std::queue<Op *> queue;
|
||||
|
||||
for (Op *op : fop->ops)
|
||||
op_indeg[op] = 0;
|
||||
|
||||
map<Op *, vector<Op *>> out_map;
|
||||
map<Var *, vector<Op *>> from;
|
||||
|
||||
int len = 0;
|
||||
for (Op *v : fop->ops)
|
||||
{
|
||||
for (auto in : v->inputs())
|
||||
from[in].push_back(v);
|
||||
++len;
|
||||
}
|
||||
for (Op *u : fop->ops)
|
||||
{
|
||||
for (auto out : u->outputs())
|
||||
{
|
||||
if (from.find(out) != from.end())
|
||||
{
|
||||
for (auto v : from[out])
|
||||
{
|
||||
++op_indeg[v];
|
||||
++var_outdeg[out];
|
||||
out_map[u].push_back(v);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
for (Op *op : fop->ops)
|
||||
{
|
||||
if (op_indeg[op] == 0)
|
||||
queue.push(op);
|
||||
}
|
||||
|
||||
int total = 0;
|
||||
int fallback = 0;
|
||||
try
|
||||
{
|
||||
while (!queue.empty())
|
||||
{
|
||||
total++;
|
||||
|
||||
for (auto in : op->inputs())
|
||||
{
|
||||
ASSERT(in->mem_ptr);
|
||||
}
|
||||
auto op = queue.front();
|
||||
queue.pop();
|
||||
for (auto out : op->outputs())
|
||||
{
|
||||
if (out->mem_ptr)
|
||||
continue;
|
||||
out->alloc(exe.allocator);
|
||||
new_alloced.insert(out);
|
||||
}
|
||||
for (auto out : out_map[op])
|
||||
{
|
||||
--op_indeg[out];
|
||||
if (op_indeg[out] == 0)
|
||||
queue.push(out);
|
||||
}
|
||||
if (op->name() == string("unary"))
|
||||
{
|
||||
auto uop = (UnaryOp *)op;
|
||||
UnaryOpRunner op;
|
||||
op.add(uop->x, true);
|
||||
op.add(uop->y, false);
|
||||
auto iter = opname_map.find(uop->ns);
|
||||
ASSERT(iter != opname_map.end()) << "op " << uop->ns << " not found";
|
||||
op.name = iter->second;
|
||||
op.jt_name = uop->name();
|
||||
op.run();
|
||||
}
|
||||
else if (op->name() == string("binary"))
|
||||
{
|
||||
auto bop = (BinaryOp *)op;
|
||||
BinaryOpRunner op;
|
||||
op.add(bop->x, true);
|
||||
op.add(bop->y, true);
|
||||
op.add(bop->z, false);
|
||||
auto iter = opname_map.find(bop->ns);
|
||||
ASSERT(iter != opname_map.end()) << "op " << bop->ns << " not found";
|
||||
op.name = iter->second;
|
||||
op.jt_name = bop->name();
|
||||
|
||||
if (bop->x->dtype() == ns_bool and bop->y->dtype() == ns_bool)
|
||||
{
|
||||
// BitwiseOr, BitwiseAnd, BitwiseXor -> LogicalOr, LogicalAnd, LogicalXor
|
||||
if (bop->ns == ns_bitwise_or)
|
||||
{
|
||||
op.name = "LogicalOr";
|
||||
}
|
||||
else if (bop->ns == ns_bitwise_and)
|
||||
{
|
||||
op.name = "LogicalAnd";
|
||||
}
|
||||
else if (bop->ns == ns_bitwise_xor)
|
||||
{
|
||||
op.name = "LogicalXor";
|
||||
}
|
||||
}
|
||||
op.run();
|
||||
}
|
||||
else if (op->name() == string("ternary"))
|
||||
{
|
||||
auto top = (TernaryOp *)op;
|
||||
TernaryOpRunner op;
|
||||
op.add(top->cond, true);
|
||||
op.add(top->x, true);
|
||||
op.add(top->y, true);
|
||||
op.add(top->z, false);
|
||||
op.run();
|
||||
}
|
||||
else if (op->name() == string("array"))
|
||||
{
|
||||
auto aop = (ArrayOp *)op;
|
||||
aclrtMemcpy(aop->output->mem_ptr, aop->output->size, aop->ptr<void>(), aop->output->size, ACL_MEMCPY_HOST_TO_DEVICE);
|
||||
}
|
||||
else if (op->name() == string("reduce"))
|
||||
{
|
||||
auto rop = (ReduceOp *)op;
|
||||
ReduceOpRunner op;
|
||||
if (rop->ns == ns_add)
|
||||
op.op_idx = 9;
|
||||
else if (rop->ns == ns_multiply)
|
||||
// TODO unsupported the multi dim
|
||||
op.op_idx = 999;
|
||||
else if (rop->ns == ns_maximum)
|
||||
op.op_idx = 11;
|
||||
else if (rop->ns == ns_minimum)
|
||||
op.op_idx = 12;
|
||||
else if (rop->ns == ns_mean)
|
||||
op.op_idx = 10;
|
||||
else
|
||||
LOGf << "op " << rop->ns << " not supported";
|
||||
op.add(rop->x, true);
|
||||
|
||||
ReduceAttr *attr = new ReduceAttr();
|
||||
for (int i = 0; i < rop->x->shape.size(); i++)
|
||||
if (rop->reduce_mask & (1 << i))
|
||||
attr->axes.push_back(i);
|
||||
if (rop->x->shape.size() == rop->y->shape.size())
|
||||
attr->keepdims = true;
|
||||
else
|
||||
attr->keepdims = false;
|
||||
|
||||
op.op_attr.reset(attr);
|
||||
op.add(rop->y, false);
|
||||
op.run();
|
||||
aclrtSynchronizeStream(aclstream);
|
||||
}
|
||||
else if (op->name() == string("broadcast_to"))
|
||||
{
|
||||
auto bop = (BroadcastToOp *)op;
|
||||
ExpandOpRunner op;
|
||||
op.jt_name = "expand";
|
||||
NanoVector xshape, xshape_bk = bop->x->shape;
|
||||
NanoVector zshape = bop->z->shape;
|
||||
|
||||
for (int i = 0; i < zshape.size(); i++)
|
||||
{
|
||||
if (bop->bcast_mask & (1 << i))
|
||||
{
|
||||
xshape.push_back(1);
|
||||
}
|
||||
else
|
||||
{
|
||||
xshape.push_back(zshape[i]);
|
||||
}
|
||||
}
|
||||
bop->x->shape = xshape;
|
||||
op.add(bop->x, true);
|
||||
// bop->x->shape = xshape_bk;
|
||||
op.add(bop->z, false);
|
||||
op.run();
|
||||
bop->x->shape = xshape_bk;
|
||||
aclrtSynchronizeStream(aclstream);
|
||||
}
|
||||
else if (op->name() == string("fuse_transpose"))
|
||||
{
|
||||
// replace fuse_transpose with transpose
|
||||
auto top = (TransposeOp *)op;
|
||||
TransposeOpRunner op;
|
||||
op.add(top->x, true);
|
||||
op.add(top->y, false);
|
||||
op.jt_name = "transpose";
|
||||
|
||||
ReduceAttr *attr = new ReduceAttr();
|
||||
for (int i = 0; i < top->axes.size(); i++)
|
||||
attr->axes.push_back(top->axes[i]);
|
||||
op.op_attr.reset(attr);
|
||||
|
||||
op.run();
|
||||
}
|
||||
else
|
||||
{
|
||||
LOGf << "op " << op->name() << " not supported";
|
||||
}
|
||||
|
||||
for (auto in : op->inputs())
|
||||
{
|
||||
--var_outdeg[in];
|
||||
if (var_outdeg[in] == 0)
|
||||
{
|
||||
if (new_alloced.find(in) != new_alloced.end())
|
||||
{
|
||||
free_var_mem(in);
|
||||
new_alloced.erase(in);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
catch (std::exception &e)
|
||||
{
|
||||
fallback = 1;
|
||||
LOGir << "fallback cpu" << e.what();
|
||||
}
|
||||
for (auto v : new_alloced)
|
||||
{
|
||||
free_var_mem(v);
|
||||
}
|
||||
if (fallback)
|
||||
{
|
||||
fallback_cpu(op);
|
||||
}
|
||||
}
|
||||
|
||||
extern int current_seed;
|
||||
extern int64 current_offset;
|
||||
|
||||
static unordered_map<string, std::function<void(Op *)>> acl_ops = {
|
||||
{"curand_random", [¤t_seed, ¤t_offset](Op *op)
|
||||
{
|
||||
auto _op = (RandomOp *)op;
|
||||
RandomOpRunner runner(_op->type == ns_uniform ? "RandomUniform" : "RandomNormal");
|
||||
auto out = op->output(0);
|
||||
RandomAttr *attr = new RandomAttr();
|
||||
attr->seed = current_seed;
|
||||
attr->offset = current_offset;
|
||||
runner.jt_name = "random";
|
||||
runner.op_attr.reset(attr);
|
||||
|
||||
runner.add(out, false);
|
||||
runner.run();
|
||||
current_offset += out->numel();
|
||||
}},
|
||||
};
|
||||
|
||||
static void exec_mapped_acl_ops(Op *op)
|
||||
{
|
||||
auto iter = acl_ops.find(op->name());
|
||||
if (iter != acl_ops.end())
|
||||
{
|
||||
LOGv << "exec acl op " << op->name() << op;
|
||||
iter->second(op);
|
||||
}
|
||||
else
|
||||
{
|
||||
LOGf << "op " << op->name() << " not supported";
|
||||
}
|
||||
}
|
||||
|
||||
static jit_op_entry_t acl_do_compile(Op *op)
|
||||
{
|
||||
LOGv << "compile" << op;
|
||||
OpCompiler oc(op);
|
||||
string *src = &oc.src;
|
||||
for (auto op_type : op_types)
|
||||
op_type->post_pass(&oc);
|
||||
string src_after_passes;
|
||||
// if is fused op
|
||||
if (oc.op)
|
||||
{
|
||||
TunerManager tm(&oc);
|
||||
src_after_passes = tm.tune();
|
||||
src = &src_after_passes;
|
||||
}
|
||||
op->compile_optimize(*src);
|
||||
if (!op->flags.get(NodeFlags::_cuda))
|
||||
{
|
||||
LOGv << "compile cpu";
|
||||
return oc.compile(op->get_jit_key(get_jk()), *src);
|
||||
}
|
||||
if (op->name() == string("fused"))
|
||||
{
|
||||
FusedOp *fop = (FusedOp *)op;
|
||||
// if is a relayed op
|
||||
if (fop->context->vrm.relay_groups.size())
|
||||
{
|
||||
LOGv << "relay fused op";
|
||||
return oc.compile(op->get_jit_key(get_jk()), *src);
|
||||
}
|
||||
else
|
||||
{
|
||||
return &try_exec_and_fallback_cpu;
|
||||
}
|
||||
}
|
||||
else if (op->name() == string("code"))
|
||||
{
|
||||
CodeOp *cop = (CodeOp *)op;
|
||||
if (cop->cuda_src.find("acl") != string::npos)
|
||||
{
|
||||
LOGv << "compile acl op";
|
||||
return oc.compile(op->get_jit_key(get_jk()), *src);
|
||||
}
|
||||
else
|
||||
{
|
||||
return &exec_mapped_acl_ops;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
LOGv << "compile finish" << op;
|
||||
return &exec_mapped_acl_ops;
|
||||
}
|
||||
return do_compile_inner(op);
|
||||
}
|
||||
|
||||
// from op_register.cc
|
||||
extern unordered_map<string, OpInfo> op_info_map;
|
||||
|
||||
void init_acl_ops()
|
||||
{
|
||||
do_compile_hook = acl_do_compile;
|
||||
vector<string> to_erase;
|
||||
for (auto &kv : op_info_map)
|
||||
{
|
||||
if (startswith(kv.first, "cu") && acl_ops.count(kv.first) == 0)
|
||||
{
|
||||
to_erase.push_back(kv.first);
|
||||
}
|
||||
}
|
||||
for (auto &k : to_erase)
|
||||
{
|
||||
LOGv << "op not supported: " << k << ", erase it.";
|
||||
op_info_map.erase(k);
|
||||
}
|
||||
}
|
||||
|
||||
} // jittor
|
||||
// ***************************************************************
|
||||
// Copyright (c) 2023 Jittor. All Rights Reserved.
|
||||
// Maintainers: Dun Liang <randonlang@gmail.com>.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#include <acl/acl.h>
|
||||
#include <acl/acl_op_compiler.h>
|
||||
#include <Python.h>
|
||||
#include <pystate.h>
|
||||
#include <algorithm>
|
||||
#include <queue>
|
||||
#include <set>
|
||||
#include "common.h"
|
||||
#include "op.h"
|
||||
#include "acl_jittor.h"
|
||||
#include "ops/random_op.h"
|
||||
#include "ops/reduce_op.h"
|
||||
#include "ops/binary_op.h"
|
||||
#include "ops/broadcast_to_op.h"
|
||||
#include "ops/transpose_op.h"
|
||||
#include "ops/array_op.h"
|
||||
#include "ops/code_op.h"
|
||||
#include "fused_op.h"
|
||||
#include "ops/unary_op.h"
|
||||
#include "ops/ternary_op.h"
|
||||
#include "executor.h"
|
||||
#include "misc/cuda_flags.h"
|
||||
#include "mem/allocator.h"
|
||||
#include "op_compiler.h"
|
||||
#include "ops/op_register.h"
|
||||
#include "opt/tuner_manager.h"
|
||||
#include "utils/str_utils.h"
|
||||
#include "aclnn/aclnn.h"
|
||||
#include "aclops/aclops.h"
|
||||
namespace jittor
|
||||
{
|
||||
void free_var_mem(Var *v);
|
||||
|
||||
unordered_map<uint32, string> opname_map = {
|
||||
// unary op
|
||||
{ns_cast, "Cast"},
|
||||
{ns_negative, "Neg"},
|
||||
{ns_abs, "Abs"},
|
||||
{ns_exp, "Exp"},
|
||||
{ns_log, "Log"},
|
||||
{ns_sqrt, "Sqrt"},
|
||||
{ns_ceil, "Ceil"},
|
||||
{ns_floor, "Floor"},
|
||||
{ns_round, "Round"},
|
||||
// m(round_int)
|
||||
// m(floor_int)
|
||||
// m(ceil_int)
|
||||
{ns_sin, "Sin"},
|
||||
{ns_cos, "Cos"},
|
||||
{ns_tan, "Tan"},
|
||||
{ns_asin, "Asin"},
|
||||
{ns_acos, "Acos"},
|
||||
{ns_atan, "Atan"},
|
||||
{ns_sinh, "Sinh"},
|
||||
{ns_cosh, "Cosh"},
|
||||
{ns_tanh, "Tanh"},
|
||||
{ns_asinh, "Asinh"},
|
||||
{ns_acosh, "Acosh"},
|
||||
{ns_atanh, "Atanh"},
|
||||
{ns_sigmoid, "Sigmoid"},
|
||||
{ns_erf, "Erf"},
|
||||
{ns_erfinv, "Erfinv"},
|
||||
{ns_logical_not, "LogicalNot"},
|
||||
{ns_bitwise_not, "BitwiseNot"},
|
||||
// binary op
|
||||
{ns_pow, "Pow"},
|
||||
{ns_maximum, "Maximum"},
|
||||
{ns_minimum, "Minimum"},
|
||||
{ns_add, "Add"},
|
||||
{ns_subtract, "Sub"},
|
||||
{ns_multiply, "Mul"},
|
||||
{ns_divide, "RealDiv"},
|
||||
{ns_floor_divide, "FloorDiv"},
|
||||
{ns_mod, "Mod"},
|
||||
{ns_less, "Less"},
|
||||
{ns_less_equal, "LessEqual"},
|
||||
{ns_greater, "Greater"},
|
||||
{ns_greater_equal, "GreaterEqual"},
|
||||
{ns_equal, "Equal"},
|
||||
{ns_not_equal, "NotEqual"},
|
||||
{ns_left_shift, "LeftShift"},
|
||||
{ns_right_shift, "RightShift"},
|
||||
{ns_logical_and, "LogicalAnd"},
|
||||
{ns_logical_or, "LogicalOr"},
|
||||
{ns_logical_xor, "LogicalXor"},
|
||||
{ns_bitwise_and, "BitwiseAnd"},
|
||||
{ns_bitwise_or, "BitwiseOr"},
|
||||
{ns_bitwise_xor, "BitwiseXor"},
|
||||
};
|
||||
|
||||
void fallback_cpu(Op *op)
|
||||
{
|
||||
LOGy << "!!! fallback_cpu " << op;
|
||||
use_cuda = 0;
|
||||
for (auto v : op->inputs())
|
||||
{
|
||||
if (v->mem_ptr && v->allocator->is_cuda())
|
||||
{
|
||||
migrate_to_cpu(v, exe.allocator);
|
||||
}
|
||||
}
|
||||
for (auto v : op->outputs())
|
||||
{
|
||||
if (v->mem_ptr && v->allocator->is_cuda())
|
||||
{
|
||||
migrate_to_cpu(v, exe.allocator);
|
||||
}
|
||||
}
|
||||
op->flags.set(NodeFlags::_cpu);
|
||||
op->flags.set(NodeFlags::_cuda, 0);
|
||||
if (op->name() == string("fused"))
|
||||
{
|
||||
auto fop = (FusedOp *)op;
|
||||
for (auto op : fop->ops)
|
||||
{
|
||||
op->flags.set(NodeFlags::_cpu);
|
||||
op->flags.set(NodeFlags::_cuda, 0);
|
||||
}
|
||||
}
|
||||
op->do_run();
|
||||
use_cuda = 1;
|
||||
}
|
||||
|
||||
/*
|
||||
check compile
|
||||
if compiled: exec
|
||||
else: compile
|
||||
check is fused
|
||||
check is relay
|
||||
else
|
||||
compile func = try exec
|
||||
if failed: fallback_cpu
|
||||
else
|
||||
try compile
|
||||
if failed: fallback_cpu
|
||||
*/
|
||||
|
||||
extern jit_op_entry_t (*do_compile_hook)(Op *);
|
||||
jit_op_entry_t do_compile_inner(Op *op);
|
||||
|
||||
void try_exec_and_fallback_cpu(Op *op)
|
||||
{
|
||||
aclrtSynchronizeStream(aclstream);
|
||||
auto fop = (FusedOp *)op;
|
||||
|
||||
std::set<Var *> new_alloced;
|
||||
map<Op *, int> op_indeg;
|
||||
map<Var *, int> var_outdeg;
|
||||
std::queue<Op *> queue;
|
||||
|
||||
for (Op *op : fop->ops)
|
||||
op_indeg[op] = 0;
|
||||
|
||||
map<Op *, vector<Op *>> out_map;
|
||||
map<Var *, vector<Op *>> from;
|
||||
|
||||
int len = 0;
|
||||
for (Op *v : fop->ops)
|
||||
{
|
||||
for (auto in : v->inputs())
|
||||
from[in].push_back(v);
|
||||
++len;
|
||||
}
|
||||
for (Op *u : fop->ops)
|
||||
{
|
||||
for (auto out : u->outputs())
|
||||
{
|
||||
if (from.find(out) != from.end())
|
||||
{
|
||||
for (auto v : from[out])
|
||||
{
|
||||
++op_indeg[v];
|
||||
++var_outdeg[out];
|
||||
out_map[u].push_back(v);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
for (Op *op : fop->ops)
|
||||
{
|
||||
if (op_indeg[op] == 0)
|
||||
queue.push(op);
|
||||
}
|
||||
|
||||
int total = 0;
|
||||
int fallback = 0;
|
||||
try
|
||||
{
|
||||
while (!queue.empty())
|
||||
{
|
||||
total++;
|
||||
|
||||
for (auto in : op->inputs())
|
||||
{
|
||||
ASSERT(in->mem_ptr);
|
||||
}
|
||||
auto op = queue.front();
|
||||
queue.pop();
|
||||
for (auto out : op->outputs())
|
||||
{
|
||||
if (out->mem_ptr)
|
||||
continue;
|
||||
out->alloc(exe.allocator);
|
||||
new_alloced.insert(out);
|
||||
}
|
||||
for (auto out : out_map[op])
|
||||
{
|
||||
--op_indeg[out];
|
||||
if (op_indeg[out] == 0)
|
||||
queue.push(out);
|
||||
}
|
||||
if (op->name() == string("unary"))
|
||||
{
|
||||
auto uop = (UnaryOp *)op;
|
||||
UnaryOpRunner op;
|
||||
op.add(uop->x, true);
|
||||
op.add(uop->y, false);
|
||||
auto iter = opname_map.find(uop->ns);
|
||||
ASSERT(iter != opname_map.end()) << "op " << uop->ns << " not found";
|
||||
op.name = iter->second;
|
||||
op.jt_name = uop->name();
|
||||
op.run();
|
||||
}
|
||||
else if (op->name() == string("binary"))
|
||||
{
|
||||
auto bop = (BinaryOp *)op;
|
||||
BinaryOpRunner op;
|
||||
op.add(bop->x, true);
|
||||
op.add(bop->y, true);
|
||||
op.add(bop->z, false);
|
||||
auto iter = opname_map.find(bop->ns);
|
||||
ASSERT(iter != opname_map.end()) << "op " << bop->ns << " not found";
|
||||
op.name = iter->second;
|
||||
op.jt_name = bop->name();
|
||||
|
||||
if (bop->x->dtype() == ns_bool and bop->y->dtype() == ns_bool)
|
||||
{
|
||||
// BitwiseOr, BitwiseAnd, BitwiseXor -> LogicalOr, LogicalAnd, LogicalXor
|
||||
if (bop->ns == ns_bitwise_or)
|
||||
{
|
||||
op.name = "LogicalOr";
|
||||
}
|
||||
else if (bop->ns == ns_bitwise_and)
|
||||
{
|
||||
op.name = "LogicalAnd";
|
||||
}
|
||||
else if (bop->ns == ns_bitwise_xor)
|
||||
{
|
||||
op.name = "LogicalXor";
|
||||
}
|
||||
}
|
||||
op.run();
|
||||
}
|
||||
else if (op->name() == string("ternary"))
|
||||
{
|
||||
auto top = (TernaryOp *)op;
|
||||
TernaryOpRunner op;
|
||||
op.add(top->cond, true);
|
||||
op.add(top->x, true);
|
||||
op.add(top->y, true);
|
||||
op.add(top->z, false);
|
||||
op.run();
|
||||
}
|
||||
else if (op->name() == string("array"))
|
||||
{
|
||||
auto aop = (ArrayOp *)op;
|
||||
aclrtMemcpy(aop->output->mem_ptr, aop->output->size, aop->ptr<void>(), aop->output->size, ACL_MEMCPY_HOST_TO_DEVICE);
|
||||
}
|
||||
else if (op->name() == string("reduce"))
|
||||
{
|
||||
auto rop = (ReduceOp *)op;
|
||||
ReduceOpRunner op;
|
||||
if (rop->ns == ns_add)
|
||||
op.op_idx = 9;
|
||||
else if (rop->ns == ns_multiply)
|
||||
// TODO unsupported the multi dim
|
||||
op.op_idx = 999;
|
||||
else if (rop->ns == ns_maximum)
|
||||
op.op_idx = 11;
|
||||
else if (rop->ns == ns_minimum)
|
||||
op.op_idx = 12;
|
||||
else if (rop->ns == ns_mean)
|
||||
op.op_idx = 10;
|
||||
else
|
||||
LOGf << "op " << rop->ns << " not supported";
|
||||
op.add(rop->x, true);
|
||||
|
||||
ReduceAttr *attr = new ReduceAttr();
|
||||
for (int i = 0; i < rop->x->shape.size(); i++)
|
||||
if (rop->reduce_mask & (1 << i))
|
||||
attr->axes.push_back(i);
|
||||
if (rop->x->shape.size() == rop->y->shape.size())
|
||||
attr->keepdims = true;
|
||||
else
|
||||
attr->keepdims = false;
|
||||
|
||||
op.op_attr.reset(attr);
|
||||
op.add(rop->y, false);
|
||||
op.run();
|
||||
aclrtSynchronizeStream(aclstream);
|
||||
}
|
||||
else if (op->name() == string("broadcast_to"))
|
||||
{
|
||||
auto bop = (BroadcastToOp *)op;
|
||||
ExpandOpRunner op;
|
||||
op.jt_name = "expand";
|
||||
NanoVector xshape, xshape_bk = bop->x->shape;
|
||||
NanoVector zshape = bop->z->shape;
|
||||
|
||||
for (int i = 0; i < zshape.size(); i++)
|
||||
{
|
||||
if (bop->bcast_mask & (1 << i))
|
||||
{
|
||||
xshape.push_back(1);
|
||||
}
|
||||
else
|
||||
{
|
||||
xshape.push_back(zshape[i]);
|
||||
}
|
||||
}
|
||||
bop->x->shape = xshape;
|
||||
op.add(bop->x, true);
|
||||
// bop->x->shape = xshape_bk;
|
||||
op.add(bop->z, false);
|
||||
op.run();
|
||||
bop->x->shape = xshape_bk;
|
||||
aclrtSynchronizeStream(aclstream);
|
||||
}
|
||||
else if (op->name() == string("fuse_transpose"))
|
||||
{
|
||||
// replace fuse_transpose with transpose
|
||||
auto top = (TransposeOp *)op;
|
||||
TransposeOpRunner op;
|
||||
op.add(top->x, true);
|
||||
op.add(top->y, false);
|
||||
op.jt_name = "transpose";
|
||||
|
||||
ReduceAttr *attr = new ReduceAttr();
|
||||
for (int i = 0; i < top->axes.size(); i++)
|
||||
attr->axes.push_back(top->axes[i]);
|
||||
op.op_attr.reset(attr);
|
||||
|
||||
op.run();
|
||||
}
|
||||
else
|
||||
{
|
||||
LOGf << "op " << op->name() << " not supported";
|
||||
}
|
||||
|
||||
for (auto in : op->inputs())
|
||||
{
|
||||
--var_outdeg[in];
|
||||
if (var_outdeg[in] == 0)
|
||||
{
|
||||
if (new_alloced.find(in) != new_alloced.end())
|
||||
{
|
||||
free_var_mem(in);
|
||||
new_alloced.erase(in);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
catch (std::exception &e)
|
||||
{
|
||||
fallback = 1;
|
||||
LOGir << "fallback cpu" << e.what();
|
||||
}
|
||||
for (auto v : new_alloced)
|
||||
{
|
||||
free_var_mem(v);
|
||||
}
|
||||
if (fallback)
|
||||
{
|
||||
fallback_cpu(op);
|
||||
}
|
||||
}
|
||||
|
||||
extern int current_seed;
|
||||
extern int64 current_offset;
|
||||
|
||||
static unordered_map<string, std::function<void(Op *)>> acl_ops = {
|
||||
{"curand_random", [¤t_seed, ¤t_offset](Op *op)
|
||||
{
|
||||
auto _op = (RandomOp *)op;
|
||||
RandomOpRunner runner(_op->type == ns_uniform ? "RandomUniform" : "RandomNormal");
|
||||
auto out = op->output(0);
|
||||
RandomAttr *attr = new RandomAttr();
|
||||
attr->seed = current_seed;
|
||||
attr->offset = current_offset;
|
||||
runner.jt_name = "random";
|
||||
runner.op_attr.reset(attr);
|
||||
|
||||
runner.add(out, false);
|
||||
runner.run();
|
||||
current_offset += out->numel();
|
||||
}},
|
||||
};
|
||||
|
||||
static void exec_mapped_acl_ops(Op *op)
|
||||
{
|
||||
auto iter = acl_ops.find(op->name());
|
||||
if (iter != acl_ops.end())
|
||||
{
|
||||
LOGv << "exec acl op " << op->name() << op;
|
||||
iter->second(op);
|
||||
}
|
||||
else
|
||||
{
|
||||
LOGf << "op " << op->name() << " not supported";
|
||||
}
|
||||
}
|
||||
|
||||
static jit_op_entry_t acl_do_compile(Op *op)
|
||||
{
|
||||
LOGv << "compile" << op;
|
||||
OpCompiler oc(op);
|
||||
string *src = &oc.src;
|
||||
for (auto op_type : op_types)
|
||||
op_type->post_pass(&oc);
|
||||
string src_after_passes;
|
||||
// if is fused op
|
||||
if (oc.op)
|
||||
{
|
||||
TunerManager tm(&oc);
|
||||
src_after_passes = tm.tune();
|
||||
src = &src_after_passes;
|
||||
}
|
||||
op->compile_optimize(*src);
|
||||
if (!op->flags.get(NodeFlags::_cuda))
|
||||
{
|
||||
LOGv << "compile cpu";
|
||||
return oc.compile(op->get_jit_key(get_jk()), *src);
|
||||
}
|
||||
if (op->name() == string("fused"))
|
||||
{
|
||||
FusedOp *fop = (FusedOp *)op;
|
||||
// if is a relayed op
|
||||
if (fop->context->vrm.relay_groups.size())
|
||||
{
|
||||
LOGv << "relay fused op";
|
||||
return oc.compile(op->get_jit_key(get_jk()), *src);
|
||||
}
|
||||
else
|
||||
{
|
||||
return &try_exec_and_fallback_cpu;
|
||||
}
|
||||
}
|
||||
else if (op->name() == string("code"))
|
||||
{
|
||||
CodeOp *cop = (CodeOp *)op;
|
||||
if (cop->cuda_src.find("acl") != string::npos)
|
||||
{
|
||||
LOGv << "compile acl op";
|
||||
return oc.compile(op->get_jit_key(get_jk()), *src);
|
||||
}
|
||||
else
|
||||
{
|
||||
return &exec_mapped_acl_ops;
|
||||
}
|
||||
}
|
||||
else if (strncmp(op->name(), "hccl", 4) == 0)
|
||||
{
|
||||
LOGv << "Compiling HCCL op: " << op->name();
|
||||
return oc.compile(op->get_jit_key(get_jk()), *src);
|
||||
}
|
||||
else
|
||||
{
|
||||
LOGv << "compile finish" << op;
|
||||
return &exec_mapped_acl_ops;
|
||||
}
|
||||
return do_compile_inner(op);
|
||||
}
|
||||
|
||||
// from op_register.cc
|
||||
extern unordered_map<string, OpInfo> op_info_map;
|
||||
|
||||
void init_acl_ops()
|
||||
{
|
||||
do_compile_hook = acl_do_compile;
|
||||
vector<string> to_erase;
|
||||
for (auto &kv : op_info_map)
|
||||
{
|
||||
if (startswith(kv.first, "cu") && acl_ops.count(kv.first) == 0)
|
||||
{
|
||||
to_erase.push_back(kv.first);
|
||||
}
|
||||
}
|
||||
for (auto &k : to_erase)
|
||||
{
|
||||
LOGv << "op not supported: " << k << ", erase it.";
|
||||
op_info_map.erase(k);
|
||||
}
|
||||
}
|
||||
|
||||
} // jittor
|
||||
|
|
|
@ -21,7 +21,7 @@ void PrintOutResult(std::vector<int64_t> &shape, void** deviceAddr) {
|
|||
}
|
||||
}
|
||||
|
||||
int Init(int32_t deviceId) {
|
||||
/*int Init(int32_t deviceId) {
|
||||
// 固定写法,AscendCL初始化
|
||||
auto ret = aclInit(nullptr);
|
||||
CHECK_RET(ret == ACL_SUCCESS or ret == 100002, LOG_PRINT("aclInit failed. ERROR: %d\n", ret); return ret);
|
||||
|
@ -30,7 +30,7 @@ int Init(int32_t deviceId) {
|
|||
//ret = aclrtCreateStream(stream);
|
||||
//CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtCreateStream failed. ERROR: %d\n", ret); return ret);
|
||||
return 0;
|
||||
}
|
||||
}*/
|
||||
|
||||
/*
|
||||
template <typename T>
|
||||
|
|
|
@ -125,7 +125,7 @@ int64_t GetShapeSize(const std::vector<int64_t> &shape);
|
|||
|
||||
void PrintOutResult(std::vector<int64_t> &shape, void **deviceAddr);
|
||||
|
||||
int Init(int32_t deviceId);
|
||||
//int Init(int32_t deviceId);
|
||||
|
||||
/*
|
||||
template <typename T>
|
||||
|
|
|
@ -0,0 +1,38 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2025 Jittor.
|
||||
// All Rights Reserved.
|
||||
// Maintainers:
|
||||
// Jiapeng Zhang <zjp24@mails.tsinghua.edu.cn>.
|
||||
//
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
|
||||
#pragma once
|
||||
#include "mpi_wrapper.h"
|
||||
|
||||
#define ACLCHECK(ret) do {\
|
||||
if(ret != ACL_SUCCESS)\
|
||||
{\
|
||||
LOGe << "retcode: " << ret;\
|
||||
return;\
|
||||
}\
|
||||
} while(0)\
|
||||
|
||||
#define HCCLCHECK(ret) do {\
|
||||
if(ret != HCCL_SUCCESS)\
|
||||
{\
|
||||
LOGe << HcclGetErrorString(ret) << " retcode: " << ret;\
|
||||
return;\
|
||||
}\
|
||||
} while(0)
|
||||
|
||||
#include <hccl.h>
|
||||
|
||||
namespace jittor {
|
||||
|
||||
EXTERN_LIB HcclRootInfo root_info;
|
||||
EXTERN_LIB HcclComm comm;
|
||||
EXTERN_LIB uint32_t hccl_device_id;
|
||||
|
||||
} // jittor
|
|
@ -0,0 +1,70 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2025 Jittor.
|
||||
// All Rights Reserved.
|
||||
// Maintainers:
|
||||
// Jiapeng Zhang <zjp24@mails.tsinghua.edu.cn>.
|
||||
//
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
|
||||
#include "var.h"
|
||||
#include "hccl_all_gather_op.h"
|
||||
#include "ops/op_register.h"
|
||||
#include "utils/str_utils.h"
|
||||
#include "hccl_wrapper.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
#ifndef JIT
|
||||
|
||||
static auto hccl_all_gather =
|
||||
get_op_info("hccl_all_gather").get_constructor<VarPtr, Var*>();
|
||||
|
||||
HcclAllGatherOp::HcclAllGatherOp(Var* x) : x(x) {
|
||||
flags.set(NodeFlags::_cpu, 0);
|
||||
flags.set(NodeFlags::_cuda, 1);
|
||||
y = create_output(nullptr, x->dtype());
|
||||
}
|
||||
|
||||
void HcclAllGatherOp::infer_shape() {
|
||||
NanoVector yshape;
|
||||
yshape.push_back(mpi_world_size * x->shape[0]);
|
||||
for (int i=1; i<x->shape.size(); i++)
|
||||
yshape.push_back(x->shape[i]);
|
||||
y->set_shape(yshape);
|
||||
}
|
||||
|
||||
VarPtr HcclAllGatherOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
||||
LOGf << "not implemented";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void HcclAllGatherOp::jit_prepare(JK& jk) {
|
||||
jk << "«Tx:" << x->dtype();
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
|
||||
void HcclAllGatherOp::jit_run() {
|
||||
LOGir << "HcclAllGatherOp::jit_run";
|
||||
@define(T_HCCL,
|
||||
@if(@strcmp(@Tx,float)==0 || @strcmp(@Tx,float32)==0, HcclDataType::HCCL_DATA_TYPE_FP32)
|
||||
@if(@strcmp(@Tx,int)==0 || @strcmp(@Tx,int32)==0, HcclDataType::HCCL_DATA_TYPE_INT32)
|
||||
@if(@strcmp(@Tx,float64)==0, HcclDataType::HCCL_DATA_TYPE_FP64)
|
||||
@if(@strcmp(@Tx,int64)==0, HcclDataType::HCCL_DATA_TYPE_INT64)
|
||||
@if(@strcmp(@Tx,uint8)==0, HcclDataType::HCCL_DATA_TYPE_UINT8)
|
||||
@if(@strcmp(@Tx,float16)==0, HcclDataType::HCCL_DATA_TYPE_FP16)
|
||||
)
|
||||
auto* __restrict__ xp = x->ptr<Tx>();
|
||||
auto* __restrict__ yp = y->ptr<Tx>();
|
||||
ACLCHECK(aclrtSynchronizeDevice());
|
||||
ACLCHECK(aclrtSynchronizeStream(aclstream));
|
||||
HCCLCHECK(HcclAllGather(xp, yp, (uint64_t)x->num, @T_HCCL, comm, aclstream));
|
||||
ACLCHECK(aclrtSynchronizeDevice());
|
||||
ACLCHECK(aclrtSynchronizeStream(aclstream));
|
||||
}
|
||||
|
||||
#endif // JIT
|
||||
|
||||
} // jittor
|
|
@ -0,0 +1,26 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2025 Jittor.
|
||||
// All Rights Reserved.
|
||||
// Maintainers:
|
||||
// Jiapeng Zhang <zjp24@mails.tsinghua.edu.cn>.
|
||||
//
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#pragma once
|
||||
#include "op.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
struct HcclAllGatherOp : Op {
|
||||
Var* x, * y;
|
||||
|
||||
HcclAllGatherOp(Var* x);
|
||||
void infer_shape() override;
|
||||
|
||||
const char* name() const override { return "hccl_all_gather"; }
|
||||
VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override;
|
||||
DECLARE_jit_run;
|
||||
};
|
||||
|
||||
} // jittor
|
|
@ -0,0 +1,62 @@
|
|||
#include "var.h"
|
||||
#include "hccl_all_reduce_op.h"
|
||||
#include "ops/op_register.h"
|
||||
#include "utils/str_utils.h"
|
||||
#include "hccl_wrapper.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
#ifndef JIT
|
||||
|
||||
static auto hccl_all_reduce =
|
||||
get_op_info("hccl_all_reduce").get_constructor<VarPtr, Var*, string>();
|
||||
|
||||
HcclAllReduceOp::HcclAllReduceOp(Var* x, string reduce_op) : x(x), reduce_op(reduce_op) {
|
||||
flags.set(NodeFlags::_cpu, 0);
|
||||
flags.set(NodeFlags::_cuda, 1);
|
||||
y = create_output(nullptr, x->dtype());
|
||||
}
|
||||
|
||||
void HcclAllReduceOp::infer_shape() {
|
||||
y->set_shape(x->shape);
|
||||
}
|
||||
|
||||
VarPtr HcclAllReduceOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
||||
return hccl_all_reduce(dout, reduce_op);
|
||||
}
|
||||
|
||||
void HcclAllReduceOp::jit_prepare(JK& jk) {
|
||||
jk << "«Tx:" << x->dtype();
|
||||
jk << "«Op:" << reduce_op;
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
|
||||
void HcclAllReduceOp::jit_run() {
|
||||
//LOGir << "HcclAllReduceOp::jit_run";
|
||||
@define(T_HCCL,
|
||||
@if(@strcmp(@Tx,float)==0 || @strcmp(@Tx,float32)==0, HcclDataType::HCCL_DATA_TYPE_FP32)
|
||||
@if(@strcmp(@Tx,int)==0 || @strcmp(@Tx,int32)==0, HcclDataType::HCCL_DATA_TYPE_INT32)
|
||||
@if(@strcmp(@Tx,float64)==0, HcclDataType::HCCL_DATA_TYPE_FP64)
|
||||
@if(@strcmp(@Tx,int64)==0, HcclDataType::HCCL_DATA_TYPE_INT64)
|
||||
@if(@strcmp(@Tx,uint8)==0, HcclDataType::HCCL_DATA_TYPE_UINT8)
|
||||
@if(@strcmp(@Tx,float16)==0, HcclDataType::HCCL_DATA_TYPE_FP16)
|
||||
)
|
||||
@define(REDUCE_OP,
|
||||
@if(@strcmp(@Op,sum)==0, HcclReduceOp::HCCL_REDUCE_SUM)
|
||||
@if(@strcmp(@Op,prod)==0, HcclReduceOp::HCCL_REDUCE_PROD)
|
||||
@if(@strcmp(@Op,max)==0, HcclReduceOp::HCCL_REDUCE_MAX)
|
||||
@if(@strcmp(@Op,min)==0, HcclReduceOp::HCCL_REDUCE_MIN)
|
||||
)
|
||||
auto* __restrict__ xp = x->ptr<Tx>();
|
||||
auto* __restrict__ yp = y->ptr<Tx>();
|
||||
ACLCHECK(aclrtSynchronizeDevice());
|
||||
ACLCHECK(aclrtSynchronizeStream(aclstream));
|
||||
HCCLCHECK(HcclAllReduce(xp, yp, (uint64_t)x->num, @T_HCCL, @REDUCE_OP, comm, aclstream));
|
||||
ACLCHECK(aclrtSynchronizeDevice());
|
||||
ACLCHECK(aclrtSynchronizeStream(aclstream));
|
||||
}
|
||||
|
||||
#endif // JIT
|
||||
|
||||
} // jittor
|
|
@ -0,0 +1,18 @@
|
|||
#pragma once
|
||||
#include "op.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
struct HcclAllReduceOp : Op {
|
||||
Var* x, * y;
|
||||
string reduce_op;
|
||||
|
||||
HcclAllReduceOp(Var* x, string reduce_op="sum");
|
||||
void infer_shape() override;
|
||||
|
||||
const char* name() const override { return "hccl_all_reduce"; }
|
||||
VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override;
|
||||
DECLARE_jit_run;
|
||||
};
|
||||
|
||||
} // jittor
|
|
@ -0,0 +1,63 @@
|
|||
#include "var.h"
|
||||
#include "hccl_broadcast_op.h"
|
||||
#include "ops/op_register.h"
|
||||
#include "utils/str_utils.h"
|
||||
#include "hccl_wrapper.h"
|
||||
#include <cassert>
|
||||
|
||||
namespace jittor {
|
||||
|
||||
#ifndef JIT
|
||||
|
||||
static auto hccl_broadcast =
|
||||
get_op_info("hccl_broadcast").get_constructor<VarPtr, Var*, int>();
|
||||
|
||||
HcclBroadcastOp::HcclBroadcastOp(Var* x, int root) : x(x), root(root) {
|
||||
flags.set(NodeFlags::_cpu, 0);
|
||||
flags.set(NodeFlags::_cuda, 1);
|
||||
y = create_output(nullptr, x->dtype());
|
||||
}
|
||||
|
||||
void HcclBroadcastOp::infer_shape() {
|
||||
y->set_shape(x->shape);
|
||||
}
|
||||
|
||||
VarPtr HcclBroadcastOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
||||
return hccl_broadcast(dout, root);
|
||||
}
|
||||
|
||||
void HcclBroadcastOp::jit_prepare(JK& jk) {
|
||||
jk << "«Tx:" << x->dtype();
|
||||
jk << "«Root:" << root;
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
|
||||
void HcclBroadcastOp::jit_run() {
|
||||
//LOGir << "HcclBroadcastOp::jit_run";
|
||||
@define(T_HCCL,
|
||||
@if(@strcmp(@Tx,float)==0 || @strcmp(@Tx,float32)==0, HcclDataType::HCCL_DATA_TYPE_FP32)
|
||||
@if(@strcmp(@Tx,int)==0 || @strcmp(@Tx,int32)==0, HcclDataType::HCCL_DATA_TYPE_INT32)
|
||||
@if(@strcmp(@Tx,float64)==0, HcclDataType::HCCL_DATA_TYPE_FP64)
|
||||
@if(@strcmp(@Tx,int64)==0, HcclDataType::HCCL_DATA_TYPE_INT64)
|
||||
@if(@strcmp(@Tx,uint8)==0, HcclDataType::HCCL_DATA_TYPE_UINT8)
|
||||
@if(@strcmp(@Tx,float16)==0, HcclDataType::HCCL_DATA_TYPE_FP16)
|
||||
)
|
||||
auto* __restrict__ xp = x->ptr<Tx>();
|
||||
auto* __restrict__ yp = y->ptr<Tx>();
|
||||
//LOGir << "HcclBroadcastOp::jit_run " << @Root << " " << hccl_device_id << " " << xp << " " << yp;
|
||||
//ACLCHECK(aclrtSynchronizeStream(aclstream));
|
||||
ACLCHECK(aclrtSynchronizeDevice());
|
||||
ACLCHECK(aclrtSynchronizeStream(aclstream));
|
||||
HCCLCHECK(HcclBroadcast(@Root == hccl_device_id ? xp : yp, (uint64_t)x->num, @T_HCCL, @Root, comm, aclstream));
|
||||
if (@Root == hccl_device_id) {
|
||||
ACLCHECK(aclrtMemcpy(yp, x->num * sizeof(Tx), xp, x->num * sizeof(Tx), ACL_MEMCPY_DEVICE_TO_DEVICE));
|
||||
ACLCHECK(aclrtSynchronizeDevice());
|
||||
}
|
||||
ACLCHECK(aclrtSynchronizeDevice());
|
||||
ACLCHECK(aclrtSynchronizeStream(aclstream));
|
||||
}
|
||||
|
||||
#endif // JIT
|
||||
|
||||
} // jittor
|
|
@ -0,0 +1,18 @@
|
|||
#pragma once
|
||||
#include "op.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
struct HcclBroadcastOp : Op {
|
||||
Var* x, * y;
|
||||
int root;
|
||||
|
||||
HcclBroadcastOp(Var* x, int root=0);
|
||||
void infer_shape() override;
|
||||
|
||||
const char* name() const override { return "hccl_broadcast"; }
|
||||
VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override;
|
||||
DECLARE_jit_run;
|
||||
};
|
||||
|
||||
} // jittor
|
|
@ -0,0 +1,63 @@
|
|||
#include "var.h"
|
||||
#include "hccl_reduce_op.h"
|
||||
#include "ops/op_register.h"
|
||||
#include "utils/str_utils.h"
|
||||
#include "hccl_wrapper.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
#ifndef JIT
|
||||
|
||||
HcclReduceOp::HcclReduceOp(Var* x, string reduce_op, int root) : x(x), reduce_op(reduce_op), root(root) {
|
||||
flags.set(NodeFlags::_cpu, 0);
|
||||
flags.set(NodeFlags::_cuda, 1);
|
||||
y = create_output(nullptr, x->dtype());
|
||||
}
|
||||
|
||||
void HcclReduceOp::infer_shape() {
|
||||
y->set_shape(x->shape);
|
||||
}
|
||||
|
||||
VarPtr HcclReduceOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
||||
static auto hccl_broadcast =
|
||||
get_op_info("hccl_broadcast").get_constructor<VarPtr, Var*, int>();
|
||||
return hccl_broadcast(dout, root);
|
||||
}
|
||||
|
||||
void HcclReduceOp::jit_prepare(JK& jk) {
|
||||
jk << "«Tx:" << x->dtype();
|
||||
jk << "«Op:" << reduce_op;
|
||||
jk << "«Root:" << root;
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
|
||||
void HcclReduceOp::jit_run() {
|
||||
LOGir << "HcclReduceOp::jit_run";
|
||||
@define(T_HCCL,
|
||||
@if(@strcmp(@Tx,float)==0 || @strcmp(@Tx,float32)==0, HcclDataType::HCCL_DATA_TYPE_FP32)
|
||||
@if(@strcmp(@Tx,int)==0 || @strcmp(@Tx,int32)==0, HcclDataType::HCCL_DATA_TYPE_INT32)
|
||||
@if(@strcmp(@Tx,float64)==0, HcclDataType::HCCL_DATA_TYPE_FP64)
|
||||
@if(@strcmp(@Tx,int64)==0, HcclDataType::HCCL_DATA_TYPE_INT64)
|
||||
@if(@strcmp(@Tx,uint8)==0, HcclDataType::HCCL_DATA_TYPE_UINT8)
|
||||
@if(@strcmp(@Tx,float16)==0, HcclDataType::HCCL_DATA_TYPE_FP16)
|
||||
)
|
||||
@define(REDUCE_OP,
|
||||
@if(@strcmp(@Op,sum)==0, HcclReduceOp::HCCL_REDUCE_SUM)
|
||||
@if(@strcmp(@Op,prod)==0, HcclReduceOp::HCCL_REDUCE_PROD)
|
||||
@if(@strcmp(@Op,max)==0, HcclReduceOp::HCCL_REDUCE_MAX)
|
||||
@if(@strcmp(@Op,min)==0, HcclReduceOp::HCCL_REDUCE_MIN)
|
||||
)
|
||||
auto* __restrict__ xp = x->ptr<Tx>();
|
||||
auto* __restrict__ yp = y->ptr<Tx>();
|
||||
|
||||
ACLCHECK(aclrtSynchronizeDevice());
|
||||
ACLCHECK(aclrtSynchronizeStream(aclstream));
|
||||
HCCLCHECK(HcclReduce(xp, yp, (uint64_t)x->num, @T_HCCL, @REDUCE_OP, @Root, comm, aclstream));
|
||||
ACLCHECK(aclrtSynchronizeDevice());
|
||||
ACLCHECK(aclrtSynchronizeStream(aclstream));
|
||||
}
|
||||
|
||||
#endif // JIT
|
||||
|
||||
} // jittor
|
|
@ -0,0 +1,19 @@
|
|||
#pragma once
|
||||
#include "op.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
struct HcclReduceOp : Op {
|
||||
Var* x, * y;
|
||||
string reduce_op;
|
||||
int root;
|
||||
|
||||
HcclReduceOp(Var* x, string reduce_op="sum", int root=0);
|
||||
void infer_shape() override;
|
||||
|
||||
const char* name() const override { return "hccl_reduce"; }
|
||||
VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override;
|
||||
DECLARE_jit_run;
|
||||
};
|
||||
|
||||
} // jittor
|
|
@ -0,0 +1,60 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2025 Jittor.
|
||||
// All Rights Reserved.
|
||||
// Maintainers:
|
||||
// Jiapeng Zhang <zjp24@mails.tsinghua.edu.cn>.
|
||||
//
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
|
||||
#include "hccl_wrapper.h"
|
||||
#include "event_queue.h"
|
||||
#include "acl_jittor.h"
|
||||
#include <acl/acl.h>
|
||||
|
||||
namespace jittor {
|
||||
|
||||
HcclRootInfo root_info;
|
||||
HcclComm comm;
|
||||
uint32_t hccl_device_id = 0;
|
||||
|
||||
struct hccl_initer {
|
||||
uint32_t device_count = 0;
|
||||
hccl_initer() {
|
||||
ACLCHECK(aclrtGetDeviceCount(&device_count));
|
||||
if (!device_count) return;
|
||||
if (!inside_mpi) return;
|
||||
hccl_device_id = mpi_local_rank;
|
||||
if (mpi_local_rank >= device_count) {
|
||||
LOGw << "mpi_local_rank(">>mpi_local_rank>>") is larger than device_count("
|
||||
>>device_count>>")";
|
||||
hccl_device_id = hccl_device_id % device_count;
|
||||
}
|
||||
LOGv << "HCCL init in device" << hccl_device_id << "local_rank" << mpi_local_rank;
|
||||
//LOGir << aclstream;
|
||||
//event_queue.run_sync([]() {
|
||||
ACLCHECK(aclrtSetDevice(hccl_device_id));
|
||||
//});
|
||||
use_device_mpi = true;
|
||||
LOGir << "HCCL init in device" << hccl_device_id << "local_rank" << mpi_local_rank;
|
||||
if (mpi_world_rank == 0)
|
||||
HCCLCHECK(HcclGetRootInfo(&root_info));
|
||||
MPI_CHECK(MPI_Bcast(&root_info, HCCL_ROOT_INFO_BYTES, MPI_CHAR, 0, MPI_COMM_WORLD));
|
||||
//MPI_Barrier(MPI_COMM_WORLD);
|
||||
LOGir << "Count:" << device_count << "HCCL init in device" << hccl_device_id;
|
||||
HCCLCHECK(HcclCommInitRootInfo(device_count, &root_info, hccl_device_id, &comm));
|
||||
ACLCHECK(aclrtCreateStream(&aclstream));
|
||||
LOGi << "HCCL init success in device" << hccl_device_id;
|
||||
}
|
||||
|
||||
~hccl_initer() {
|
||||
if (!device_count) return;
|
||||
if (!inside_mpi) return;
|
||||
if (!use_device_mpi) return;
|
||||
HCCLCHECK(HcclCommDestroy(comm));
|
||||
}
|
||||
};
|
||||
|
||||
static hccl_initer hccl_initer;
|
||||
}
|
|
@ -37,14 +37,23 @@ MpiAllReduceOp::MpiAllReduceOp(Var* x, NanoString op) : x(x), op(op) {
|
|||
}
|
||||
ASSERT(op == ns_add) << "Not supported MPI op" << op;
|
||||
#ifdef HAS_CUDA
|
||||
|
||||
if (use_device_mpi && use_cuda) {
|
||||
static auto nccl_all_reduce = has_op("nccl_all_reduce")
|
||||
? get_op_info("nccl_all_reduce").get_constructor<VarPtr, Var*>()
|
||||
: nullptr;
|
||||
static auto hccl_all_reduce = has_op("hccl_all_reduce")
|
||||
? get_op_info("hccl_all_reduce").get_constructor<VarPtr, Var*, string>()
|
||||
: nullptr;
|
||||
if (nccl_all_reduce) {
|
||||
auto var = nccl_all_reduce(x);
|
||||
forward(var);
|
||||
return;
|
||||
} else if (hccl_all_reduce) {
|
||||
auto var = hccl_all_reduce(x, "sum");
|
||||
//exe.run_sync({var}, true);
|
||||
forward(var);
|
||||
return;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -26,10 +26,18 @@ MpiBroadcastOp::MpiBroadcastOp(Var* x, int root) : x(x), root(root) {
|
|||
static auto nccl_broadcast = has_op("nccl_broadcast")
|
||||
? get_op_info("nccl_broadcast").get_constructor<VarPtr, Var*, int>()
|
||||
: nullptr;
|
||||
static auto hccl_broadcast = has_op("hccl_broadcast")
|
||||
? get_op_info("hccl_broadcast").get_constructor<VarPtr, Var*, int>()
|
||||
: nullptr;
|
||||
if (nccl_broadcast) {
|
||||
auto var = nccl_broadcast(x, root);
|
||||
forward(var);
|
||||
return;
|
||||
} else if (hccl_broadcast) {
|
||||
auto var = hccl_broadcast(x, root);
|
||||
//exe.run_sync({var}, true);
|
||||
forward(var);
|
||||
return;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -41,10 +41,18 @@ MpiReduceOp::MpiReduceOp(Var* x, NanoString op, int root) : x(x), op(op), root(r
|
|||
static auto nccl_reduce = has_op("nccl_reduce")
|
||||
? get_op_info("nccl_reduce").get_constructor<VarPtr, Var*, int>()
|
||||
: nullptr;
|
||||
static auto hccl_reduce = has_op("hccl_reduce")
|
||||
? get_op_info("hccl_reduce").get_constructor<VarPtr, Var*, string, int>()
|
||||
: nullptr;
|
||||
if (nccl_reduce) {
|
||||
auto var = nccl_reduce(x, root);
|
||||
forward(var);
|
||||
return;
|
||||
} else if (hccl_reduce) {
|
||||
auto var = hccl_reduce(x, "sum", root);
|
||||
//exe.run_sync({var}, true);
|
||||
forward(var);
|
||||
return;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -7,7 +7,13 @@
|
|||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#include <cmath>
|
||||
#include <limits>
|
||||
#include <cstring>
|
||||
|
||||
#if defined(__x86_64__) || defined(_M_X64)
|
||||
#include <immintrin.h>
|
||||
#endif
|
||||
#include <unistd.h>
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
|
@ -31,8 +37,80 @@ namespace jittor {
|
|||
MPI_Datatype MPI_HALF;
|
||||
MPI_Op MPI_HALF_ADD;
|
||||
|
||||
#if !defined(__x86_64__) && !defined(_M_X64)
|
||||
// ARM架构下的FP16-FP32转换辅助函数
|
||||
static inline float fp16_to_fp32_value(uint16_t h) {
|
||||
unsigned sign = ((h >> 15) & 1);
|
||||
unsigned exponent = ((h >> 10) & 0x1f);
|
||||
unsigned mantissa = ((h & 0x3ff) << 13);
|
||||
|
||||
if (exponent == 0) {
|
||||
if (mantissa == 0) {
|
||||
return sign ? -0.0f : 0.0f;
|
||||
} else {
|
||||
// 非规格化数
|
||||
while (!(mantissa & 0x400000)) {
|
||||
mantissa <<= 1;
|
||||
exponent -= 1;
|
||||
}
|
||||
exponent += 1;
|
||||
mantissa &= ~0x400000;
|
||||
}
|
||||
} else if (exponent == 31) {
|
||||
if (mantissa == 0) {
|
||||
return sign ? -std::numeric_limits<float>::infinity() : std::numeric_limits<float>::infinity();
|
||||
} else {
|
||||
return std::numeric_limits<float>::quiet_NaN();
|
||||
}
|
||||
}
|
||||
|
||||
exponent += (127 - 15);
|
||||
mantissa <<= 10;
|
||||
|
||||
unsigned int i = ((sign << 31) | (exponent << 23) | mantissa);
|
||||
float f;
|
||||
std::memcpy(&f, &i, sizeof(float));
|
||||
return f;
|
||||
}
|
||||
|
||||
static inline uint16_t fp32_to_fp16_value(float f) {
|
||||
unsigned int i;
|
||||
std::memcpy(&i, &f, sizeof(float));
|
||||
|
||||
unsigned sign = ((i >> 31) & 0x1);
|
||||
unsigned exponent = ((i >> 23) & 0xff);
|
||||
unsigned mantissa = (i & 0x7fffff);
|
||||
|
||||
unsigned short h = 0;
|
||||
|
||||
if (exponent == 0) {
|
||||
// 零或非规格化数
|
||||
h = (sign << 15);
|
||||
} else if (exponent == 0xff) {
|
||||
// 无穷大或NaN
|
||||
h = (sign << 15) | 0x7c00;
|
||||
if (mantissa) h |= 0x200;
|
||||
} else {
|
||||
// 规格化数
|
||||
int new_exp = exponent - 127 + 15;
|
||||
if (new_exp < 0) {
|
||||
// 下溢出到零
|
||||
h = (sign << 15);
|
||||
} else if (new_exp > 30) {
|
||||
// 上溢出到无穷大
|
||||
h = (sign << 15) | 0x7c00;
|
||||
} else {
|
||||
// 正常转换
|
||||
h = (sign << 15) | (new_exp << 10) | (mantissa >> 13);
|
||||
}
|
||||
}
|
||||
|
||||
return h;
|
||||
}
|
||||
#endif
|
||||
|
||||
void HalfAdd(void* invec, void* inoutvec, int* len, MPI_Datatype* type) {
|
||||
// return;
|
||||
#if defined(__x86_64__) || defined(_M_X64)
|
||||
short* in = (short*)invec;
|
||||
short* inout = (short*)inoutvec;
|
||||
|
||||
|
@ -62,9 +140,27 @@ void HalfAdd(void* invec, void* inoutvec, int* len, MPI_Datatype* type) {
|
|||
// 将单精度浮点数转换回半精度浮点数,并存储结果
|
||||
*(inout + i) = _mm_cvtps_ph(out, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)[0];
|
||||
}
|
||||
#else
|
||||
// ARM架构实现:使用基本的半精度浮点数运算
|
||||
uint16_t* in = (uint16_t*)invec;
|
||||
uint16_t* inout = (uint16_t*)inoutvec;
|
||||
int total = *len;
|
||||
|
||||
// 简单的逐元素相加实现
|
||||
for (int i = 0; i < total; i++) {
|
||||
// 将FP16转换为FP32
|
||||
float in_val = fp16_to_fp32_value(in[i]);
|
||||
float inout_val = fp16_to_fp32_value(inout[i]);
|
||||
|
||||
// 执行加法
|
||||
float result = in_val + inout_val;
|
||||
|
||||
// 将结果转回FP16
|
||||
inout[i] = fp32_to_fp16_value(result);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
int mpi_world_size = 1;
|
||||
int mpi_world_rank = 0;
|
||||
int mpi_local_size = 1;
|
||||
|
|
|
@ -31,15 +31,15 @@ cudaEvent_t event;
|
|||
struct Init {
|
||||
Init() {
|
||||
if (!get_device_count()) return;
|
||||
checkCudaErrors(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking));
|
||||
checkCudaErrors(cudaEventCreate(&event, cudaEventDisableTiming));
|
||||
//checkCudaErrors(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking));
|
||||
//checkCudaErrors(cudaEventCreate(&event, cudaEventDisableTiming));
|
||||
stream = aclstream;
|
||||
}
|
||||
~Init() {
|
||||
if (!get_device_count()) return;
|
||||
peekCudaErrors(cudaDeviceSynchronize());
|
||||
peekCudaErrors(cudaStreamDestroy(stream));
|
||||
peekCudaErrors(cudaEventDestroy(event));
|
||||
//peekCudaErrors(cudaDeviceSynchronize());
|
||||
//peekCudaErrors(cudaStreamDestroy(stream));
|
||||
//peekCudaErrors(cudaEventDestroy(event));
|
||||
}
|
||||
} init;
|
||||
|
||||
|
|
Loading…
Reference in New Issue