Merge branch 'master' of github.com:Jittor/jittor

This commit is contained in:
Dun Liang 2022-03-30 14:36:09 +08:00
commit b8c3c82c40
11 changed files with 566 additions and 5 deletions

View File

@ -26,7 +26,7 @@ with lock.lock_scope():
from .compile_extern import mkl_ops, mpi, mpi_ops, in_mpi, rank, world_size
if core.get_device_count() == 0:
has_cuda = compile_extern.has_cuda = compiler.has_cuda = False
from .compile_extern import cudnn, curand, cublas
from .compile_extern import cudnn, curand, cublas, cufft
from .init_cupy import numpy2cupy
import contextlib

View File

@ -2,7 +2,7 @@ from jittor_core import *
from jittor_core.ops import *
from .misc import *
from . import attention as attention, contrib as contrib, dataset as dataset, init as init, linalg as linalg, lr_scheduler as lr_scheduler, numpy2cupy as numpy2cupy, optim as optim, sparse as sparse
from .compile_extern import cublas as cublas, cudnn as cudnn, curand as curand, mkl_ops as mkl_ops, mpi_ops as mpi_ops, world_size as world_size
from .compile_extern import cublas as cublas, cudnn as cudnn, curand as curand, cufft as cufft, mkl_ops as mkl_ops, mpi_ops as mpi_ops, world_size as world_size
from .compiler import compile_custom_op as compile_custom_op, compile_custom_ops as compile_custom_ops
from .contrib import concat as concat
from .nn import bmm as bmm, bmm_transpose as bmm_transpose, matmul as matmul

View File

@ -219,7 +219,7 @@ def setup_cuda_extern():
line = traceback.format_exc()
LOG.w(f"CUDA found but cub is not loaded:\n{line}")
libs = ["cublas", "cudnn", "curand"]
libs = ["cublas", "cudnn", "curand", "cufft"]
# in cuda 11.4, module memory comsumptions:
# default context: 259 MB
# cublas: 340 MB
@ -566,7 +566,7 @@ if os.environ.get("FIX_TORCH_ERROR", "0") == "1":
except:
pass
cudnn = cublas = curand = None
cudnn = cublas = curand = cufft = None
setup_mpi()
in_mpi = inside_mpi()
rank = mpi.world_rank() if in_mpi else 0

View File

@ -0,0 +1,102 @@
/*
* Copyright 2020 NVIDIA Corporation. All rights reserved.
*
* NOTICE TO LICENSEE:
*
* This source code and/or documentation ("Licensed Deliverables") are
* subject to NVIDIA intellectual property rights under U.S. and
* international Copyright laws.
*
* These Licensed Deliverables contained herein is PROPRIETARY and
* CONFIDENTIAL to NVIDIA and is being provided under the terms and
* conditions of a form of NVIDIA software license agreement by and
* between NVIDIA and Licensee ("License Agreement") or electronically
* accepted by Licensee. Notwithstanding any terms or conditions to
* the contrary in the License Agreement, reproduction or disclosure
* of the Licensed Deliverables to any third party without the express
* written consent of NVIDIA is prohibited.
*
* NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
* LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE
* SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS
* PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND.
* NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED
* DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY,
* NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
* NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
* LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY
* SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY
* DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
* WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS
* ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE
* OF THESE LICENSED DELIVERABLES.
*
* U.S. Government End Users. These Licensed Deliverables are a
* "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT
* 1995), consisting of "commercial computer software" and "commercial
* computer software documentation" as such terms are used in 48
* C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government
* only as a commercial end item. Consistent with 48 C.F.R.12.212 and
* 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all
* U.S. Government End Users acquire the Licensed Deliverables with
* only those rights set forth herein.
*
* Any use of the Licensed Deliverables in individual and commercial
* software must include, in the user documentation and internal
* comments to the code, the above Disclaimer and U.S. Government End
* Users Notice.
*/
#pragma once
// CUDA API error checking
#ifndef CUDA_RT_CALL
#define CUDA_RT_CALL( call ) \
{ \
auto status = static_cast<cudaError_t>( call ); \
if ( status != cudaSuccess ) \
fprintf( stderr, \
"ERROR: CUDA RT call \"%s\" in line %d of file %s failed " \
"with " \
"%s (%d).\n", \
#call, \
__LINE__, \
__FILE__, \
cudaGetErrorString( status ), \
status ); \
}
#endif // CUDA_RT_CALL
// cufft API error chekcing
#ifndef CUFFT_CALL
#define CUFFT_CALL( call ) \
{ \
auto status = static_cast<cufftResult>( call ); \
if ( status != CUFFT_SUCCESS ) \
fprintf( stderr, \
"ERROR: CUFFT call \"%s\" in line %d of file %s failed " \
"with " \
"code (%d).\n", \
#call, \
__LINE__, \
__FILE__, \
status ); \
}
#endif // CUFFT_CALL
// template <> struct traits<CUFFT_C2C> {
// // scalar type
// typedef float T;
// using input_host_type = std::complex<T>;
// using input_device_type = cufftComplex;
// using output_host_type = std::complex<T>;
// using output_device_type = cufftComplex;
// static constexpr cufftType_t transformType = CUDA_R_64F;
// template <typename RNG> inline static T rand(RNG &gen) {
// return make_cuFloatComplex((S)gen(), (S)gen());
// }
// };

View File

@ -0,0 +1,24 @@
// ***************************************************************
// Copyright (c) 2021 Jittor. All Rights Reserved.
// Maintainers:
// Guoye Yang <498731903@qq.com>.
// Dun Liang <randonlang@gmail.com>.
//
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#pragma once
#include <cuda_runtime.h>
#include <cufftXt.h>
#include "cufft_utils.h"
#include "utils/log.h"
#include "helper_cuda.h"
#include "fp16_emu.h"
#include "common.h"
namespace jittor {
EXTERN_LIB unordered_map<string, cufftHandle> cufft_handle_cache;
} // jittor

View File

@ -0,0 +1,101 @@
// ***************************************************************
// Copyright (c) 2021 Jittor. All Rights Reserved.
// Maintainers:
// Guoye Yang <498731903@qq.com>.
// Dun Liang <randonlang@gmail.com>.
//
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#include "var.h"
#include "init.h"
#include <cuda_runtime.h>
#include <cufft.h>
#include "helper_cuda.h"
#include "cufft_fft_op.h"
#include "cufft_wrapper.h"
#include <complex>
#include <iostream>
#include <random>
#include <vector>
#include <cufftXt.h>
#include "cufft_utils.h"
#include "ops/op_register.h"
namespace jittor {
#ifndef JIT
static auto make_cufft_fft = get_op_info("cufft_fft")
.get_constructor<VarPtr, Var*, bool>();
CufftFftOp::CufftFftOp(Var* x, bool inverse) : x(x), inverse(inverse) {
flags.set(NodeFlags::_cuda, 1);
y = create_output(x->shape, x->dtype());
}
VarPtr CufftFftOp::grad(Var* out, Var* dout, Var* v, int v_index) {
return make_cufft_fft(dout, !inverse);
}
void CufftFftOp::jit_prepare(JK& jk) {
if ((y->dtype() != "float32") && (y->dtype() != "float64")){
printf("not supported fft dtype: %s\n", y->dtype().to_cstring());
ASSERT(false);
}
jk << _CS("[T:") << y->dtype();
jk << _CS("][I:")<<inverse<<"]";
jk << _CS("[TS:\"")<<y->dtype()<<"\"]";
}
#else // JIT
#ifdef JIT_cpu
void CufftFftOp::jit_run() {
}
#else // JIT_cuda
void CufftFftOp::jit_run() {
auto* __restrict__ xp = x->mem_ptr;
auto* __restrict__ yp = y->mem_ptr;
int batch_size = x->shape[0];
int n1 = x->shape[1], n2 = x->shape[2];
int fft_size = batch_size * n1 * n2;
std::array<int, 2> fft = {n1, n2};
auto op_type = CUFFT_C2C;
if (TS == "float32") {
op_type = CUFFT_C2C;
} else if (TS == "float64") {
op_type = CUFFT_Z2Z;
}
JK& jk = get_jk();
jk.clear();
jk << fft[0] << "," << fft[1] << "," << TS << "," << batch_size;
auto iter = cufft_handle_cache.find(jk.to_string());
cufftHandle plan;
if (iter!=cufft_handle_cache.end()) plan = iter->second;
else {
CUFFT_CALL(cufftCreate(&plan));
CUFFT_CALL(cufftPlanMany(&plan, 2, fft.data(),
nullptr, 1, fft[0] * fft[1], // *inembed, istride, idist
nullptr, 1, fft[0] * fft[1], // *onembed, ostride, odist
op_type, batch_size));
CUFFT_CALL(cufftSetStream(plan, 0));
cufft_handle_cache[jk.to_string()] = plan;
}
/*
* Note:
* Identical pointers to data and output arrays implies in-place transformation
*/
if (TS == "float32") {
CUFFT_CALL(cufftExecC2C(plan, (cufftComplex *)xp, (cufftComplex *)yp, I ? CUFFT_INVERSE : CUFFT_FORWARD));
} else if (TS == "float64") {
CUFFT_CALL(cufftExecZ2Z(plan, (cufftDoubleComplex *)xp, (cufftDoubleComplex *)yp, I ? CUFFT_INVERSE : CUFFT_FORWARD));
}
}
#endif // JIT_cpu
#endif // JIT
} // jittor

View File

@ -0,0 +1,27 @@
// ***************************************************************
// Copyright (c) 2021 Jittor. All Rights Reserved.
// Maintainers:
// Guoye Yang <498731903@qq.com>.
// Dun Liang <randonlang@gmail.com>.
//
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#pragma once
#include "op.h"
namespace jittor {
//TODO: support FFT2D only now.
struct CufftFftOp : Op {
bool inverse;
Var* x, * y;
NanoString type;
CufftFftOp(Var* x, bool inverse=false);
VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override;
const char* name() const override { return "cufft_fft"; }
DECLARE_jit_run;
};
} // jittor

View File

@ -0,0 +1,35 @@
// ***************************************************************
// Copyright (c) 2021 Jittor. All Rights Reserved.
// Maintainers:
// Guoye Yang <498731903@qq.com>.
// Dun Liang <randonlang@gmail.com>.
//
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#include "cufft_wrapper.h"
#include "misc/cuda_flags.h"
namespace jittor {
unordered_map<string, cufftHandle> cufft_handle_cache;
struct cufft_initer {
inline cufft_initer() {
if (!get_device_count()) return;
LOGv << "cufftCreate finished";
}
inline ~cufft_initer() {
if (!get_device_count()) return;
for (auto it = cufft_handle_cache.begin(); it != cufft_handle_cache.end(); it++) {
CUFFT_CALL(cufftDestroy(it->second));
}
cufft_handle_cache.clear();
LOGv << "cufftDestroy finished";
}
} init;
} // jittor

View File

@ -2789,3 +2789,13 @@ class Bilinear(Module):
def execute(self, in1, in2):
return bilinear(in1, in2, self.weight, self.bias)
#TODO: support FFT2D only now.
def _fft2(x, inverse=False):
assert(jt.flags.use_cuda==1)
assert(len(x.shape) == 4)
assert(x.shape[3] == 2)
y = jt.compile_extern.cufft_ops.cufft_fft(x, inverse)
if inverse:
y /= x.shape[1] * x.shape[2]
return y

View File

@ -263,7 +263,7 @@ def generate_error_code_from_func_header(func_head, target_scope_name, name, dfs
help_name = ""+target_scope_name+'.'+name
else:
help_name = name
if lib_name in ["mpi", "nccl", "cudnn", "curand", "cublas", "mkl"]:
if lib_name in ["mpi", "nccl", "cudnn", "curand" "cufft", "cublas", "mkl"]:
help_name = lib_name+'.'+help_name
help_cmd = f"help(jt.{help_name})"

View File

@ -0,0 +1,262 @@
# ***************************************************************
# Copyright (c) 2021 Jittor. All Rights Reserved.
# Maintainers:
# Guoye Yang <498731903@qq.com>
# Dun Liang <randonlang@gmail.com>.
#
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
import jittor as jt
import unittest
from .test_log import find_log_with_re
import torch # torch >= 1.9.0 needed
import numpy as np
from jittor import nn
#requires torch>=1.10.1
class TestFFTOp(unittest.TestCase):
@unittest.skipIf(not jt.has_cuda, "Cuda not found")
@jt.flag_scope(use_cuda=1)
def test_fft_forward(self):
img = np.random.rand(256, 300)
img2 = np.random.rand(256, 300)
X = np.stack([img, img2], 0)
# torch
x = torch.Tensor(X)
y = torch.fft.fft2(x)
y_torch_real = y.numpy().real
y_torch_imag = y.numpy().imag
#jittor
x = jt.array(X,dtype=jt.float32)
x = jt.stack([x, jt.zeros_like(x)], 3)
y = nn._fft2(x)
y_jt_real = y[:, :, :, 0].data
y_jt_imag = y[:, :, :, 1].data
assert(np.allclose(y_torch_real, y_jt_real, atol=1))
assert(np.allclose(y_torch_imag, y_jt_imag, atol=1))
@unittest.skipIf(not jt.has_cuda, "Cuda not found")
@jt.flag_scope(use_cuda=1)
def test_ifft_forward(self):
img = np.random.rand(256, 300)
img2 = np.random.rand(256, 300)
X = np.stack([img, img2], 0)
# torch
x = torch.Tensor(X)
y = torch.fft.fft2(x)
y_torch_real = y.numpy().real
y_torch_imag = y.numpy().imag
y_ori = torch.fft.ifft2(y)
y_ori_torch_real = y_ori.real.numpy()
assert(np.allclose(y_ori_torch_real, X, atol=1))
#jittor
x = jt.array(X,dtype=jt.float32)
x = jt.stack([x, jt.zeros_like(x)], 3)
y = nn._fft2(x)
y_ori = nn._fft2(y, True)
y_jt_real = y[:, :, :, 0].data
y_jt_imag = y[:, :, :, 1].data
y_ori_jt_real = y_ori[:, :, :, 0].data
assert(np.allclose(y_torch_real, y_jt_real, atol=1))
assert(np.allclose(y_torch_imag, y_jt_imag, atol=1))
assert(np.allclose(y_ori_jt_real, X, atol=1))
assert(np.allclose(y_ori_jt_real, y_ori_torch_real, atol=1))
@unittest.skipIf(not jt.has_cuda, "Cuda not found")
@jt.flag_scope(use_cuda=1)
def test_fft_backward(self):
img = np.random.rand(256, 300)
img2 = np.random.rand(256, 300)
X = np.stack([img, img2], 0)
T1 = np.random.rand(1,256,300)
T2 = np.random.rand(1,256,300)
# torch
x = torch.Tensor(X)
x.requires_grad = True
t1 = torch.Tensor(T1)
t2 = torch.Tensor(T2)
y_mid = torch.fft.fft2(x)
y = torch.fft.fft2(y_mid)
real = y.real
imag = y.imag
loss = (real * t1).sum() + (imag * t2).sum()
loss.backward()
grad_x_torch = x.grad.detach().numpy()
#jittor
x = jt.array(X,dtype=jt.float32)
t1 = jt.array(T1,dtype=jt.float32)
t2 = jt.array(T2,dtype=jt.float32)
x = jt.stack([x, jt.zeros_like(x)], 3)
y_mid = nn._fft2(x)
y = nn._fft2(y_mid)
real = y[:, :, :, 0]
imag = y[:, :, :, 1]
loss = (real * t1).sum() + (imag * t2).sum()
grad_x_jt = jt.grad(loss, x).data[:, :, :, 0]
assert(np.allclose(grad_x_jt, grad_x_torch, atol=1))
@unittest.skipIf(not jt.has_cuda, "Cuda not found")
@jt.flag_scope(use_cuda=1)
def test_ifft_backward(self):
img = np.random.rand(256, 300)
img2 = np.random.rand(256, 300)
X = np.stack([img, img2], 0)
T1 = np.random.rand(1,256,300)
T2 = np.random.rand(1,256,300)
# torch
x = torch.Tensor(X)
x.requires_grad = True
t1 = torch.Tensor(T1)
t2 = torch.Tensor(T2)
y_mid = torch.fft.ifft2(x)
y = torch.fft.ifft2(y_mid)
real = y.real
imag = y.imag
loss = (real * t1).sum() + (imag * t2).sum()
loss.backward()
grad_x_torch = x.grad.detach().numpy()
#jittor
x = jt.array(X,dtype=jt.float32)
t1 = jt.array(T1,dtype=jt.float32)
t2 = jt.array(T2,dtype=jt.float32)
x = jt.stack([x, jt.zeros_like(x)], 3)
y_mid = nn._fft2(x, True)
y = nn._fft2(y_mid, True)
real = y[:, :, :, 0]
imag = y[:, :, :, 1]
loss = (real * t1).sum() + (imag * t2).sum()
grad_x_jt = jt.grad(loss, x).data[:, :, :, 0]
assert(np.allclose(grad_x_jt, grad_x_torch))
@unittest.skipIf(not jt.has_cuda, "Cuda not found")
@jt.flag_scope(use_cuda=1)
def test_fft_float64_forward(self):
img = np.random.rand(256, 300)
img2 = np.random.rand(256, 300)
X = np.stack([img, img2], 0)
# torch
x = torch.DoubleTensor(X)
y = torch.fft.fft2(x)
y_torch_real = y.numpy().real
y_torch_imag = y.numpy().imag
#jittor
x = jt.array(X).float64()
x = jt.stack([x, jt.zeros_like(x)], 3)
y = nn._fft2(x)
y_jt_real = y[:, :, :, 0].data
y_jt_imag = y[:, :, :, 1].data
assert(np.allclose(y_torch_real, y_jt_real, atol=1))
assert(np.allclose(y_torch_imag, y_jt_imag, atol=1))
@unittest.skipIf(not jt.has_cuda, "Cuda not found")
@jt.flag_scope(use_cuda=1)
def test_ifft_float64_forward(self):
img = np.random.rand(256, 300)
img2 = np.random.rand(256, 300)
X = np.stack([img, img2], 0)
# torch
x = torch.DoubleTensor(X)
y = torch.fft.fft2(x)
y_torch_real = y.numpy().real
y_torch_imag = y.numpy().imag
y_ori = torch.fft.ifft2(y)
y_ori_torch_real = y_ori.real.numpy()
assert(np.allclose(y_ori_torch_real, X, atol=1))
#jittor
x = jt.array(X).float64()
x = jt.stack([x, jt.zeros_like(x)], 3)
y = nn._fft2(x)
y_ori = nn._fft2(y, True)
y_jt_real = y[:, :, :, 0].data
y_jt_imag = y[:, :, :, 1].data
y_ori_jt_real = y_ori[:, :, :, 0].data
assert(np.allclose(y_torch_real, y_jt_real, atol=1))
assert(np.allclose(y_torch_imag, y_jt_imag, atol=1))
assert(np.allclose(y_ori_jt_real, X, atol=1))
assert(np.allclose(y_ori_jt_real, y_ori_torch_real, atol=1))
@unittest.skipIf(not jt.has_cuda, "Cuda not found")
@jt.flag_scope(use_cuda=1)
def test_fft_float64_backward(self):
img = np.random.rand(256, 300)
img2 = np.random.rand(256, 300)
X = np.stack([img, img2], 0)
T1 = np.random.rand(1,256,300)
T2 = np.random.rand(1,256,300)
# torch
x = torch.DoubleTensor(X)
x.requires_grad = True
t1 = torch.DoubleTensor(T1)
t2 = torch.DoubleTensor(T2)
y_mid = torch.fft.fft2(x)
y = torch.fft.fft2(y_mid)
real = y.real
imag = y.imag
loss = (real * t1).sum() + (imag * t2).sum()
loss.backward()
grad_x_torch = x.grad.detach().numpy()
#jittor
x = jt.array(X).float64()
t1 = jt.array(T1).float64()
t2 = jt.array(T2).float64()
x = jt.stack([x, jt.zeros_like(x)], 3)
y_mid = nn._fft2(x)
y = nn._fft2(y_mid)
real = y[:, :, :, 0]
imag = y[:, :, :, 1]
loss = (real * t1).sum() + (imag * t2).sum()
grad_x_jt = jt.grad(loss, x).data[:, :, :, 0]
assert(np.allclose(grad_x_jt, grad_x_torch, atol=1))
@unittest.skipIf(not jt.has_cuda, "Cuda not found")
@jt.flag_scope(use_cuda=1)
def test_ifft_float64_backward(self):
img = np.random.rand(256, 300)
img2 = np.random.rand(256, 300)
X = np.stack([img, img2], 0)
T1 = np.random.rand(1,256,300)
T2 = np.random.rand(1,256,300)
# torch
x = torch.DoubleTensor(X)
x.requires_grad = True
t1 = torch.DoubleTensor(T1)
t2 = torch.DoubleTensor(T2)
y_mid = torch.fft.ifft2(x)
y = torch.fft.ifft2(y_mid)
real = y.real
imag = y.imag
loss = (real * t1).sum() + (imag * t2).sum()
loss.backward()
grad_x_torch = x.grad.detach().numpy()
#jittor
x = jt.array(X).float64()
t1 = jt.array(T1).float64()
t2 = jt.array(T2).float64()
x = jt.stack([x, jt.zeros_like(x)], 3)
y_mid = nn._fft2(x, True)
y = nn._fft2(y_mid, True)
real = y[:, :, :, 0]
imag = y[:, :, :, 1]
loss = (real * t1).sum() + (imag * t2).sum()
grad_x_jt = jt.grad(loss, x).data[:, :, :, 0]
assert(np.allclose(grad_x_jt, grad_x_torch))
if __name__ == "__main__":
unittest.main()