JittorMirror/python/jittor/src/ops/copy_op.cc

57 lines
1.5 KiB
C++

// ***************************************************************
// Copyright (c) 2023 Jittor. All Rights Reserved.
// Maintainers:
// Dun Liang <randonlang@gmail.com>.
//
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#include "var.h"
#include "ops/op_register.h"
#include "ops/copy_op.h"
#ifdef HAS_CUDA
#include <cuda_runtime.h>
#include "helper_cuda.h"
#include "misc/cuda_flags.h"
#endif
namespace jittor {
// EXTERN_LIB aclrtStream aclstream;
CopyOp::CopyOp(Var* x) {
flags.set(NodeFlags::_cpu);
flags.set(NodeFlags::_cuda);
flags.set(NodeFlags::_manual_set_vnbb);
auto y = create_output(nullptr, x->dtype());
if (x->name.ptr)
y->name = x->name;
}
VarPtr CopyOp::grad(Var* out, Var* dout, Var* v, int v_index) {
return dout;
}
void CopyOp::infer_shape() {
outputs().front()->set_shape(inputs().front()->shape);
}
void CopyOp::run() {
auto x = inputs().front();
auto size = x->size;
auto x_ptr = x->mem_ptr;
auto y_ptr = outputs().front()->mem_ptr;
#ifdef HAS_CUDA
if (flags.get(NodeFlags::_cuda)) {
checkCudaErrors(cudaMemcpyAsync(y_ptr, x_ptr, size, cudaMemcpyDeviceToDevice, 0));
// checkCudaErrors(aclrtMemcpyAsync(y_ptr, size, x_ptr, size, cudaMemcpyDeviceToDevice, aclstream));
// checkCudaErrors(aclrtSynchronizeStream(aclstream));
} else
#endif
{
std::memcpy(y_ptr, x_ptr, size);
}
}
} // jittor