fix cupy mpirun bug

This commit is contained in:
Gword 2021-01-17 13:31:39 +08:00
parent 56cc2f50f4
commit 077f5dd8ef
1 changed files with 4 additions and 1 deletions

View File

@ -18,10 +18,13 @@ if has_cupy:
import jittor as jt
import os
import ctypes
print("jt.mpi.local_rank()",jt.mpi.local_rank())
cupy_device = cp.cuda.Device(jt.mpi.local_rank())
cupy_device.__enter__()
def cvt(a):
a_pointer, read_only_flag = a.__array_interface__['data']
aptr=cp.cuda.MemoryPointer(cp.cuda.memory.UnownedMemory(a_pointer,a.size*a.itemsize,a,0),0)
aptr=cp.cuda.MemoryPointer(cp.cuda.memory.UnownedMemory(a_pointer,a.size*a.itemsize,a, jt.mpi.local_rank()),0)
a = cp.ndarray(a.shape,a.dtype,aptr)
return a