Merge pull request #17 from CSCG-Lab/splits

fix conv2dbackward
This commit is contained in:
Yuxuan Han 2025-01-02 14:40:25 +08:00 committed by GitHub
commit 7ba878bf49
2 changed files with 29 additions and 0 deletions

View File

@ -81,6 +81,34 @@ namespace jittor
use_nchw = true;
}
void Conv2dBackwardOpRunner::setupOutputDesc()
{
auto output_num = out_.size();
for (int output_idx = 0; output_idx < output_num; output_idx++)
{
std::vector<int64_t> shape;
for (int j = 0; j < out_[output_idx]->shape.size(); j++)
{
shape.push_back(out_[output_idx]->shape[j]);
}
outputShapes.push_back(shape);
}
for (int idx = 0; idx < 2; idx++)
{
outputTensors.push_back(nullptr);
auto ret = CreateAclTensor(outputShapes[idx], out_[idx]->mem_ptr, out_[idx]->size, get_dtype(out_[idx]->dtype()), &outputTensors[idx], use_nchw);
CHECK_RET(ret == ACL_SUCCESS, return);
}
// biasgrad nd format
{
outputTensors.push_back(nullptr);
auto ret = CreateAclTensor(outputShapes[2], out_[2]->mem_ptr, out_[2]->size, get_dtype(out_[2]->dtype()), &outputTensors[2], false);
CHECK_RET(ret == ACL_SUCCESS, return);
}
}
void Conv2dBackwardOpRunner::executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it)
{
// for conv

View File

@ -19,6 +19,7 @@ namespace jittor
protected:
void executeOp(std::unordered_map<string, AclOpFunctions>::iterator &it) override;
void setupOutputDesc() override;
public:
Conv2dBackwardOpRunner();