polish getitem inplace array

This commit is contained in:
Dun Liang 2021-01-18 15:03:40 +08:00
parent 3734af07c0
commit b9f6f048cc
3 changed files with 23 additions and 1 deletions

View File

@ -132,6 +132,22 @@ class TestSetitem(unittest.TestCase):
a = jt.zeros((3,))
a[0:] = 1.0
assert a.data[2] == 1
@unittest.skipIf(not jt.compiler.has_cuda, "No CUDA found")
@jt.flag_scope(use_cuda=1)
def test_getitem_inplace_array(self):
a = jt.array([[1,2],[3,4]])
assert (a[0].numpy() == [1,2]).all(), a[0].numpy()
assert (a[1].numpy() == [3,4]).all(), a[1].numpy()
@unittest.skipIf(not jt.compiler.has_cuda, "No CUDA found")
@jt.flag_scope(use_cuda=1)
def test_setitem_inplace_array(self):
a = jt.array([[1,2],[3,4]])
a[0,0] = -1
a[1,1] = -2
assert (a[0].numpy() == [-1,2]).all(), a[0].numpy()
assert (a[1].numpy() == [3,-2]).all(), a[1].numpy()
if __name__ == "__main__":
unittest.main()

View File

@ -102,8 +102,13 @@ struct DelayFree final : Allocator {
void migrate_to_cpu(void*& mem_ptr, size_t& allocation, size_t size, Allocator* allocator) {
auto da = cuda_dual_allocator.get_dual_allocation(allocation);
auto pre_allocation = allocation;
auto offset = (int64)mem_ptr - (int64)da.device_ptr;
mem_ptr = allocator->alloc(size, allocation);
std::memcpy(mem_ptr, da.host_ptr, size);
checkCudaErrors(cudaMemcpy(mem_ptr,
(void*)((int64)da.device_ptr+offset), size, cudaMemcpyDeviceToHost));
// std::memcpy(mem_ptr, (void*)((int64)da.host_ptr+offset), size);
free(da.device_ptr, size, pre_allocation);
}
};

View File

@ -14,6 +14,7 @@
#include "executor.h"
#include "graph.h"
#include "update_queue.h"
#include "mem/allocator/cuda_dual_allocator.h"
namespace jittor {