mirror of https://github.com/Jittor/Jittor
fix cupy mpirun bug
This commit is contained in:
parent
56cc2f50f4
commit
077f5dd8ef
|
@ -18,10 +18,13 @@ if has_cupy:
|
|||
import jittor as jt
|
||||
import os
|
||||
import ctypes
|
||||
print("jt.mpi.local_rank()",jt.mpi.local_rank())
|
||||
cupy_device = cp.cuda.Device(jt.mpi.local_rank())
|
||||
cupy_device.__enter__()
|
||||
|
||||
def cvt(a):
|
||||
a_pointer, read_only_flag = a.__array_interface__['data']
|
||||
aptr=cp.cuda.MemoryPointer(cp.cuda.memory.UnownedMemory(a_pointer,a.size*a.itemsize,a,0),0)
|
||||
aptr=cp.cuda.MemoryPointer(cp.cuda.memory.UnownedMemory(a_pointer,a.size*a.itemsize,a, jt.mpi.local_rank()),0)
|
||||
a = cp.ndarray(a.shape,a.dtype,aptr)
|
||||
return a
|
||||
|
||||
|
|
Loading…
Reference in New Issue