This commit is contained in:
Exusial 2025-03-30 09:18:43 +08:00
commit 917d122e96
28 changed files with 582 additions and 49 deletions

View File

@ -96,7 +96,6 @@ Jittor environment requirements:
| OS | CPU | Python | Compiler | (Optional) GPU platform |
|--------------------------------------------------------|-------------------------------------|--------|--------------|---------------------------------------------|
| Linux<br>(Ubuntu, CentOS, Arch, <br>UOS, KylinOS, ...) | x86 <br>x86_64 <br>ARM <br>loongson | >= 3.7 | g++ >=5.4 | Nvidia CUDA >= 10.0, [cuDNN](https://docs.nvidia.com/deeplearning/cudnn/install-guide/index.html#installlinux-tar) <br> or [AMD ROCm](https://docs.amd.com/) >= 4.0 <br> or [Hygon DCU DTK](https://tycloud.hpccube.com/doc/1.0.6/11277/general-handbook/software-tutorial/jittor.html) >= 22.04 |
| macOS <br>(>= 10.14 Mojave) | intel<br>Apple Silicon | >= 3.7 | clang >= 8.0 | - |
| Windows 10 & 11 | x86_64 | [>= 3.8](https://www.python.org/downloads/windows/) | - | Nvidia CUDA >= 10.2 [cuDNN](https://docs.nvidia.com/deeplearning/cudnn/install-guide/index.html#install-windows) |
@ -116,25 +115,6 @@ python3.7 -m jittor.test.test_example
### macOS install
Please first install additional dependencies with [homebrew](https://brew.sh).
```bash
brew install libomp
```
Then you can install jittor through pip and run the example.
```bash
python3.7 -m pip install jittor
python3.7 -m jittor.test.test_example
```
Currently jittor only supports CPU on macOS.
### Windows install

View File

@ -9,7 +9,7 @@
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
__version__ = '1.3.9.10'
__version__ = '1.3.9.14'
from jittor_utils import lock
with lock.lock_scope():
ori_int = int
@ -26,7 +26,7 @@ with lock.lock_scope():
from .compile_extern import mkl_ops, mpi, mpi_ops, in_mpi, rank, world_size
if core.get_device_count() == 0:
has_cuda = compile_extern.has_cuda = compiler.has_cuda = False
from .compile_extern import cudnn, curand, cublas, cufft
from .compile_extern import cudnn, curand, cublas, cufft, cusparse
from .init_cupy import numpy2cupy
from typing import List, Tuple

View File

@ -2,7 +2,7 @@ from jittor_core import *
from jittor_core.ops import *
from .misc import *
from . import attention as attention, contrib as contrib, dataset as dataset, init as init, linalg as linalg, lr_scheduler as lr_scheduler, numpy2cupy as numpy2cupy, optim as optim, sparse as sparse, gradfunctional as gradfunctional
from .compile_extern import cublas as cublas, cudnn as cudnn, cufft as cufft, curand as curand, mkl_ops as mkl_ops, mpi_ops as mpi_ops, world_size as world_size
from .compile_extern import cublas as cublas, cudnn as cudnn, cufft as cufft, curand as curand, cusparse as cusparse ,mkl_ops as mkl_ops, mpi_ops as mpi_ops, world_size as world_size
from .compiler import compile_custom_op as compile_custom_op, compile_custom_ops as compile_custom_ops
from .contrib import concat as concat
from .nn import bmm as bmm, bmm_transpose as bmm_transpose, matmul as matmul

View File

@ -77,7 +77,7 @@ def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.
attn_weight = query @ key.transpose(-2, -1) * scale_factor
attn_weight += attn_bias
attn_weight = softmax(attn_weight, dim=-1)
attn_weight = dropout(attn_weight, dropout_p, train=True)
attn_weight = dropout(attn_weight, dropout_p, is_train=True)
return attn_weight @ value
def _mha_shape_check(query: Var, key: Var, value: Var,

View File

@ -224,7 +224,7 @@ def setup_cuda_extern():
line = traceback.format_exc()
LOG.w(f"CUDA found but cub is not loaded:\n{line}")
libs = ["cublas", "cudnn", "curand", "cufft"]
libs = ["cublas", "cudnn", "curand", "cufft", "cusparse"]
# in cuda 11.4, module memory comsumptions:
# default context: 259 MB
# cublas: 340 MB
@ -240,6 +240,9 @@ def setup_cuda_extern():
msg += """Develop version of CUDNN not found,
please refer to CUDA offical tar file installation:
https://docs.nvidia.com/deeplearning/cudnn/install-guide/index.html#installlinux-tar"""
if lib_name == "cusparse":
msg += """CUSPARSE library is not loaded,
please ensure it is installed along with the CUDA toolkit."""
if platform.machine() in ["x86_64", "AMD64"]:
msg += f"""
or you can let jittor install cuda and cudnn for you:
@ -300,6 +303,13 @@ def setup_cuda_lib(lib_name, link=True, extra_flags=""):
link_flags = f"-l{lib_name} -L\"{os.path.dirname(culib_path)}\""
# print("link_flags", link_flags, culib_path)
if lib_name == "cusparse" :
try:
cusparse_spmv_path = search_file([cuda_lib, extra_lib_path], "libcusparse.so")
ctypes.CDLL(cusparse_spmv_path, dlopen_flags)
except:
LOG.w("Failed to load cusparse-specific shared libraries.")
# find all source files
culib_src_dir = os.path.join(jittor_path, "extern", "cuda", lib_name)
culib_src_files = []
@ -693,7 +703,7 @@ if FIX_TORCH_ERROR:
except:
pass
cudnn = cublas = curand = cufft = None
cudnn = cublas = curand = cufft = cusparse = None
setup_mpi()
rank = mpi.world_rank() if in_mpi else 0
world_size = mpi.world_size() if in_mpi else 1

View File

@ -1371,7 +1371,7 @@ if has_cuda and is_cuda:
nvcc_flags = " " + os.environ.get("nvcc_flags", "") + " "
nvcc_flags += convert_nvcc_flags(cc_flags)
nvcc_version = list(jit_utils.get_int_version(nvcc_path))
max_arch = 89
max_arch = 90
if nvcc_version < [11,]:
max_arch = 75
elif nvcc_version < [11,1]:

View File

@ -95,8 +95,8 @@ VarPtr CudnnConv3dBackwardWOp::grad(Var* out, Var* dout, Var* v, int v_index) {
EXTERN_LIB unordered_map<string, cudnnConvolutionBwdFilterAlgo_t> bwdw_algo_cache;
template <typename T_ELEM> __inline__ cudnnDataType_t getDataType();
template <> __inline__ cudnnDataType_t getDataType<half1>() { return CUDNN_DATA_HALF; }
template <> __inline__ cudnnDataType_t getDataType<float>() { return CUDNN_DATA_FLOAT; }
// template <> __inline__ cudnnDataType_t getDataType<half1>() { return CUDNN_DATA_HALF; }
// template <> __inline__ cudnnDataType_t getDataType<float>() { return CUDNN_DATA_FLOAT; }
void CudnnConv3dBackwardWOp::jit_run() {
auto w = dw;

View File

@ -85,8 +85,8 @@ VarPtr CudnnConv3dBackwardXOp::grad(Var* out, Var* dout, Var* v, int v_index) {
EXTERN_LIB unordered_map<string, cudnnConvolutionBwdDataAlgo_t> bwdx_algo_cache;
template <typename T_ELEM> __inline__ cudnnDataType_t getDataType();
template <> __inline__ cudnnDataType_t getDataType<half1>() { return CUDNN_DATA_HALF; }
template <> __inline__ cudnnDataType_t getDataType<float>() { return CUDNN_DATA_FLOAT; }
// template <> __inline__ cudnnDataType_t getDataType<half1>() { return CUDNN_DATA_HALF; }
// template <> __inline__ cudnnDataType_t getDataType<float>() { return CUDNN_DATA_FLOAT; }
void CudnnConv3dBackwardXOp::jit_run() {
auto x = dx;

View File

@ -88,8 +88,8 @@ VarPtr CudnnConv3dOp::grad(Var* out, Var* dout, Var* v, int v_index) {
EXTERN_LIB unordered_map<string, cudnnConvolutionFwdAlgo_t> fwd_algo_cache;
template <typename T_ELEM> __inline__ cudnnDataType_t getDataType();
template <> __inline__ cudnnDataType_t getDataType<half1>() { return CUDNN_DATA_HALF; }
template <> __inline__ cudnnDataType_t getDataType<float>() { return CUDNN_DATA_FLOAT; }
// template <> __inline__ cudnnDataType_t getDataType<half1>() { return CUDNN_DATA_HALF; }
// template <> __inline__ cudnnDataType_t getDataType<float>() { return CUDNN_DATA_FLOAT; }
void CudnnConv3dOp::jit_run() {
cudnnHandle_t& handle_ = cudnn_handle;

View File

@ -88,8 +88,8 @@ void CudnnRnnBackwardXOp::jit_prepare(JK& jk) {
#ifdef JIT_cuda
template <typename T_ELEM> __inline__ cudnnDataType_t getDataType();
template <> __inline__ cudnnDataType_t getDataType<half1>() { return CUDNN_DATA_HALF; }
template <> __inline__ cudnnDataType_t getDataType<float>() { return CUDNN_DATA_FLOAT; }
// template <> __inline__ cudnnDataType_t getDataType<half1>() { return CUDNN_DATA_HALF; }
// template <> __inline__ cudnnDataType_t getDataType<float>() { return CUDNN_DATA_FLOAT; }
void CudnnRnnBackwardXOp::jit_run() {
int num_directions = 1 + bidirectional;

View File

@ -137,8 +137,8 @@ void CudnnRnnOp::grads(Var** dout, VarPtr* dins) {
#pragma clang diagnostic ignored "-Wtautological-compare"
template <typename T_ELEM> __inline__ cudnnDataType_t getDataType();
template <> __inline__ cudnnDataType_t getDataType<half1>() { return CUDNN_DATA_HALF; }
template <> __inline__ cudnnDataType_t getDataType<float>() { return CUDNN_DATA_FLOAT; }
// template <> __inline__ cudnnDataType_t getDataType<half1>() { return CUDNN_DATA_HALF; }
// template <> __inline__ cudnnDataType_t getDataType<float>() { return CUDNN_DATA_FLOAT; }
void CudnnRnnOp::jit_run() {
int num_directions = bidirectional + 1;

View File

@ -0,0 +1,42 @@
// ***************************************************************
// Copyright (c) 2023 Jittor. All Rights Reserved.
// Maintainers: Shizhan Lu <578752274@qq.com>.
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#pragma once
#include <cuda_runtime.h>
#include <cusparse.h>
#include "utils/log.h"
#include "helper_cuda.h"
#include "fp16_emu.h"
#include "common.h"
#include "misc/nano_string.h"
namespace jittor {
EXTERN_LIB cusparseHandle_t cusparse_handle;
static inline cusparseIndexType_t get_index_dtype(NanoString dtype) {
if (dtype == ns_int32) return CUSPARSE_INDEX_32I;
if (dtype == ns_int64) return CUSPARSE_INDEX_64I;
LOGf << "not support type" << dtype;
return CUSPARSE_INDEX_32I;
}
static inline cudaDataType get_dtype(NanoString dtype) {
if (dtype == ns_float32) return CUDA_R_32F;
if (dtype == ns_float64) return CUDA_R_64F;
if (dtype == ns_float16) return CUDA_R_16F;
#ifndef IS_ROCM
if (dtype == ns_bfloat16) return CUDA_R_16BF;
#endif
LOGf << "not support type" << dtype;
return CUDA_R_32F;
}
static inline cusparseOperation_t get_trans_type(bool trans) {
if (trans) return CUSPARSE_OPERATION_TRANSPOSE;
else return CUSPARSE_OPERATION_NON_TRANSPOSE;
}
} // jittor

View File

@ -0,0 +1,70 @@
// ***************************************************************
// Copyright (c) 2023 Jittor. All Rights Reserved.
// Maintainers: Shizhan Lu <578752274@qq.com>.
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#include "var.h"
#include "cusparse_spmmcoo_op.h"
#include "cusparse_wrapper.h"
using namespace std;
namespace jittor {
#ifndef JIT
CusparseSpmmcooOp::CusparseSpmmcooOp(Var* outputVar_, Var* x_, Var* row_indices_,Var* col_indices_,Var* value_,int A_row_,int A_col_,bool trans_A_,bool trans_B_)
: outputVar(outputVar_), x(x_),row_indices(row_indices_), col_indices(col_indices_), value(value_),A_row(A_row_),A_col(A_col_),trans_A(trans_A_),trans_B(trans_B_) {
flags.set(NodeFlags::_cuda, 1);
flags.set(NodeFlags::_cpu, 0);
flags.set(NodeFlags::_manual_set_vnbb);
ASSERT(x->dtype().is_float() && outputVar->dtype().is_float()) << "type of two inputs should be the same";
output = create_output(nullptr, x->dtype());
}
void CusparseSpmmcooOp::jit_prepare(JK& jk) {
add_jit_define(jk, "T", x->dtype());
add_jit_define(jk, "Tindex", col_indices->dtype());
}
#else // JIT
void CusparseSpmmcooOp::jit_run() {
cusparseSpMatDescr_t matA;
cusparseDnMatDescr_t matB, matC;
cusparseHandle_t &handle_ = cusparse_handle;
// void* dBuffer = NULL;
// size_t bufferSize = 0;
const auto& xs = x->shape;
const auto& vs = value->shape;
const auto& os = outputVar->shape;
ASSERT(xs==os)<<"matrix A and matrix C size not match";
ASSERT(A_col==xs[0])<<"matrix A and matrix B size not match";
auto dtype_A = get_dtype(value->dtype());
auto dtype_B = get_dtype(x->dtype());
auto dtype_C = get_dtype(outputVar->dtype());
auto dtype_index = get_index_dtype(col_indices->dtype());
checkCudaErrors( cusparseCreateCoo(&matA, A_row, A_col, vs[0], row_indices->ptr<Tindex>(), col_indices->ptr<Tindex>(), value->ptr<T>(), dtype_index, CUSPARSE_INDEX_BASE_ZERO, dtype_A) );
checkCudaErrors( cusparseCreateDnMat(&matB, xs[0], xs[1], xs[1], x->ptr<T>(), dtype_B, CUSPARSE_ORDER_ROW) );
checkCudaErrors( cusparseCreateDnMat(&matC, os[0], os[1],os[1], outputVar->ptr<T>(), dtype_C, CUSPARSE_ORDER_ROW) );
float alpha = 1.0f;
float beta = 0.0f;
// checkCudaErrors( cusparseSpMM_bufferSize(
// handle_,
// CUSPARSE_OPERATION_NON_TRANSPOSE,
// CUSPARSE_OPERATION_NON_TRANSPOSE,
// &alpha, matA, matB, &beta, matC, CUDA_R_32F,
// CUSPARSE_SPMM_ALG_DEFAULT , &bufferSize) );
// checkCudaErrors( cudaMalloc(&dBuffer, bufferSize) );
checkCudaErrors( cusparseSpMM(handle_,
get_trans_type(trans_A),
get_trans_type(trans_B),
&alpha, matA, matB, &beta, matC, CUDA_R_32F,
CUSPARSE_SPMM_ALG_DEFAULT, NULL) );
// checkCudaErrors( cudaFree(dBuffer) );
checkCudaErrors( cusparseDestroySpMat(matA) );
checkCudaErrors( cusparseDestroyDnMat(matB) );
checkCudaErrors( cusparseDestroyDnMat(matC) );
}
#endif // JIT
} // jittor

View File

@ -0,0 +1,28 @@
// ***************************************************************
// Copyright (c) 2023 Jittor. All Rights Reserved.
// Maintainers: Shizhan Lu <578752274@qq.com>.
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#pragma once
#include "op.h"
#include "cusparse.h"
namespace jittor {
struct CusparseSpmmcooOp : Op {
Var* x;
Var* outputVar;
Var* row_indices;
Var* col_indices;
Var* value;
Var* output;
int A_row;
int A_col;
bool trans_A;
bool trans_B;
CusparseSpmmcooOp(Var* outputVar_, Var* x_, Var* row_indices_,Var* col_indices_,Var* value_,int A_row,int A_col,bool trans_A,bool trans_B);
const char* name() const override { return "cusparse_spmmcoo"; }
DECLARE_jit_run;
};
} // jittor

View File

@ -0,0 +1,75 @@
// ***************************************************************
// Copyright (c) 2023 Jittor. All Rights Reserved.
// Maintainers: Shizhan Lu <578752274@qq.com>.
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#include "var.h"
#include "cusparse_spmmcsr_op.h"
#include "cusparse_wrapper.h"
using namespace std;
namespace jittor {
#ifndef JIT
CusparseSpmmcsrOp::CusparseSpmmcsrOp(Var* outputVar_, Var* x_, Var* col_indices_,Var* value_,Var* row_offset_,int A_row_,int A_col_,bool trans_A_,bool trans_B_)
: outputVar(outputVar_), x(x_), col_indices(col_indices_), value(value_),row_offset(row_offset_),A_row(A_row_),A_col(A_col_),trans_A(trans_A_),trans_B(trans_B_){
flags.set(NodeFlags::_cuda, 1);
flags.set(NodeFlags::_cpu, 0);
flags.set(NodeFlags::_manual_set_vnbb);
ASSERT(x->dtype().is_float() && outputVar->dtype().is_float()) << "type of two inputs should be the same";
output = create_output(nullptr, x->dtype());
}
void CusparseSpmmcsrOp::jit_prepare(JK& jk) {
add_jit_define(jk, "T", x->dtype());
add_jit_define(jk, "Tindex", col_indices->dtype());
}
#else // JIT
void CusparseSpmmcsrOp::jit_run() {
cusparseSpMatDescr_t matA;
cusparseDnMatDescr_t matB, matC;
void* dBuffer = NULL;
size_t bufferSize = 0;
cusparseHandle_t &handle_ = cusparse_handle;
const auto& xs = x->shape;
const auto& vs = value->shape;
const auto& os = outputVar->shape;
ASSERT(xs==os)<<"matrix A and matrix C size not match";
ASSERT(A_col==xs[0])<<"matrix A and matrix B size not match";
auto dtype_A = get_dtype(value->dtype());
auto dtype_B = get_dtype(x->dtype());
auto dtype_C = get_dtype(outputVar->dtype());
auto dtype_index = get_index_dtype(col_indices->dtype());
checkCudaErrors( cusparseCreateCsr(&matA, A_row, A_col, vs[0], row_offset->ptr<Tindex>(), col_indices->ptr<Tindex>(), value->ptr<T>(), dtype_index, dtype_index, CUSPARSE_INDEX_BASE_ZERO, dtype_A) );
checkCudaErrors( cusparseCreateDnMat(&matB, xs[0], xs[1], xs[1], x->ptr<T>(), dtype_B, CUSPARSE_ORDER_ROW) );
checkCudaErrors( cusparseCreateDnMat(&matC, os[0], os[1],os[1], outputVar->ptr<T>(), dtype_C, CUSPARSE_ORDER_ROW) );
float alpha = 1.0f;
float beta = 0.0f;
// checkCudaErrors( cusparseSpMM_bufferSize(
// handle_,
// CUSPARSE_OPERATION_NON_TRANSPOSE,
// CUSPARSE_OPERATION_NON_TRANSPOSE,
// &alpha, matA, matB, &beta, matC, dtype_C,
// CUSPARSE_SPMM_CSR_ALG2, &bufferSize) );
checkCudaErrors( cusparseSpMM_bufferSize(
handle_,
get_trans_type(trans_A),
get_trans_type(trans_B),
&alpha, matA, matB, &beta, matC, CUDA_R_32F,
CUSPARSE_SPMM_CSR_ALG2, &bufferSize) );
checkCudaErrors( cudaMalloc(&dBuffer, bufferSize) );
// checkCudaErrors(cusparseSpMM(handle_, CUSPARSE_OPERATION_NON_TRANSPOSE, CUSPARSE_OPERATION_NON_TRANSPOSE, &alpha, matA, matB, &beta, matC, dtype_C, CUSPARSE_SPMM_CSR_ALG2 , dBuffer)); //CUSPARSE_MM_ALG_DEFAULT, CUSPARSE_SPMM_CSR_ALG2 , CUSPARSE_SPMM_COO_ALG4
checkCudaErrors(cusparseSpMM(handle_, get_trans_type(trans_A), get_trans_type(trans_B), &alpha, matA, matB, &beta, matC, CUDA_R_32F, CUSPARSE_SPMM_CSR_ALG2 , dBuffer));
checkCudaErrors( cudaFree(dBuffer) );
checkCudaErrors( cusparseDestroySpMat(matA) );
checkCudaErrors( cusparseDestroyDnMat(matB) );
checkCudaErrors( cusparseDestroyDnMat(matC) );
}
#endif // JIT
} // jittor

View File

@ -0,0 +1,28 @@
// ***************************************************************
// Copyright (c) 2023 Jittor. All Rights Reserved.
// Maintainers: Shizhan Lu <578752274@qq.com>.
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#pragma once
#include "op.h"
#include "cusparse.h"
namespace jittor {
struct CusparseSpmmcsrOp : Op {
Var* x;
Var* outputVar;
Var* col_indices;
Var* row_offset;
Var* value;
Var* output;
int A_row;
int A_col;
bool trans_A;
bool trans_B;
CusparseSpmmcsrOp(Var* outputVar_, Var* x_, Var* col_indices_,Var* value_,Var* row_offset_,int A_row,int A_col,bool trans_A,bool trans_B);
const char* name() const override { return "cusparse_spmmcsr"; }
DECLARE_jit_run;
};
} // jittor

View File

@ -0,0 +1,31 @@
// ***************************************************************
// Copyright (c) 2023 Jittor. All Rights Reserved.
// Maintainers: Shizhan Lu <578752274@qq.com>.
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#include "cusparse_wrapper.h"
#include "misc/cuda_flags.h"
namespace jittor {
cusparseHandle_t cusparse_handle;
struct cusparse_initer {
inline cusparse_initer() {
if (!get_device_count()) return;
checkCudaErrors(cusparseCreate(&cusparse_handle));
LOGv << "cusparseCreate finished" << (void*)cusparse_handle;
}
inline ~cusparse_initer() {
if (!get_device_count()) return;
LOGv << "cusparseDestroy:" << (void*)cusparse_handle;
checkCudaErrors(cusparseDestroy(cusparse_handle));
LOGv << "cusparseDestroy finished";
}
} init;
} // jittor

View File

@ -0,0 +1,57 @@
/**
* Copyright 1993-2017 NVIDIA Corporation. All rights reserved.
*
* Please refer to the NVIDIA end user license agreement (EULA) associated
* with this source code for terms and conditions that govern your use of
* this software. Any use, reproduction, disclosure, or distribution of
* this software and related documentation outside the terms of the EULA
* is strictly prohibited.
*
*/
////////////////////////////////////////////////////////////////////////////////
// These are CUDA Helper functions for initialization and error checking
#include <cuda_runtime.h>
#include <cusparse.h>
#include "helper_cuda.h"
#ifdef CUSPARSEAPI
// cuSPARSE API errors
const char *_cudaGetErrorEnum(cusparseStatus_t status) {
switch(status)
{
case CUSPARSE_STATUS_SUCCESS:
return "success";
case CUSPARSE_STATUS_NOT_INITIALIZED:
return "library not initialized";
case CUSPARSE_STATUS_ALLOC_FAILED:
return "resource allocation failed";
case CUSPARSE_STATUS_INVALID_VALUE:
return "an invalid numeric value was used as an argument";
case CUSPARSE_STATUS_ARCH_MISMATCH:
return "an absent device architectural feature is required";
case CUSPARSE_STATUS_MAPPING_ERROR:
return "an access to GPU memory space failed";
case CUSPARSE_STATUS_EXECUTION_FAILED:
return "the GPU program failed to execute";
case CUSPARSE_STATUS_INTERNAL_ERROR:
return "an internal operation failed";
case CUSPARSE_STATUS_MATRIX_TYPE_NOT_SUPPORTED:
return "the matrix type is not supported by this function";
case CUSPARSE_STATUS_ZERO_PIVOT:
return "an entry of the matrix is either structural zero or numerical zero (singular block)";
default:
return "unknown error";
}
}
#endif

View File

@ -347,6 +347,8 @@ def stack(x, dim=0):
[[4 5 6]]]
'''
assert isinstance(x, Sequence)
if isinstance(x, tuple):
x = list(x)
for i,x_ in enumerate(x):
x[i] = jt.array(x_)
if len(x) < 2:
@ -2268,10 +2270,18 @@ bfloat16_finfo.max = 1e38
def finfo(dtype):
if dtype == "bfloat16":
return bfloat16_finfo
return np.finfo(str(dtype).split('.')[-1])
if callable(dtype) and hasattr(dtype, "__name__"):
dtype = dtype.__name__.split('.')[-1]
else:
dtype = str(dtype).split('.')[-1]
return np.finfo(dtype)
def iinfo(dtype):
return np.iinfo(str(dtype).split('.')[-1])
if callable(dtype) and hasattr(dtype, "__name__"):
dtype = dtype.__name__.split('.')[-1]
else:
dtype = str(dtype).split('.')[-1]
return np.iinfo(dtype)
def index_select(input,dim,indices):

View File

@ -572,6 +572,7 @@ class Dropout(Module):
noise = jt.random(input.shape)
noise = (noise > self.p).int()
output = output * noise / (1.0 - self.p) # div keep prob
output = output.to(input.dtype)
return output
def dropout(x,p=0.5,is_train=False):

View File

@ -13,6 +13,7 @@
# ***************************************************************
import jittor as jt
import numpy as np
from copy import deepcopy
class Optimizer(object):
""" Basic class of Optimizer.
@ -38,10 +39,29 @@ class Optimizer(object):
# so we can omit 0+x
self.__zero_grad = True
self._grad_map = {}
self.__input_params = []
def add_param_group(self, group):
self.param_groups.append(group)
def set_input_into_param_group(self, inputs):
""" This function adds inputs to the optimizer as variables that need tuning.
This is to enforce the calculation of gradients from the output to the input,
ensuring that the backward hook is called correctly.
Args:
inputs: List of the input
"""
self.__input_params = []
if isinstance(inputs, jt.Var):
self.__input_params.append(inputs)
elif isinstance(inputs, (list, tuple)):
for v in inputs:
if isinstance(v, jt.Var):
self.__input_params.append(v)
else:
raise NotImplementedError
def clip_grad_norm(self, max_norm:float, norm_type:int=2):
r"""Clips gradient norm of this optimizer.
The norm is computed over all gradients together.
@ -163,6 +183,9 @@ class Optimizer(object):
params.append(p)
if not p.is_stop_grad():
params_has_grad.append(p)
for p in self.__input_params:
if not p.is_stop_grad():
params_has_grad.append(p)
# sync prev params
jt.sync(params_has_grad)

View File

@ -269,7 +269,7 @@ def generate_error_code_from_func_header(func_head, target_scope_name, name, dfs
help_name = ""+target_scope_name+'.'+name
else:
help_name = name
if lib_name in ["mpi", "nccl", "cudnn", "curand" "cufft", "cublas", "mkl"]:
if lib_name in ["mpi", "nccl", "cudnn", "curand" "cufft", "cublas", "mkl", "cusparse"]:
help_name = lib_name+'.'+help_name
help_cmd = f"help(jt.{help_name})"

View File

@ -15,8 +15,8 @@
namespace jittor {
DEFINE_FLAG(int, use_sfrl_allocator, 1, "Enable sfrl allocator");
DEFINE_FLAG(int64, sfrl_large_block_size_device, 20971520, "sfrl_large_block_size, larger will reduce memory shard, only affect device");
constexpr int64 sfrl_large_block_size_cpu=20971520;
DEFINE_FLAG(int64, sfrl_large_block_size_device, 5242880, "sfrl_large_block_size, larger will reduce memory shard, only affect device");
constexpr int64 sfrl_large_block_size_cpu=5242880;
std::vector<size_t> CachingBlockPool::block_ids;
//start from 1

View File

@ -6283,7 +6283,7 @@ mz_bool mz_zip_writer_add_mem_ex_v2(mz_zip_archive *pZip, const char *pArchive_n
if (!(level_and_flags & MZ_ZIP_FLAG_COMPRESSED_DATA))
{
// uncomp_crc32 = (mz_uint32)mz_crc32(MZ_CRC32_INIT, (const mz_uint8 *)pBuf, buf_size);
uncomp_crc32 = (mz_uint32)mz_crc32(MZ_CRC32_INIT, (const mz_uint8 *)pBuf, buf_size);
uncomp_size = buf_size;
if (uncomp_size <= 3)
{

View File

@ -1333,16 +1333,16 @@ struct ZipFile {
memset(zip_archive.get(), 0, sizeof(mz_zip_archive));
if (mode == "r") {
this->mode = 'r';
if (!mz_zip_reader_init_file(zip_archive.get(), filename.c_str(), 0))
zip_archive = nullptr;
if (!mz_zip_reader_init_file(zip_archive.get(), filename.c_str(), 0))
zip_archive = nullptr;
} else if (mode == "w") {
this->mode = 'w';
if (!mz_zip_writer_init_file_v2(zip_archive.get(), filename.c_str(), 0, MZ_ZIP_FLAG_WRITE_ZIP64)) {
zip_archive = nullptr;
}
}
if (!zip_archive)
throw std::runtime_error("Failed to open zip file: " + filename);
// if (!zip_archive)
// throw std::runtime_error("Failed to open zip file: " + filename);
}
// @pyjt(__dealloc__)
inline ~ZipFile() {

View File

@ -0,0 +1,112 @@
# ***************************************************************
# Copyright (c) 2023 Jittor. All Rights Reserved.
# Maintainers: Shizhan Lu <578752274@qq.com>.
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
import unittest
import jittor as jt
import numpy as np
from jittor.compile_extern import cusparse_ops
class TestSpmmCsrOp(unittest.TestCase):
@unittest.skipIf(not jt.has_cuda, "No CUDA support, skipping test")
@jt.flag_scope(use_cuda=1, lazy_execution=0)
def test_spmm_csr_forward_float32_int32(self):
x = jt.array([[3.0, 2.0, 1.0], [3.0, 2.0, 1.0], [3.0, 2.0, 1.0]], dtype="float32")
col_indices = jt.array([0, 1, 1, 2], dtype="int32")
row_offset = jt.array([0, 2, 3, 4], dtype="int32")
csr_weight = jt.array([3.0, 1.0, 4.0, 2.0], dtype="float32")
output = jt.zeros((3, 3), dtype="float32")
cusparse_ops.cusparse_spmmcsr(
output, x, col_indices, csr_weight, row_offset,
3, 3 ,False, False
).fetch_sync()
expected_output = np.array([
[12.0, 8.0, 4.0],
[12.0, 8.0, 4.0],
[6.0, 4.0, 2.0]
])
np.testing.assert_allclose(output.data, expected_output, atol=1e-5)
@unittest.skipIf(not jt.has_cuda, "No CUDA support, skipping test")
@jt.flag_scope(use_cuda=1, lazy_execution=0)
def test_spmm_csr_forward_float16_int32(self):
x = jt.array([[3.0, 2.0, 1.0], [3.0, 2.0, 1.0], [3.0, 2.0, 1.0]], dtype="float16")
col_indices = jt.array([0, 1, 1, 2], dtype="int32")
row_offset = jt.array([0, 2, 3, 4], dtype="int32")
csr_weight = jt.array([3.0, 1.0, 4.0, 2.0], dtype="float16")
output = jt.zeros((3, 3), dtype="float16")
cusparse_ops.cusparse_spmmcsr(
output, x, col_indices, csr_weight, row_offset,
3, 3,False, False
).fetch_sync()
expected_output = np.array([
[12.0, 8.0, 4.0],
[12.0, 8.0, 4.0],
[6.0, 4.0, 2.0]
], dtype="float16")
np.testing.assert_allclose(output.data, expected_output, atol=1e-5)
# @unittest.skipIf(not jt.has_cuda, "No CUDA support, skipping test")
# @jt.flag_scope(use_cuda=1, lazy_execution=0)
# def test_spmm_csr_forward_float64_int32(self):
# x = jt.array([[3.0, 2.0, 1.0], [3.0, 2.0, 1.0], [3.0, 2.0, 1.0]], dtype="float64")
# col_indices = jt.array([0, 1, 1, 2], dtype="int32")
# row_offset = jt.array([0, 2, 3, 4], dtype="int32")
# csr_weight = jt.array([3.0, 1.0, 4.0, 2.0], dtype="float64")
# output = jt.zeros((3, 3), dtype="float64")
# cusparse_ops.cusparse_spmmcsr(
# output, x, col_indices, csr_weight, row_offset,
# 3, 3,False, False
# ).fetch_sync()
# expected_output = np.array([
# [12.0, 8.0, 4.0],
# [12.0, 8.0, 4.0],
# [6.0, 4.0, 2.0]
# ], dtype="float64")
# np.testing.assert_allclose(output.data, expected_output, atol=1e-5)
@unittest.skipIf(not jt.has_cuda, "No CUDA support, skipping test")
@jt.flag_scope(use_cuda=1, lazy_execution=0)
def test_spmm_csr_forward_float32_int64(self):
x = jt.array([[3.0, 2.0, 1.0], [3.0, 2.0, 1.0], [3.0, 2.0, 1.0]], dtype="float32")
col_indices = jt.array([0, 1, 1, 2], dtype="int64")
row_offset = jt.array([0, 2, 3, 4], dtype="int64")
csr_weight = jt.array([3.0, 1.0, 4.0, 2.0], dtype="float32")
output = jt.zeros((3, 3), dtype="float32")
cusparse_ops.cusparse_spmmcsr(
output, x, col_indices, csr_weight, row_offset,
3, 3,False, False
).fetch_sync()
expected_output = np.array([
[12.0, 8.0, 4.0],
[12.0, 8.0, 4.0],
[6.0, 4.0, 2.0]
], dtype="float32")
np.testing.assert_allclose(output.data, expected_output, atol=1e-5)
@unittest.skipIf(not jt.has_cuda, "No CUDA support, skipping test")
@jt.flag_scope(use_cuda=1, lazy_execution=0)
def test_spmm_coo(self):
x=jt.array([[3.0, 2.0, 1.0],[4.0, 2.0, 2.0],[1.0, 2.0, 3.0]], dtype="float32")
edge_index=jt.array([[0,0,1,2],[1,2,2,1]],dtype="int32")
row_indices=edge_index[0,:]
col_indices=edge_index[1,:]
edge_weight = jt.array([1.0, 1.0, 1.0, 1.0], dtype="float32")
feature_dim=jt.size(x,1)
output=jt.zeros(3,feature_dim)
cusparse_ops.cusparse_spmmcoo(output,x,row_indices,col_indices,edge_weight,3,3,False, False).fetch_sync()
print("Output:", output)
expected_output = np.array([
[5.0, 4.0, 5.0],
[1.0, 2.0, 3.0],
[4.0, 2.0, 2.0]
], dtype="float32")
np.testing.assert_allclose(output.data, expected_output, atol=1e-5)
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,63 @@
import unittest
import jittor as jt
import numpy as np
class TestFinfo(unittest.TestCase):
def test(self):
for dtype in ['float16', 'float32', 'float64']:
finfo = jt.finfo(dtype)
np_finfo = np.finfo(dtype)
assert finfo.bits == np_finfo.bits
assert finfo.eps == np_finfo.eps
assert finfo.max == np_finfo.max
assert finfo.min == np_finfo.min
assert finfo.nexp == np_finfo.nexp
assert finfo.nmant == np_finfo.nmant
assert finfo.iexp == np_finfo.iexp
assert finfo.precision == np_finfo.precision
assert finfo.resolution == np_finfo.resolution
assert finfo.tiny == np_finfo.tiny
for dtype_jt, dtype in [
(jt.float16, 'float16'),
(jt.float32, 'float32'),
(jt.float64, 'float64'),
]:
finfo = jt.finfo(dtype_jt)
np_finfo = np.finfo(dtype)
assert finfo.bits == np_finfo.bits
assert finfo.eps == np_finfo.eps
assert finfo.max == np_finfo.max
assert finfo.min == np_finfo.min
assert finfo.nexp == np_finfo.nexp
assert finfo.nmant == np_finfo.nmant
assert finfo.iexp == np_finfo.iexp
assert finfo.precision == np_finfo.precision
assert finfo.resolution == np_finfo.resolution
assert finfo.tiny == np_finfo.tiny
class TestIinfo(unittest.TestCase):
def test(self):
for dtype in ['int16', 'int32', 'int64']:
iinfo = jt.iinfo(dtype)
np_iinfo = np.iinfo(dtype)
assert iinfo.bits == np_iinfo.bits
assert iinfo.max == np_iinfo.max
assert iinfo.min == np_iinfo.min
assert iinfo.dtype == np.dtype(dtype)
for dtype_jt, dtype in [
(jt.int16, 'int16'),
(jt.int32, 'int32'),
(jt.int64, 'int64'),
]:
iinfo = jt.iinfo(dtype_jt)
np_iinfo = np.iinfo(dtype)
assert iinfo.bits == np_iinfo.bits
assert iinfo.max == np_iinfo.max
assert iinfo.min == np_iinfo.min
assert iinfo.dtype == np.dtype(dtype)
if __name__ == "__main__":
unittest.main()

View File

@ -78,8 +78,11 @@ def check_cuda_env():
with open("/proc/self/cmdline", "r") as f:
argv = f.read().split("\x00")
if len(argv[-1]) == 0: del argv[-1]
LOG.i(f"restart {sys.executable} {argv[1:]}")
os.execl(sys.executable, sys.executable, *argv[1:])
if 'ipykernel_launcher' in argv:
LOG.i(f"needed restart but not {sys.executable} {argv[1:]}, you can ignore this warning.")
else:
LOG.i(f"restart {sys.executable} {argv[1:]}")
os.execl(sys.executable, sys.executable, *argv[1:])
except:
pass