mirror of https://github.com/Jittor/Jittor
commit
7ba878bf49
|
@ -81,6 +81,34 @@ namespace jittor
|
|||
use_nchw = true;
|
||||
}
|
||||
|
||||
void Conv2dBackwardOpRunner::setupOutputDesc()
|
||||
{
|
||||
auto output_num = out_.size();
|
||||
|
||||
for (int output_idx = 0; output_idx < output_num; output_idx++)
|
||||
{
|
||||
std::vector<int64_t> shape;
|
||||
for (int j = 0; j < out_[output_idx]->shape.size(); j++)
|
||||
{
|
||||
shape.push_back(out_[output_idx]->shape[j]);
|
||||
}
|
||||
outputShapes.push_back(shape);
|
||||
}
|
||||
|
||||
for (int idx = 0; idx < 2; idx++)
|
||||
{
|
||||
outputTensors.push_back(nullptr);
|
||||
auto ret = CreateAclTensor(outputShapes[idx], out_[idx]->mem_ptr, out_[idx]->size, get_dtype(out_[idx]->dtype()), &outputTensors[idx], use_nchw);
|
||||
CHECK_RET(ret == ACL_SUCCESS, return);
|
||||
}
|
||||
// biasgrad nd format
|
||||
{
|
||||
outputTensors.push_back(nullptr);
|
||||
auto ret = CreateAclTensor(outputShapes[2], out_[2]->mem_ptr, out_[2]->size, get_dtype(out_[2]->dtype()), &outputTensors[2], false);
|
||||
CHECK_RET(ret == ACL_SUCCESS, return);
|
||||
}
|
||||
}
|
||||
|
||||
void Conv2dBackwardOpRunner::executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it)
|
||||
{
|
||||
// for conv
|
||||
|
|
|
@ -19,6 +19,7 @@ namespace jittor
|
|||
|
||||
protected:
|
||||
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
|
||||
void setupOutputDesc() override;
|
||||
|
||||
public:
|
||||
Conv2dBackwardOpRunner();
|
||||
|
|
Loading…
Reference in New Issue