mirror of https://github.com/Jittor/Jittor
polish getitem inplace array
This commit is contained in:
parent
3734af07c0
commit
b9f6f048cc
|
@ -132,6 +132,22 @@ class TestSetitem(unittest.TestCase):
|
|||
a = jt.zeros((3,))
|
||||
a[0:] = 1.0
|
||||
assert a.data[2] == 1
|
||||
|
||||
@unittest.skipIf(not jt.compiler.has_cuda, "No CUDA found")
|
||||
@jt.flag_scope(use_cuda=1)
|
||||
def test_getitem_inplace_array(self):
|
||||
a = jt.array([[1,2],[3,4]])
|
||||
assert (a[0].numpy() == [1,2]).all(), a[0].numpy()
|
||||
assert (a[1].numpy() == [3,4]).all(), a[1].numpy()
|
||||
|
||||
@unittest.skipIf(not jt.compiler.has_cuda, "No CUDA found")
|
||||
@jt.flag_scope(use_cuda=1)
|
||||
def test_setitem_inplace_array(self):
|
||||
a = jt.array([[1,2],[3,4]])
|
||||
a[0,0] = -1
|
||||
a[1,1] = -2
|
||||
assert (a[0].numpy() == [-1,2]).all(), a[0].numpy()
|
||||
assert (a[1].numpy() == [3,-2]).all(), a[1].numpy()
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -102,8 +102,13 @@ struct DelayFree final : Allocator {
|
|||
void migrate_to_cpu(void*& mem_ptr, size_t& allocation, size_t size, Allocator* allocator) {
|
||||
auto da = cuda_dual_allocator.get_dual_allocation(allocation);
|
||||
auto pre_allocation = allocation;
|
||||
auto offset = (int64)mem_ptr - (int64)da.device_ptr;
|
||||
|
||||
mem_ptr = allocator->alloc(size, allocation);
|
||||
std::memcpy(mem_ptr, da.host_ptr, size);
|
||||
|
||||
checkCudaErrors(cudaMemcpy(mem_ptr,
|
||||
(void*)((int64)da.device_ptr+offset), size, cudaMemcpyDeviceToHost));
|
||||
// std::memcpy(mem_ptr, (void*)((int64)da.host_ptr+offset), size);
|
||||
free(da.device_ptr, size, pre_allocation);
|
||||
}
|
||||
};
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
#include "executor.h"
|
||||
#include "graph.h"
|
||||
#include "update_queue.h"
|
||||
#include "mem/allocator/cuda_dual_allocator.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
|
|
Loading…
Reference in New Issue