mirror of https://github.com/Jittor/Jittor
Merge remote-tracking branch 'upstream/master'
This commit is contained in:
commit
746794064a
|
@ -9,7 +9,7 @@
|
|||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
||||
__version__ = '1.2.3.22'
|
||||
__version__ = '1.2.3.34'
|
||||
from jittor_utils import lock
|
||||
with lock.lock_scope():
|
||||
ori_int = int
|
||||
|
@ -437,11 +437,11 @@ def pow(x, y):
|
|||
Var.pow = Var.__pow__ = pow
|
||||
|
||||
def argmax(x, dim, keepdims:bool=False):
|
||||
return x.arg_reduce("max", dim, keepdims)
|
||||
return jt.arg_reduce(x, "max", dim, keepdims)
|
||||
Var.argmax = argmax
|
||||
|
||||
def argmin(x, dim, keepdims:bool=False):
|
||||
return x.arg_reduce("min", dim, keepdims)
|
||||
return jt.arg_reduce(x, "min", dim, keepdims)
|
||||
Var.argmin = argmin
|
||||
|
||||
def randn(*size, dtype="float32", requires_grad=True) -> Var:
|
||||
|
|
|
@ -23,29 +23,35 @@ def search_file(dirs, name, prefer_version=()):
|
|||
def install_mkl(root_folder):
|
||||
# origin url is
|
||||
# url = "https://github.com/intel/mkl-dnn/releases/download/v1.0.2/mkldnn_lnx_1.0.2_cpu_gomp.tgz"
|
||||
url = "https://cloud.tsinghua.edu.cn/f/da02bf62b55b4aa3b8ee/?dl=1"
|
||||
filename = "mkldnn_lnx_1.0.2_cpu_gomp.tgz"
|
||||
# newest version for oneDNN
|
||||
# url = "https://github.com/oneapi-src/oneDNN/releases/download/v2.2/dnnl_lnx_2.2.0_cpu_gomp.tgz"
|
||||
# filename = "dnnl_lnx_2.2.0_cpu_gomp.tgz"
|
||||
import platform
|
||||
if platform.system()=="Linux":
|
||||
if platform.machine()=='x86_64':
|
||||
filename = "dnnl_lnx_2.2.0_cpu_gomp.tgz"
|
||||
md5 = "35bbbdf550a9d8ad54db798e372000f6"
|
||||
elif platform.machine()=='aarch64':
|
||||
filename = "dnnl_lnx_2.2.0_cpu_gomp_aarch64.tgz"
|
||||
md5 = "72cf9b0b8fd6c3c786d35a9daaee22b8"
|
||||
else:
|
||||
raise RuntimeError(f"platform.machine()=={platform.machine()} not support yet,"
|
||||
" Please contact us on https://github.com/jittor/jittor ")
|
||||
else:
|
||||
raise RuntimeError(f"platform.machine()=={platform.machine()} not support yet,"
|
||||
" Please contact us on https://github.com/jittor/jittor ")
|
||||
|
||||
url = "https://cg.cs.tsinghua.edu.cn/jittor/assets/" + filename
|
||||
fullname = os.path.join(root_folder, filename)
|
||||
dirname = os.path.join(root_folder, filename.replace(".tgz",""))
|
||||
|
||||
if not os.path.isfile(os.path.join(dirname, "examples", "test")):
|
||||
if not os.path.isfile(os.path.join(dirname, "lib", "libmkldnn.so")):
|
||||
LOG.i("Downloading mkl...")
|
||||
download_url_to_local(url, filename, root_folder, "47187284ede27ad3bd64b5f0e7d5e730")
|
||||
# newest version for oneDNN
|
||||
# download_url_to_local(url, filename, root_folder, "35bbbdf550a9d8ad54db798e372000f6")
|
||||
download_url_to_local(url, filename, root_folder, md5)
|
||||
import tarfile
|
||||
|
||||
with tarfile.open(fullname, "r") as tar:
|
||||
tar.extractall(root_folder)
|
||||
|
||||
assert 0 == os.system(f"cd {dirname}/examples && "
|
||||
f"{cc_path} -std=c++14 cpu_cnn_inference_f32.cpp -Ofast -lmkldnn -I ../include -L ../lib -o test && LD_LIBRARY_PATH=../lib/ ./test")
|
||||
# newest version for oneDNN
|
||||
# assert 0 == os.system(f"cd {dirname}/examples && "
|
||||
# f"{cc_path} -std=c++14 cnn_inference_f32.cpp -Ofast -lmkldnn -I ../include -L ../lib -o test && LD_LIBRARY_PATH=../lib/ ./test")
|
||||
f"{cc_path} -std=c++14 cnn_inference_f32.cpp -Ofast -lmkldnn -I ../include -L ../lib -o test && LD_LIBRARY_PATH=../lib/ ./test")
|
||||
|
||||
def setup_mkl():
|
||||
global mkl_ops, use_mkl
|
||||
|
@ -80,7 +86,7 @@ def setup_mkl():
|
|||
install_mkl(mkl_path)
|
||||
mkl_home = ""
|
||||
for name in os.listdir(mkl_path):
|
||||
if name.startswith("mkldnn_lnx") and os.path.isdir(os.path.join(mkl_path, name)):
|
||||
if name.startswith("dnnl") and os.path.isdir(os.path.join(mkl_path, name)):
|
||||
mkl_home = os.path.join(mkl_path, name)
|
||||
break
|
||||
assert mkl_home!=""
|
||||
|
@ -197,8 +203,14 @@ def setup_cuda_lib(lib_name, link=True, extra_flags=""):
|
|||
|
||||
if lib_name == "cublas" and nvcc_version[0] >= 10:
|
||||
# manual link libcublasLt.so
|
||||
cublas_lt_lib_path = search_file([cuda_lib, extra_lib_path, "/usr/lib/x86_64-linux-gnu", "/usr/lib"], f"libcublasLt.so", nvcc_version)
|
||||
ctypes.CDLL(cublas_lt_lib_path, dlopen_flags)
|
||||
try:
|
||||
cublas_lt_lib_path = search_file([cuda_lib, extra_lib_path, "/usr/lib/x86_64-linux-gnu", "/usr/lib"], f"libcublasLt.so", nvcc_version)
|
||||
ctypes.CDLL(cublas_lt_lib_path, dlopen_flags)
|
||||
except:
|
||||
# some aarch64 os, such as uos with FT2000 cpu,
|
||||
# it's cuda 10 doesn't have libcublasLt.so
|
||||
pass
|
||||
|
||||
|
||||
|
||||
if lib_name == "cudnn":
|
||||
|
|
|
@ -12,6 +12,7 @@ import inspect
|
|||
import datetime
|
||||
import threading
|
||||
import ctypes
|
||||
import platform
|
||||
from ctypes import cdll
|
||||
from ctypes.util import find_library
|
||||
|
||||
|
@ -634,7 +635,7 @@ def compile_custom_ops(
|
|||
if gen_name_ != "":
|
||||
gen_name = gen_name_
|
||||
if len(gen_name) > 100:
|
||||
gen_name = gen_name[:80] + "___hash" + str(hash(gen_name))
|
||||
gen_name = gen_name[:80] + "___hash" + str(abs(hash(gen_name)))
|
||||
|
||||
includes = sorted(list(set(includes)))
|
||||
includes = "".join(map(lambda x: f" -I'{x}' ", includes))
|
||||
|
@ -1038,6 +1039,8 @@ if os.path.isfile(version_file) and not os.path.isdir(os.path.join(jittor_path,
|
|||
os_key = os_type.get(os_id, "ubuntu")
|
||||
if "os_key" in os.environ:
|
||||
os_key = os.environ['os_key']
|
||||
if platform.machine()=='aarch64':
|
||||
os_key += '-aarch64'
|
||||
LOG.i("OS type:", os_id, " OS key:", os_key)
|
||||
key += '-' + os_key + '.o'
|
||||
# TODO: open the website
|
||||
|
@ -1049,7 +1052,7 @@ if os.path.isfile(version_file) and not os.path.isdir(os.path.join(jittor_path,
|
|||
compile(cc_path, cc_flags+opt_flags, files, 'jittor_core'+extension_suffix)
|
||||
|
||||
# TODO: move to compile_extern.py
|
||||
compile_extern()
|
||||
# compile_extern()
|
||||
|
||||
with jit_utils.import_scope(import_flags):
|
||||
import jittor_core as core
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
|
||||
from .dataset import Dataset, ImageFolder
|
||||
from .dataset import Dataset, ImageFolder, dataset_root
|
||||
from .mnist import MNIST
|
||||
from .cifar import CIFAR10, CIFAR100
|
||||
from .voc import VOC
|
||||
from .sampler import *
|
|
@ -0,0 +1,189 @@
|
|||
|
||||
import os
|
||||
from jittor_utils.misc import download_and_extract_archive, check_integrity
|
||||
from PIL import Image
|
||||
import sys, pickle
|
||||
import numpy as np
|
||||
from jittor.dataset import Dataset, dataset_root
|
||||
|
||||
class CIFAR10(Dataset):
|
||||
"""`CIFAR10 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset.
|
||||
|
||||
Args:
|
||||
root (string): Root directory of dataset where directory
|
||||
``cifar-10-batches-py`` exists or will be saved to if download is set to True.
|
||||
train (bool, optional): If True, creates dataset from training set, otherwise
|
||||
creates from test set.
|
||||
transform (callable, optional): A function/transform that takes in an PIL image
|
||||
and returns a transformed version. E.g, ``transforms.RandomCrop``
|
||||
target_transform (callable, optional): A function/transform that takes in the
|
||||
target and transforms it.
|
||||
download (bool, optional): If true, downloads the dataset from the internet and
|
||||
puts it in root directory. If dataset is already downloaded, it is not
|
||||
downloaded again.
|
||||
|
||||
Example::
|
||||
|
||||
|
||||
from jittor.dataset.cifar import CIFAR10
|
||||
a = CIFAR10()
|
||||
a.set_attrs(batch_size=16)
|
||||
for imgs, labels in a:
|
||||
print(imgs.shape, labels.shape)
|
||||
break
|
||||
|
||||
"""
|
||||
base_folder = 'cifar-10-batches-py'
|
||||
url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
|
||||
filename = "cifar-10-python.tar.gz"
|
||||
tgz_md5 = 'c58f30108f718f92721af3b95e74349a'
|
||||
train_list = [
|
||||
['data_batch_1', 'c99cafc152244af753f735de768cd75f'],
|
||||
['data_batch_2', 'd4bba439e000b95fd0a9bffe97cbabec'],
|
||||
['data_batch_3', '54ebc095f3ab1f0389bbae665268c751'],
|
||||
['data_batch_4', '634d18415352ddfa80567beed471001a'],
|
||||
['data_batch_5', '482c414d41f54cd18b22e5b47cb7c3cb'],
|
||||
]
|
||||
|
||||
test_list = [
|
||||
['test_batch', '40351d587109b95175f43aff81a1287e'],
|
||||
]
|
||||
meta = {
|
||||
'filename': 'batches.meta',
|
||||
'key': 'label_names',
|
||||
'md5': '5ff9c542aee3614f3951f8cda6e48888',
|
||||
}
|
||||
|
||||
def __init__(self, root=dataset_root+"/cifar_data/", train=True, transform=None, target_transform=None,
|
||||
download=True):
|
||||
|
||||
super(CIFAR10, self).__init__()
|
||||
self.root = root
|
||||
self.transform=transform
|
||||
self.target_transform=target_transform
|
||||
|
||||
self.train = train # training set or test set
|
||||
|
||||
if download:
|
||||
self.download()
|
||||
|
||||
if not self._check_integrity():
|
||||
raise RuntimeError('Dataset not found or corrupted.' +
|
||||
' You can use download=True to download it')
|
||||
|
||||
if self.train:
|
||||
downloaded_list = self.train_list
|
||||
else:
|
||||
downloaded_list = self.test_list
|
||||
|
||||
self.data = []
|
||||
self.targets = []
|
||||
|
||||
# now load the picked numpy arrays
|
||||
for file_name, checksum in downloaded_list:
|
||||
file_path = os.path.join(self.root, self.base_folder, file_name)
|
||||
with open(file_path, 'rb') as f:
|
||||
if sys.version_info[0] == 2:
|
||||
entry = pickle.load(f)
|
||||
else:
|
||||
entry = pickle.load(f, encoding='latin1')
|
||||
self.data.append(entry['data'])
|
||||
if 'labels' in entry:
|
||||
self.targets.extend(entry['labels'])
|
||||
else:
|
||||
self.targets.extend(entry['fine_labels'])
|
||||
|
||||
self.data = np.vstack(self.data).reshape(-1, 3, 32, 32)
|
||||
self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC
|
||||
|
||||
self._load_meta()
|
||||
|
||||
def _load_meta(self):
|
||||
path = os.path.join(self.root, self.base_folder, self.meta['filename'])
|
||||
if not check_integrity(path, self.meta['md5']):
|
||||
raise RuntimeError('Dataset metadata file not found or corrupted.' +
|
||||
' You can use download=True to download it')
|
||||
with open(path, 'rb') as infile:
|
||||
if sys.version_info[0] == 2:
|
||||
data = pickle.load(infile)
|
||||
else:
|
||||
data = pickle.load(infile, encoding='latin1')
|
||||
self.classes = data[self.meta['key']]
|
||||
self.class_to_idx = {_class: i for i, _class in enumerate(self.classes)}
|
||||
|
||||
def __getitem__(self, index):
|
||||
"""
|
||||
Args:
|
||||
index (int): Index
|
||||
|
||||
Returns:
|
||||
tuple: (image, target) where target is index of the target class.
|
||||
"""
|
||||
img, target = self.data[index], self.targets[index]
|
||||
|
||||
# doing this so that it is consistent with all other datasets
|
||||
# to return a PIL Image
|
||||
img = Image.fromarray(img)
|
||||
|
||||
if self.transform is not None:
|
||||
img = self.transform(img)
|
||||
|
||||
if self.target_transform is not None:
|
||||
target = self.target_transform(target)
|
||||
|
||||
return img, target
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def _check_integrity(self):
|
||||
root = self.root
|
||||
for fentry in (self.train_list + self.test_list):
|
||||
filename, md5 = fentry[0], fentry[1]
|
||||
fpath = os.path.join(root, self.base_folder, filename)
|
||||
if not check_integrity(fpath, md5):
|
||||
return False
|
||||
return True
|
||||
|
||||
def download(self):
|
||||
if self._check_integrity():
|
||||
print('Files already downloaded and verified')
|
||||
return
|
||||
download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.tgz_md5)
|
||||
|
||||
def extra_repr(self):
|
||||
return "Split: {}".format("Train" if self.train is True else "Test")
|
||||
|
||||
|
||||
class CIFAR100(CIFAR10):
|
||||
"""`CIFAR100 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset.
|
||||
|
||||
This is a subclass of the `CIFAR10` Dataset.
|
||||
|
||||
|
||||
Example::
|
||||
|
||||
|
||||
from jittor.dataset.cifar import CIFAR100
|
||||
a = CIFAR100()
|
||||
a.set_attrs(batch_size=16)
|
||||
for imgs, labels in a:
|
||||
print(imgs.shape, labels.shape)
|
||||
break
|
||||
"""
|
||||
base_folder = 'cifar-100-python'
|
||||
url = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz"
|
||||
filename = "cifar-100-python.tar.gz"
|
||||
tgz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85'
|
||||
train_list = [
|
||||
['train', '16019d7e3df5f24257cddd939b257f8d'],
|
||||
]
|
||||
|
||||
test_list = [
|
||||
['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc'],
|
||||
]
|
||||
meta = {
|
||||
'filename': 'meta',
|
||||
'key': 'fine_label_names',
|
||||
'md5': '7973b15100ade9c7d40fb424638fde48',
|
||||
}
|
|
@ -29,18 +29,7 @@ kernel(in0->num/in0->shape[in0->shape.size()-1], 0, in0_p, out0_p, in0->shape[in
|
|||
|
||||
class OneHotCategorical:
|
||||
def __init__(self, probs=None, logits=None):
|
||||
assert not (probs is None and logits is None)
|
||||
if probs is None:
|
||||
# cannot align to pytorch
|
||||
probs = jt.sigmoid(logits)
|
||||
elif logits is None:
|
||||
logits = jt.log(probs)
|
||||
with jt.no_grad():
|
||||
self.probs = probs / probs.sum(-1, True)
|
||||
self.cum_probs = simple_presum(self.probs)
|
||||
self.cum_probs_l = self.cum_probs[..., :-1]
|
||||
self.cum_probs_r = self.cum_probs[..., 1:]
|
||||
self.logits = logits
|
||||
Categorical.__init__(self, probs, logits)
|
||||
|
||||
def sample(self, sample_shape=[]):
|
||||
shape = sample_shape + self.probs.shape[:-1] + (1,)
|
||||
|
@ -48,17 +37,12 @@ class OneHotCategorical:
|
|||
one_hot = jt.logical_and(self.cum_probs_l < rand, rand <= self.cum_probs_r).float()
|
||||
return one_hot
|
||||
|
||||
def log_prob(self,x):
|
||||
if len(x.shape) == 1:
|
||||
x = x.unsqueeze(0)
|
||||
logits = self.logits.broadcast(x.shape)
|
||||
indices = jt.argmax(x, dim=-1)[0]
|
||||
return logits.gather(1, indices.unsqueeze(-1)).reshape(-1)
|
||||
def log_prob(self, x):
|
||||
x = jt.argmax(x, dim=-1)[0]
|
||||
return Categorical.log_prob(self, x)
|
||||
|
||||
def entropy(self):
|
||||
min_real = -(math.pow(2,23)-1) / math.pow(2,22) * math.pow(2,127)
|
||||
logits = jt.clamp(self.logits,min_v=min_real)
|
||||
p_log_p = logits * self.probs
|
||||
p_log_p = self.logits * self.probs
|
||||
return -p_log_p.sum(-1)
|
||||
|
||||
|
||||
|
@ -68,29 +52,32 @@ class Categorical:
|
|||
if probs is None:
|
||||
# cannot align to pytorch
|
||||
probs = jt.sigmoid(logits)
|
||||
elif logits is None:
|
||||
logits = jt.log(probs)
|
||||
probs = probs / probs.sum(-1, True)
|
||||
if logits is None:
|
||||
logits = jt.safe_log(probs)
|
||||
with jt.no_grad():
|
||||
self.probs = probs / probs.sum(-1, True)
|
||||
self.probs = probs
|
||||
self.logits = logits
|
||||
self.cum_probs = simple_presum(probs)
|
||||
self.cum_probs = simple_presum(self.probs)
|
||||
self.cum_probs_l = self.cum_probs[..., :-1]
|
||||
self.cum_probs_r = self.cum_probs[..., 1:]
|
||||
|
||||
def sample(self, sample_shape=[]):
|
||||
def sample(self, sample_shape=()):
|
||||
shape = sample_shape + self.probs.shape[:-1] + (1,)
|
||||
rand = jt.rand(shape)
|
||||
one_hot = jt.logical_and(self.cum_probs_l < rand, rand <= self.cum_probs_r)
|
||||
index = one_hot.index(one_hot.ndim-1)
|
||||
index = one_hot.index(one_hot.ndim - 1)
|
||||
return (one_hot * index).sum(-1)
|
||||
|
||||
|
||||
def log_prob(self, x):
|
||||
return jt.log(self.probs)[0,x]
|
||||
|
||||
a = self.probs.ndim
|
||||
b = x.ndim
|
||||
indexes = tuple( f'i{i}' for i in range(b-a+1, b) )
|
||||
indexes = indexes + (x,)
|
||||
return jt.safe_log(self.probs).getitem(indexes)
|
||||
|
||||
def entropy(self):
|
||||
min_real = -(math.pow(2,23)-1) / math.pow(2,22) * math.pow(2,127)
|
||||
logits = jt.clamp(self.logits,min_v=min_real)
|
||||
p_log_p = logits * self.probs
|
||||
p_log_p = self.logits * self.probs
|
||||
return -p_log_p.sum(-1)
|
||||
|
||||
|
||||
|
@ -104,11 +91,11 @@ class Normal:
|
|||
|
||||
def log_prob(self, x):
|
||||
var = self.sigma**2
|
||||
log_scale = jt.log(self.sigma)
|
||||
log_scale = jt.safe_log(self.sigma)
|
||||
return -((x-self.mu)**2) / (2*var) - log_scale-np.log(np.sqrt(2*np.pi))
|
||||
|
||||
def entropy(self):
|
||||
return 0.5+0.5*np.log(2*np.pi)+jt.log(self.sigma)
|
||||
return 0.5+0.5*np.log(2*np.pi)+jt.safe_log(self.sigma)
|
||||
|
||||
|
||||
class Uniform:
|
||||
|
@ -123,10 +110,10 @@ class Uniform:
|
|||
def log_prob(self,x):
|
||||
if x < self.low or x >= self.high:
|
||||
return math.inf
|
||||
return -jt.log(self.high - self.low)
|
||||
return -jt.safe_log(self.high - self.low)
|
||||
|
||||
def entropy(self):
|
||||
return jt.log(self.high - self.low)
|
||||
return jt.safe_log(self.high - self.low)
|
||||
|
||||
|
||||
class Geometric:
|
||||
|
@ -138,15 +125,14 @@ class Geometric:
|
|||
self.logits = logits
|
||||
elif logits is None:
|
||||
self.prob = p
|
||||
self.logits = -jt.log(1. / p - 1)
|
||||
self.logits = -jt.safe_log(1. / p - 1)
|
||||
|
||||
def sample(self, sample_shape):
|
||||
tiny = jt.info(self.probs.dtype).tiny
|
||||
u = jt.clamp(jt.rand(sample_shape),min_v=tiny)
|
||||
return (jt.log(u) / (jt.log(-self.probs+1))).floor()
|
||||
u = jt.rand(sample_shape)
|
||||
return (jt.safe_log(u) / (jt.safe_log(-self.probs+1))).floor()
|
||||
|
||||
def log_prob(self, x):
|
||||
return x*jt.log(-self.prob+1)+jt.log(self.prob)
|
||||
return x*jt.safe_log(-self.prob+1)+jt.safe_log(self.prob)
|
||||
|
||||
def entropy(self):
|
||||
return binary_cross_entropy_with_logits(jt.array(self.logits),jt.array(self.prob)) / self.prob
|
||||
|
@ -157,16 +143,14 @@ def kl_divergence(cur_dist, old_dist):
|
|||
if isinstance(cur_dist, Normal):
|
||||
vr = (cur_dist.sigma / old_dist.sigma)**2
|
||||
t1 = ((cur_dist.mu - old_dist.mu) / old_dist.sigma)**2
|
||||
return 0.5*(vr+t1-1-jt.log(vr))
|
||||
return 0.5*(vr+t1-1-jt.safe_log(vr))
|
||||
if isinstance(cur_dist, Categorical) or isinstance(cur_dist,OneHotCategorical):
|
||||
t = cur_dist.probs * (cur_dist.logits-old_dist.logits)
|
||||
t[jt.array((old_dist.probs == 0))] = math.inf
|
||||
t[jt.array((cur_dist.probs == 0))] = 0
|
||||
return t.sum(-1)
|
||||
if isinstance(cur_dist, Uniform):
|
||||
res = jt.log((old_dist.high - old_dist.low) / (cur_dist.high - cur_dist.low))
|
||||
res = jt.safe_log((old_dist.high - old_dist.low) / (cur_dist.high - cur_dist.low))
|
||||
if old_dist.low > cur_dist.low or old_dist.high < cur_dist.high:
|
||||
res = math.inf
|
||||
return res
|
||||
if isinstance(cur_dist, Geometric):
|
||||
return -cur_dist.entropy() - jt.log(-old_dist.prob+1) / cur_dist.prob - old_dist.logits
|
||||
return -cur_dist.entropy() - jt.safe_log(-old_dist.prob+1) / cur_dist.prob - old_dist.logits
|
||||
|
|
|
@ -0,0 +1,288 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2021 Jittor. All Rights Reserved.
|
||||
// Maintainers:
|
||||
// Dun Liang <randonlang@gmail.com>
|
||||
// Guowei Yang <471184555@qq.com>
|
||||
//
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#include "mem/allocator.h"
|
||||
#include "var.h"
|
||||
#include "cudnn_conv3d_backward_w_op.h"
|
||||
#include "cudnn_warper.h"
|
||||
#include "executor.h"
|
||||
#include "ops/op_register.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
namespace jittor {
|
||||
|
||||
#pragma GCC diagnostic ignored "-Wunused-variable"
|
||||
|
||||
#ifndef JIT
|
||||
|
||||
CudnnConv3dBackwardWOp::CudnnConv3dBackwardWOp(Var* x, Var* dy, int kd, int kh, int kw, int strided, int strideh, int stridew, int paddingd, int paddingh, int paddingw, int dilationd, int dilationh, int dilationw, int groups, string xformat)
|
||||
: x(x), dy(dy), kd(kd), kh(kh), kw(kw), strided(strided), strideh(strideh), stridew(stridew), paddingd(paddingd), paddingh(paddingh), paddingw(paddingw), dilationd(dilationd), dilationh(dilationh), dilationw(dilationw), groups(groups),
|
||||
xformat(move(xformat)) {
|
||||
flags.set(NodeFlags::_cuda, 1);
|
||||
flags.set(NodeFlags::_cpu, 0);
|
||||
dw = create_output(nullptr, dtype_infer(dy->ns, x->ns));
|
||||
}
|
||||
|
||||
void CudnnConv3dBackwardWOp::infer_shape() {
|
||||
ASSERTop(x->shape.size(),==,5);
|
||||
ASSERTop(dy->shape.size(),==,5);
|
||||
int xn, xc, xd, xh, xw, wd, wh, ww, wci, wco, yn, yc, yd, yh, yw;
|
||||
|
||||
if (xformat == "ncdhw") {
|
||||
x->shape.unpack(xn, xc, xd, xh, xw);
|
||||
dy->shape.unpack(yn, yc, yd, yh, yw);
|
||||
} else {
|
||||
x->shape.unpack(xn, xd, xh, xw, xc);
|
||||
dy->shape.unpack(yn, yd, yh, yw, yc);
|
||||
}
|
||||
wco = yc, wci = xc / groups;
|
||||
wh = kh;
|
||||
ww = kw;
|
||||
wd = kd;
|
||||
dw->set_shape(NanoVector(wco, wci, wd, wh, ww));
|
||||
}
|
||||
|
||||
void CudnnConv3dBackwardWOp::jit_prepare(JK& jk) {
|
||||
jk << _CS("[Tx:") << x->dtype();
|
||||
jk << _CS("][Ty:") << dy->dtype();
|
||||
jk << _CS("][Tw:") << dw->dtype();
|
||||
jk << ']';
|
||||
}
|
||||
|
||||
static auto make_conv3d = get_op_info("cudnn_conv3d")
|
||||
.get_constructor<VarPtr, Var*, Var*, int, int, int, int, int, int, int, int, int, int, string>();
|
||||
static auto make_backwardx = get_op_info("cudnn_conv3d_backward_x")
|
||||
.get_constructor<VarPtr, Var*, Var*, int, int, int, int, int, int, int, int, int, int, int, int, int, string>();
|
||||
|
||||
|
||||
VarPtr CudnnConv3dBackwardWOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
||||
int xn, xc, xd, xh, xw, wd, wh, ww, wci, wco, yn, yc, yd, yh, yw;
|
||||
|
||||
if (xformat == "ncdhw") {
|
||||
x->shape.unpack(xn, xc, xd, xh, xw);
|
||||
dy->shape.unpack(yn, yc, yd, yh, yw);
|
||||
} else {
|
||||
x->shape.unpack(xn, xd, xh, xw, xc);
|
||||
dy->shape.unpack(yn, yd, yh, yw, yc);
|
||||
}
|
||||
|
||||
if (v_index == 0) {
|
||||
return make_backwardx(dout, dy, xd, xh, xw, strided, strideh, stridew, paddingd, paddingh, paddingw, dilationd, dilationh, dilationw, groups, xformat);
|
||||
} else {
|
||||
return make_conv3d(x, dout, strided, strideh, stridew, paddingd, paddingh, paddingw, dilationd, dilationh, dilationw, groups, xformat);
|
||||
}
|
||||
}
|
||||
|
||||
// unordered_map<string, cudnnConvolutionBwdFilterAlgo_t> bwdw_algo_cache;
|
||||
|
||||
#else // JIT
|
||||
#ifdef JIT_cuda
|
||||
|
||||
#pragma clang diagnostic ignored "-Wtautological-compare"
|
||||
|
||||
extern unordered_map<string, cudnnConvolutionBwdFilterAlgo_t> bwdw_algo_cache;
|
||||
|
||||
template <typename T_ELEM> __inline__ cudnnDataType_t getDataType();
|
||||
template <> __inline__ cudnnDataType_t getDataType<half1>() { return CUDNN_DATA_HALF; }
|
||||
template <> __inline__ cudnnDataType_t getDataType<float>() { return CUDNN_DATA_FLOAT; }
|
||||
|
||||
void CudnnConv3dBackwardWOp::jit_run() {
|
||||
auto w = dw;
|
||||
auto y = dy;
|
||||
cudnnHandle_t& handle_ = cudnn_handle;
|
||||
|
||||
cudnnTensorDescriptor_t cudnnIdesc;
|
||||
cudnnFilterDescriptor_t cudnnFdesc;
|
||||
cudnnTensorDescriptor_t cudnnOdesc;
|
||||
cudnnConvolutionDescriptor_t cudnnConvDesc;
|
||||
|
||||
checkCudaErrors(cudnnCreateTensorDescriptor( &cudnnIdesc ));
|
||||
checkCudaErrors(cudnnCreateFilterDescriptor( &cudnnFdesc ));
|
||||
checkCudaErrors(cudnnCreateTensorDescriptor( &cudnnOdesc ));
|
||||
checkCudaErrors(cudnnCreateConvolutionDescriptor( &cudnnConvDesc ));
|
||||
checkCudaErrors(cudnnSetConvolutionGroupCount( cudnnConvDesc, groups ));
|
||||
|
||||
int xn, xc, xd, xh, xw, wd, wh, ww, wci, wco, yn, yc, yd, yh, yw;
|
||||
int sx[] = {0,0,0,0,1};
|
||||
for (int i=3; i>=0; i--) sx[i] = sx[i+1] * x->shape[i+1];
|
||||
int strideX[5];
|
||||
if (xformat == "ncdhw") {
|
||||
x->shape.unpack(xn, xc, xd, xh, xw);
|
||||
int tmp[5] = {sx[0],sx[1],sx[2],sx[3],sx[4]};
|
||||
memcpy(strideX, tmp, sizeof(tmp));
|
||||
} else {
|
||||
x->shape.unpack(xn, xd, xh, xw, xc);
|
||||
int tmp[5] = {sx[0],sx[2],sx[3],sx[4],sx[1]};
|
||||
memcpy(strideX, tmp, sizeof(tmp));
|
||||
}
|
||||
int dimX[] = {xn, xc, xd, xh, xw};
|
||||
// dimX: ncdhw
|
||||
checkCudaErrors(cudnnSetTensorNdDescriptor(
|
||||
cudnnIdesc, getDataType<Tx>(),
|
||||
5, dimX, strideX
|
||||
));
|
||||
|
||||
auto ws = w->shape;
|
||||
int dimW[] = {(int)ws[0],(int)ws[1],(int)ws[2],(int)ws[3],(int)ws[4]};
|
||||
// cudnn only support this two format
|
||||
// https://docs.nvidia.com/deeplearning/sdk/cudnn-api/index.html#cudnnSetFilterNdDescriptor
|
||||
#define filterFormat_oihw CUDNN_TENSOR_NCHW
|
||||
#define filterFormat_ohwi CUDNN_TENSOR_NHWC
|
||||
|
||||
// dimW: KCRS(oihw)
|
||||
checkCudaErrors(cudnnSetFilterNdDescriptor(
|
||||
cudnnFdesc, getDataType<Tw>(),
|
||||
// filterFormat_@WFORMAT, 5, dimW
|
||||
filterFormat_oihw, 5, dimW
|
||||
));
|
||||
|
||||
int padA[] = {paddingd, paddingh, paddingw};
|
||||
int convstrideA[] = {strided, strideh, stridew};
|
||||
int dilationA[] = {dilationd, dilationh, dilationw};
|
||||
// difference between
|
||||
// CUDNN_CONVOLUTION and CUDNN_CROSS_CORRELATION
|
||||
// is the kernel rc order
|
||||
// currently, No perf difference is observed between
|
||||
// this two mode
|
||||
checkCudaErrors(cudnnSetConvolutionNdDescriptor(
|
||||
cudnnConvDesc, 3,
|
||||
padA, convstrideA, dilationA,
|
||||
CUDNN_CROSS_CORRELATION, getDataType<Ty>()
|
||||
));
|
||||
|
||||
// using tensor core
|
||||
// checkCudaErrors( cudnnSetConvolutionMathType(cudnnConvDesc, CUDNN_TENSOR_OP_MATH) );
|
||||
|
||||
|
||||
int sy[] = {0,0,0,0,1};
|
||||
for (int i=3; i>=0; i--) sy[i] = sy[i+1] * y->shape[i+1];
|
||||
int strideY[5];
|
||||
if (xformat == "ncdhw") {
|
||||
y->shape.unpack(yn, yc, yd, yh, yw);
|
||||
int tmp[5] = {sy[0],sy[1],sy[2],sy[3],sy[4]};
|
||||
memcpy(strideY, tmp, sizeof(tmp));
|
||||
} else {
|
||||
y->shape.unpack(yn, yd, yh, yw, yc);
|
||||
int tmp[5] = {sy[0],sy[2],sy[3],sy[4],sy[1]};
|
||||
memcpy(strideY, tmp, sizeof(tmp));
|
||||
}
|
||||
int dimY[] = {yn, yc, yd, yh, yw};
|
||||
|
||||
checkCudaErrors(cudnnSetTensorNdDescriptor(
|
||||
cudnnOdesc, getDataType<Ty>(),
|
||||
5, dimY, strideY
|
||||
));
|
||||
|
||||
cudnnConvolutionBwdFilterAlgo_t algos[] = {
|
||||
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0,
|
||||
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1,
|
||||
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT,
|
||||
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3,
|
||||
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED,
|
||||
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT_TILING,
|
||||
};
|
||||
int num_algos = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_COUNT;
|
||||
int perf_count;
|
||||
cudnnConvolutionBwdFilterAlgoPerf_t perf_results[num_algos];
|
||||
cudnnConvolutionBwdFilterAlgo_t algo;
|
||||
bool benchmark=true;
|
||||
|
||||
jk.clear();
|
||||
jk << dimX[0] << "," << dimX[1] << "," << dimX[2] << "," << dimX[3] << "," << dimX[4] << ",";
|
||||
jk << dimW[0] << "," << dimW[1] << "," << dimW[2] << "," << dimW[3] << "," << dimW[4] << ",";
|
||||
jk << paddingd << paddingh << paddingw << "," << strided << strideh <<stridew << "," << dilationd << dilationh << dilationw << "," << groups << ".";
|
||||
auto iter = bwdw_algo_cache.find(jk.to_string());
|
||||
|
||||
if (iter!=bwdw_algo_cache.end()) algo = iter->second;
|
||||
else {
|
||||
if (bwdw_algo_cache.size()>=max_cache_size) benchmark = false;
|
||||
if (benchmark) {
|
||||
size_t max_ws_size = 0;
|
||||
for (int i = 0; i < num_algos; i++) {
|
||||
size_t sz;
|
||||
cudnnStatus_t ret = cudnnGetConvolutionBackwardFilterWorkspaceSize(handle_, cudnnIdesc, cudnnOdesc, cudnnConvDesc, cudnnFdesc, algos[i], &sz);
|
||||
// continue if use too much workspace
|
||||
if (sz > mem_info.total_cuda_ram * max_workspace_ratio) continue;
|
||||
if (CUDNN_STATUS_SUCCESS == ret && sz > max_ws_size) max_ws_size = sz;
|
||||
}
|
||||
size_t allocation;
|
||||
void* ws = exe.temp_allocator->alloc(max_ws_size, allocation);
|
||||
checkCudaErrors(cudnnFindConvolutionBackwardFilterAlgorithmEx(
|
||||
handle_,
|
||||
cudnnIdesc, x->ptr<Tx>(),
|
||||
cudnnOdesc, y->ptr<Ty>(),
|
||||
cudnnConvDesc,
|
||||
cudnnFdesc, w->ptr<Tw>(),
|
||||
num_algos,
|
||||
&perf_count,
|
||||
perf_results,
|
||||
ws,
|
||||
max_ws_size));
|
||||
exe.temp_allocator->free(ws, max_ws_size, allocation);
|
||||
} else {
|
||||
checkCudaErrors(cudnnGetConvolutionBackwardFilterAlgorithm_v7(
|
||||
handle_,
|
||||
cudnnIdesc,
|
||||
cudnnOdesc,
|
||||
cudnnConvDesc,
|
||||
cudnnFdesc,
|
||||
num_algos,
|
||||
&perf_count,
|
||||
perf_results));
|
||||
}
|
||||
int best_algo_idx=-1;
|
||||
for (int i = 0; i < perf_count; i++)
|
||||
if (perf_results[i].status == CUDNN_STATUS_SUCCESS){
|
||||
best_algo_idx=i;
|
||||
break;
|
||||
}
|
||||
ASSERT(best_algo_idx!=-1);
|
||||
algo=perf_results[best_algo_idx].algo;
|
||||
if (benchmark) {
|
||||
bwdw_algo_cache[jk.to_string()] = algo;
|
||||
if (bwdw_algo_cache.size()==max_cache_size)
|
||||
LOGw << "backward w algorithm cache is full";
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: warp work space
|
||||
void *workSpace = 0;
|
||||
size_t workSpaceSize;
|
||||
checkCudaErrors (cudnnGetConvolutionBackwardFilterWorkspaceSize(
|
||||
handle_, cudnnIdesc, cudnnOdesc, cudnnConvDesc,
|
||||
cudnnFdesc, algo, &workSpaceSize));
|
||||
size_t allocation;
|
||||
if (workSpaceSize > 0) {
|
||||
workSpace = exe.temp_allocator->alloc(workSpaceSize, allocation);
|
||||
}
|
||||
float alpha=1, beta=0;
|
||||
checkCudaErrors(cudnnConvolutionBackwardFilter(
|
||||
handle_,
|
||||
(void*)(&alpha),
|
||||
cudnnIdesc, x->ptr<Tx>(),
|
||||
cudnnOdesc, y->ptr<Ty>(),
|
||||
cudnnConvDesc,
|
||||
algo,
|
||||
workSpace, workSpaceSize,
|
||||
(void*)(&beta),
|
||||
cudnnFdesc, w->ptr<Tw>())
|
||||
);
|
||||
if (workSpace)
|
||||
exe.temp_allocator->free(workSpace, workSpaceSize, allocation);
|
||||
|
||||
checkCudaErrors(cudnnDestroyTensorDescriptor( cudnnIdesc ));
|
||||
checkCudaErrors(cudnnDestroyFilterDescriptor( cudnnFdesc ));
|
||||
checkCudaErrors(cudnnDestroyTensorDescriptor( cudnnOdesc ));
|
||||
checkCudaErrors(cudnnDestroyConvolutionDescriptor( cudnnConvDesc ));
|
||||
}
|
||||
#endif
|
||||
#endif // JIT
|
||||
|
||||
} // jittor
|
|
@ -0,0 +1,28 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2021 Jittor. All Rights Reserved.
|
||||
// Maintainers:
|
||||
// Dun Liang <randonlang@gmail.com>
|
||||
// Guowei Yang <471184555@qq.com>
|
||||
//
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#pragma once
|
||||
#include "op.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
struct CudnnConv3dBackwardWOp : Op {
|
||||
Var* x, * dy, * dw;
|
||||
int kd, kh, kw, strided, strideh, stridew, paddingd, paddingh, paddingw, dilationd, dilationh, dilationw, groups;
|
||||
string xformat;
|
||||
|
||||
CudnnConv3dBackwardWOp(Var* x, Var* y, int kd, int kh, int kw, int strided, int strideh, int stridew, int paddingd, int paddingh, int paddingw, int dilationd, int dilationh, int dilationw, int groups=1, string xformat="ncdhw");
|
||||
|
||||
const char* name() const override { return "cudnn_conv3d_backward_w"; }
|
||||
VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override;
|
||||
void infer_shape() override;
|
||||
DECLARE_jit_run;
|
||||
};
|
||||
|
||||
} // jittor
|
|
@ -0,0 +1,279 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2021 Jittor. All Rights Reserved.
|
||||
// Maintainers:
|
||||
// Dun Liang <randonlang@gmail.com>
|
||||
// Guowei Yang <471184555@qq.com>
|
||||
//
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#include "mem/allocator.h"
|
||||
#include "var.h"
|
||||
#include "cudnn_conv3d_backward_x_op.h"
|
||||
#include "cudnn_warper.h"
|
||||
#include "executor.h"
|
||||
#include "ops/op_register.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
namespace jittor {
|
||||
|
||||
#pragma GCC diagnostic ignored "-Wunused-variable"
|
||||
|
||||
#ifndef JIT
|
||||
|
||||
CudnnConv3dBackwardXOp::CudnnConv3dBackwardXOp(Var* w, Var* dy, int depth, int height, int width, int strided, int strideh, int stridew, int paddingd, int paddingh, int paddingw, int dilationd, int dilationh, int dilationw, int groups, string xformat)
|
||||
: w(w), dy(dy), xd(depth), xh(height), xw(width), strided(strided), strideh(strideh), stridew(stridew), paddingd(paddingd), paddingh(paddingh), paddingw(paddingw), dilationd(dilationd), dilationh(dilationh), dilationw(dilationw), groups(groups),
|
||||
xformat(move(xformat)) {
|
||||
flags.set(NodeFlags::_cuda, 1);
|
||||
flags.set(NodeFlags::_cpu, 0);
|
||||
dx = create_output(nullptr, dtype_infer(dy->ns, w->ns));
|
||||
}
|
||||
|
||||
void CudnnConv3dBackwardXOp::infer_shape() {
|
||||
ASSERTop(w->shape.size(),==,5);
|
||||
ASSERTop(dy->shape.size(),==,5);
|
||||
int xn, xc, wd, wh, ww, wci, wco, yn, yc, yd, yh, yw;
|
||||
w->shape.unpack(wco, wci, wd, wh, ww);
|
||||
if (xformat == "ncdhw")
|
||||
dy->shape.unpack(yn, yc, yd, yh, yw);
|
||||
else
|
||||
dy->shape.unpack(yn, yd, yh, yw, yc);
|
||||
xn = yn, xc = wci * groups;
|
||||
if (xformat == "ncdhw")
|
||||
dx->set_shape(NanoVector(xn, xc, xd, xh, xw));
|
||||
else
|
||||
dx->set_shape(NanoVector(xn, xd, xh, xw, xc));
|
||||
}
|
||||
|
||||
void CudnnConv3dBackwardXOp::jit_prepare(JK& jk) {
|
||||
jk << _CS("[Tx:") << dx->dtype();
|
||||
jk << _CS("][Ty:") << dy->dtype();
|
||||
jk << _CS("][Tw:") << w->dtype();
|
||||
jk << ']';
|
||||
}
|
||||
|
||||
|
||||
static auto make_conv3d = get_op_info("cudnn_conv3d")
|
||||
.get_constructor<VarPtr, Var*, Var*, int, int, int, int, int, int, int, int, int, int, string>();
|
||||
static auto make_backwardw = get_op_info("cudnn_conv3d_backward_w")
|
||||
.get_constructor<VarPtr, Var*, Var*, int, int, int, int, int, int, int, int, int, int, int, int, int, string>();
|
||||
|
||||
|
||||
VarPtr CudnnConv3dBackwardXOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
||||
int xn, xc, wd, wh, ww, wci, wco, yn, yc, yd, yh, yw;
|
||||
w->shape.unpack(wco, wci, wd, wh, ww);
|
||||
|
||||
if (v_index == 0) {
|
||||
return make_backwardw(dout, dy, wd, wh, ww, strided, strideh, stridew, paddingd, paddingh, paddingw, dilationd, dilationh, dilationw, groups, xformat);
|
||||
} else {
|
||||
return make_conv3d(dout, w, strided, strideh, stridew, paddingd, paddingh, paddingw, dilationd, dilationh, dilationw, groups, xformat);
|
||||
}
|
||||
}
|
||||
// unordered_map<string, cudnnConvolutionBwdDataAlgo_t> bwdx_algo_cache;
|
||||
|
||||
#else // JIT
|
||||
#ifdef JIT_cuda
|
||||
|
||||
#pragma clang diagnostic ignored "-Wtautological-compare"
|
||||
|
||||
extern unordered_map<string, cudnnConvolutionBwdDataAlgo_t> bwdx_algo_cache;
|
||||
|
||||
template <typename T_ELEM> __inline__ cudnnDataType_t getDataType();
|
||||
template <> __inline__ cudnnDataType_t getDataType<half1>() { return CUDNN_DATA_HALF; }
|
||||
template <> __inline__ cudnnDataType_t getDataType<float>() { return CUDNN_DATA_FLOAT; }
|
||||
|
||||
void CudnnConv3dBackwardXOp::jit_run() {
|
||||
auto x = dx;
|
||||
auto y = dy;
|
||||
cudnnHandle_t& handle_ = cudnn_handle;
|
||||
|
||||
cudnnTensorDescriptor_t cudnnIdesc;
|
||||
cudnnFilterDescriptor_t cudnnFdesc;
|
||||
cudnnTensorDescriptor_t cudnnOdesc;
|
||||
cudnnConvolutionDescriptor_t cudnnConvDesc;
|
||||
|
||||
checkCudaErrors(cudnnCreateTensorDescriptor( &cudnnIdesc ));
|
||||
checkCudaErrors(cudnnCreateFilterDescriptor( &cudnnFdesc ));
|
||||
checkCudaErrors(cudnnCreateTensorDescriptor( &cudnnOdesc ));
|
||||
checkCudaErrors(cudnnCreateConvolutionDescriptor( &cudnnConvDesc ));
|
||||
checkCudaErrors(cudnnSetConvolutionGroupCount( cudnnConvDesc, groups ));
|
||||
|
||||
|
||||
int xn, xc, xd, xh, xw, wd, wh, ww, wci, wco, yn, yc, yd, yh, yw;
|
||||
int sx[] = {0,0,0,0,1};
|
||||
for (int i=3; i>=0; i--) sx[i] = sx[i+1] * x->shape[i+1];
|
||||
int strideX[5];
|
||||
if (xformat == "ncdhw") {
|
||||
x->shape.unpack(xn, xc, xd, xh, xw);
|
||||
int tmp[5] = {sx[0],sx[1],sx[2],sx[3],sx[4]};
|
||||
memcpy(strideX, tmp, sizeof(tmp));
|
||||
} else {
|
||||
x->shape.unpack(xn, xd, xh, xw, xc);
|
||||
int tmp[5] = {sx[0],sx[2],sx[3],sx[4],sx[1]};
|
||||
memcpy(strideX, tmp, sizeof(tmp));
|
||||
}
|
||||
int dimX[] = {xn, xc, xd, xh, xw};
|
||||
// dimX: ncdhw
|
||||
checkCudaErrors(cudnnSetTensorNdDescriptor(
|
||||
cudnnIdesc, getDataType<Tx>(),
|
||||
5, dimX, strideX
|
||||
));
|
||||
|
||||
auto ws = w->shape;
|
||||
int dimW[] = {(int)ws[0],(int)ws[1],(int)ws[2],(int)ws[3],(int)ws[4]};
|
||||
// cudnn only support this two format
|
||||
// https://docs.nvidia.com/deeplearning/sdk/cudnn-api/index.html#cudnnSetFilterNdDescriptor
|
||||
#define filterFormat_oihw CUDNN_TENSOR_NCHW
|
||||
#define filterFormat_ohwi CUDNN_TENSOR_NHWC
|
||||
|
||||
// dimW: KCRS(oihw)
|
||||
checkCudaErrors(cudnnSetFilterNdDescriptor(
|
||||
cudnnFdesc, getDataType<Tw>(),
|
||||
// filterFormat_@WFORMAT, 5, dimW
|
||||
filterFormat_oihw, 5, dimW
|
||||
));
|
||||
|
||||
int padA[] = {paddingd, paddingh, paddingw};
|
||||
int convstrideA[] = {strided, strideh, stridew};
|
||||
int dilationA[] = {dilationd, dilationh, dilationw};
|
||||
// difference between
|
||||
// CUDNN_CONVOLUTION and CUDNN_CROSS_CORRELATION
|
||||
// is the kernel rc order
|
||||
// currently, No perf difference is observed between
|
||||
// this two mode
|
||||
checkCudaErrors(cudnnSetConvolutionNdDescriptor(
|
||||
cudnnConvDesc, 3,
|
||||
padA, convstrideA, dilationA,
|
||||
CUDNN_CROSS_CORRELATION, getDataType<Ty>()
|
||||
));
|
||||
|
||||
// using tensor core
|
||||
// checkCudaErrors( cudnnSetConvolutionMathType(cudnnConvDesc, CUDNN_TENSOR_OP_MATH) );
|
||||
|
||||
|
||||
int sy[] = {0,0,0,0,1};
|
||||
for (int i=3; i>=0; i--) sy[i] = sy[i+1] * y->shape[i+1];
|
||||
int strideY[5];
|
||||
if (xformat == "ncdhw") {
|
||||
y->shape.unpack(yn, yc, yd, yh, yw);
|
||||
int tmp[5] = {sy[0],sy[1],sy[2],sy[3],sy[4]};
|
||||
memcpy(strideY, tmp, sizeof(tmp));
|
||||
} else {
|
||||
y->shape.unpack(yn, yd, yh, yw, yc);
|
||||
int tmp[5] = {sy[0],sy[2],sy[3],sy[4],sy[1]};
|
||||
memcpy(strideY, tmp, sizeof(tmp));
|
||||
}
|
||||
int dimY[] = {yn, yc, yd, yh, yw};
|
||||
|
||||
checkCudaErrors(cudnnSetTensorNdDescriptor(
|
||||
cudnnOdesc, getDataType<Ty>(),
|
||||
5, dimY, strideY
|
||||
));
|
||||
|
||||
cudnnConvolutionBwdDataAlgo_t algos[] = {
|
||||
CUDNN_CONVOLUTION_BWD_DATA_ALGO_0,
|
||||
CUDNN_CONVOLUTION_BWD_DATA_ALGO_1,
|
||||
CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT,
|
||||
CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING,
|
||||
CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD,
|
||||
CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED
|
||||
};
|
||||
int num_algos = CUDNN_CONVOLUTION_BWD_DATA_ALGO_COUNT;
|
||||
int perf_count;
|
||||
cudnnConvolutionBwdDataAlgoPerf_t perf_results[num_algos];
|
||||
cudnnConvolutionBwdDataAlgo_t algo;
|
||||
bool benchmark=true;
|
||||
|
||||
jk.clear();
|
||||
jk << dimX[0] << "," << dimX[1] << "," << dimX[2] << "," << dimX[3] << "," << dimX[4] << ",";
|
||||
jk << dimW[0] << "," << dimW[1] << "," << dimW[2] << "," << dimW[3] << "," << dimW[4] << ",";
|
||||
jk << paddingd << paddingh << paddingw << "," << strided << strideh <<stridew << "," << dilationd << dilationh << dilationw << "," << groups << ".";
|
||||
auto iter = bwdx_algo_cache.find(jk.to_string());
|
||||
|
||||
if (iter!=bwdx_algo_cache.end()) algo = iter->second;
|
||||
else {
|
||||
if (bwdx_algo_cache.size()>=max_cache_size) benchmark = false;
|
||||
if (benchmark) {
|
||||
size_t max_ws_size = 0;
|
||||
for (int i = 0; i < num_algos; i++) {
|
||||
size_t sz;
|
||||
cudnnStatus_t ret = cudnnGetConvolutionBackwardDataWorkspaceSize(handle_, cudnnFdesc, cudnnOdesc, cudnnConvDesc, cudnnIdesc, algos[i], &sz);
|
||||
// continue if use too much workspace
|
||||
if (sz > mem_info.total_cuda_ram * max_workspace_ratio) continue;
|
||||
if (CUDNN_STATUS_SUCCESS == ret && sz > max_ws_size) max_ws_size = sz;
|
||||
}
|
||||
size_t allocation;
|
||||
void* ws = exe.temp_allocator->alloc(max_ws_size, allocation);
|
||||
checkCudaErrors(cudnnFindConvolutionBackwardDataAlgorithmEx(
|
||||
handle_,
|
||||
cudnnFdesc, w->ptr<Tw>(),
|
||||
cudnnOdesc, y->ptr<Ty>(),
|
||||
cudnnConvDesc,
|
||||
cudnnIdesc, x->ptr<Tx>(),
|
||||
num_algos,
|
||||
&perf_count,
|
||||
perf_results,
|
||||
ws,
|
||||
max_ws_size));
|
||||
exe.temp_allocator->free(ws, max_ws_size, allocation);
|
||||
} else {
|
||||
checkCudaErrors(cudnnGetConvolutionBackwardDataAlgorithm_v7(
|
||||
handle_,
|
||||
cudnnFdesc,
|
||||
cudnnOdesc,
|
||||
cudnnConvDesc,
|
||||
cudnnIdesc,
|
||||
num_algos,
|
||||
&perf_count,
|
||||
perf_results));
|
||||
}
|
||||
int best_algo_idx=-1;
|
||||
for (int i = 0; i < perf_count; i++)
|
||||
if (perf_results[i].status == CUDNN_STATUS_SUCCESS){
|
||||
best_algo_idx=i;
|
||||
break;
|
||||
}
|
||||
ASSERT(best_algo_idx!=-1);
|
||||
algo=perf_results[best_algo_idx].algo;
|
||||
if (benchmark) {
|
||||
bwdx_algo_cache[jk.to_string()] = algo;
|
||||
if (bwdx_algo_cache.size()==max_cache_size)
|
||||
LOGw << "backward x algorithm cache is full";
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: warp work space
|
||||
void *workSpace = 0;
|
||||
size_t workSpaceSize;
|
||||
checkCudaErrors (cudnnGetConvolutionBackwardDataWorkspaceSize(
|
||||
handle_, cudnnFdesc, cudnnOdesc, cudnnConvDesc,
|
||||
cudnnIdesc, algo, &workSpaceSize));
|
||||
size_t allocation;
|
||||
if (workSpaceSize > 0) {
|
||||
workSpace = exe.temp_allocator->alloc(workSpaceSize, allocation);
|
||||
}
|
||||
float alpha=1, beta=0;
|
||||
checkCudaErrors(cudnnConvolutionBackwardData(
|
||||
handle_,
|
||||
(void*)(&alpha),
|
||||
cudnnFdesc, w->ptr<Tw>(),
|
||||
cudnnOdesc, y->ptr<Ty>(),
|
||||
cudnnConvDesc,
|
||||
algo,
|
||||
workSpace, workSpaceSize,
|
||||
(void*)(&beta),
|
||||
cudnnIdesc, x->ptr<Tx>())
|
||||
);
|
||||
if (workSpace)
|
||||
exe.temp_allocator->free(workSpace, workSpaceSize, allocation);
|
||||
|
||||
checkCudaErrors(cudnnDestroyTensorDescriptor( cudnnIdesc ));
|
||||
checkCudaErrors(cudnnDestroyFilterDescriptor( cudnnFdesc ));
|
||||
checkCudaErrors(cudnnDestroyTensorDescriptor( cudnnOdesc ));
|
||||
checkCudaErrors(cudnnDestroyConvolutionDescriptor( cudnnConvDesc ));
|
||||
}
|
||||
#endif
|
||||
#endif // JIT
|
||||
|
||||
} // jittor
|
|
@ -0,0 +1,28 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2021 Jittor. All Rights Reserved.
|
||||
// Maintainers:
|
||||
// Dun Liang <randonlang@gmail.com>
|
||||
// Guowei Yang <471184555@qq.com>
|
||||
//
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#pragma once
|
||||
#include "op.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
struct CudnnConv3dBackwardXOp : Op {
|
||||
Var* w, * dy, * dx;
|
||||
int xd, xh, xw, strided, strideh, stridew, paddingd, paddingh, paddingw, dilationd, dilationh, dilationw, groups;
|
||||
string xformat;
|
||||
|
||||
CudnnConv3dBackwardXOp(Var* w, Var* y, int depth, int height, int width, int strided, int strideh, int stridew, int paddingd, int paddingh, int paddingw, int dilationd, int dilationh, int dilationw, int groups=1, string xformat="ncdhw");
|
||||
|
||||
const char* name() const override { return "cudnn_conv3d_backward_x"; }
|
||||
VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override;
|
||||
void infer_shape() override;
|
||||
DECLARE_jit_run;
|
||||
};
|
||||
|
||||
} // jittor
|
|
@ -0,0 +1,284 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2021 Jittor. All Rights Reserved.
|
||||
// Maintainers: Dun Liang <randonlang@gmail.com>.
|
||||
//
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#include "var.h"
|
||||
#include "cudnn_conv3d_op.h"
|
||||
#include "cudnn_warper.h"
|
||||
#include "executor.h"
|
||||
#include "ops/op_register.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
namespace jittor {
|
||||
|
||||
#pragma GCC diagnostic ignored "-Wunused-variable"
|
||||
|
||||
#ifndef JIT
|
||||
|
||||
CudnnConv3dOp::CudnnConv3dOp(Var* x, Var* w, int strided, int strideh, int stridew, int paddingd, int paddingh, int paddingw, int dilationd, int dilationh, int dilationw, int groups, string xformat)
|
||||
: x(x), w(w), strided(strided), strideh(strideh), stridew(stridew), paddingd(paddingd), paddingh(paddingh), paddingw(paddingw), dilationd(dilationd), dilationh(dilationh), dilationw(dilationw), groups(groups),
|
||||
xformat(move(xformat)) {
|
||||
flags.set(NodeFlags::_cuda, 1);
|
||||
flags.set(NodeFlags::_cpu, 0);
|
||||
y = create_output(nullptr, dtype_infer(x->ns, w->ns));
|
||||
}
|
||||
|
||||
void CudnnConv3dOp::infer_shape() {
|
||||
ASSERTop(x->shape.size(),==,5);
|
||||
ASSERTop(w->shape.size(),==,5);
|
||||
int xn, xc, xd, xh, xw, wd, wh, ww, wci, wco, yn, yc, yd, yh, yw;
|
||||
if (xformat == "ncdhw")
|
||||
x->shape.unpack(xn, xc, xd, xh, xw);
|
||||
else
|
||||
x->shape.unpack(xn, xd, xh, xw, xc);
|
||||
w->shape.unpack(wco, wci, wd, wh, ww);
|
||||
ASSERTop(wci * groups,==,xc);
|
||||
yn = xn, yc = wco;
|
||||
yd = (xd+paddingd*2-wd*dilationd+dilationd-1)/strided+1;
|
||||
yh = (xh+paddingh*2-wh*dilationh+dilationh-1)/strideh+1;
|
||||
yw = (xw+paddingw*2-ww*dilationw+dilationw-1)/stridew+1;
|
||||
if (xformat == "ncdhw")
|
||||
y->set_shape(NanoVector(yn, yc, yd, yh, yw));
|
||||
else
|
||||
y->set_shape(NanoVector(yn, yd, yh, yw, yc));
|
||||
}
|
||||
|
||||
void CudnnConv3dOp::jit_prepare(JK& jk) {
|
||||
jk << _CS("[Tx:") << x->dtype();
|
||||
jk << _CS("][Ty:") << y->dtype();
|
||||
jk << _CS("][Tw:") << w->dtype();
|
||||
jk << ']';
|
||||
}
|
||||
|
||||
static auto make_backwardx = get_op_info("cudnn_conv3d_backward_x")
|
||||
.get_constructor<VarPtr, Var*, Var*, int, int, int, int, int, int, int, int, int, int, int, int, int, string>();
|
||||
static auto make_backwardw = get_op_info("cudnn_conv3d_backward_w")
|
||||
.get_constructor<VarPtr, Var*, Var*, int, int, int, int, int, int, int, int, int, int, int, int, int, string>();
|
||||
|
||||
VarPtr CudnnConv3dOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
||||
int xn, xc, xd, xh, xw, wd, wh, ww, wci, wco, yn, yc, yd, yh, yw;
|
||||
if (xformat == "ncdhw")
|
||||
x->shape.unpack(xn, xc, xd, xh, xw);
|
||||
else
|
||||
x->shape.unpack(xn, xd, xh, xw, xc);
|
||||
w->shape.unpack(wco, wci, wd, wh, ww);
|
||||
if (v_index == 0) {
|
||||
return make_backwardx(w, dout, xd, xh, xw, strided, strideh, stridew, paddingd, paddingh, paddingw, dilationd, dilationh, dilationw, groups, xformat);
|
||||
} else {
|
||||
return make_backwardw(x, dout, wd, wh, ww, strided, strideh, stridew, paddingd, paddingh, paddingw, dilationd, dilationh, dilationw, groups, xformat);
|
||||
}
|
||||
}
|
||||
|
||||
// unordered_map<string, cudnnConvolutionFwdAlgo_t> fwd_algo_cache;
|
||||
|
||||
#else // JIT
|
||||
#ifdef JIT_cuda
|
||||
|
||||
#pragma clang diagnostic ignored "-Wtautological-compare"
|
||||
|
||||
extern unordered_map<string, cudnnConvolutionFwdAlgo_t> fwd_algo_cache;
|
||||
|
||||
template <typename T_ELEM> __inline__ cudnnDataType_t getDataType();
|
||||
template <> __inline__ cudnnDataType_t getDataType<half1>() { return CUDNN_DATA_HALF; }
|
||||
template <> __inline__ cudnnDataType_t getDataType<float>() { return CUDNN_DATA_FLOAT; }
|
||||
|
||||
void CudnnConv3dOp::jit_run() {
|
||||
cudnnHandle_t& handle_ = cudnn_handle;
|
||||
|
||||
cudnnTensorDescriptor_t cudnnIdesc;
|
||||
cudnnFilterDescriptor_t cudnnFdesc;
|
||||
cudnnTensorDescriptor_t cudnnOdesc;
|
||||
cudnnConvolutionDescriptor_t cudnnConvDesc;
|
||||
|
||||
checkCudaErrors(cudnnCreateTensorDescriptor( &cudnnIdesc ));
|
||||
checkCudaErrors(cudnnCreateFilterDescriptor( &cudnnFdesc ));
|
||||
checkCudaErrors(cudnnCreateTensorDescriptor( &cudnnOdesc ));
|
||||
checkCudaErrors(cudnnCreateConvolutionDescriptor( &cudnnConvDesc ));
|
||||
checkCudaErrors(cudnnSetConvolutionGroupCount( cudnnConvDesc, groups ));
|
||||
|
||||
|
||||
int xn, xc, xd, xh, xw, wd, wh, ww, wci, wco, yn, yc, yd, yh, yw;
|
||||
int sx[] = {0,0,0,0,1};
|
||||
for (int i=3; i>=0; i--) sx[i] = sx[i+1] * x->shape[i+1];
|
||||
int strideX[5];
|
||||
if (xformat == "ncdhw") {
|
||||
x->shape.unpack(xn, xc, xd, xh, xw);
|
||||
int tmp[5] = {sx[0],sx[1],sx[2],sx[3],sx[4]};
|
||||
memcpy(strideX, tmp, sizeof(tmp));
|
||||
} else {
|
||||
x->shape.unpack(xn, xd, xh, xw, xc);
|
||||
int tmp[5] = {sx[0],sx[2],sx[3],sx[4],sx[1]};
|
||||
memcpy(strideX, tmp, sizeof(tmp));
|
||||
}
|
||||
int dimX[] = {xn, xc, xd, xh, xw};
|
||||
// dimX: ncdhw
|
||||
checkCudaErrors(cudnnSetTensorNdDescriptor(
|
||||
cudnnIdesc, getDataType<Tx>(),
|
||||
5, dimX, strideX
|
||||
));
|
||||
|
||||
auto ws = w->shape;
|
||||
int dimW[] = {(int)ws[0],(int)ws[1],(int)ws[2],(int)ws[3],(int)ws[4]};
|
||||
// cudnn only support this two format
|
||||
// https://docs.nvidia.com/deeplearning/sdk/cudnn-api/index.html#cudnnSetFilterNdDescriptor
|
||||
#define filterFormat_oihw CUDNN_TENSOR_NCHW
|
||||
#define filterFormat_ohwi CUDNN_TENSOR_NHWC
|
||||
|
||||
// dimW: KCRS(oihw)
|
||||
checkCudaErrors(cudnnSetFilterNdDescriptor(
|
||||
cudnnFdesc, getDataType<Tw>(),
|
||||
// filterFormat_@WFORMAT, 5, dimW
|
||||
filterFormat_oihw, 5, dimW
|
||||
));
|
||||
|
||||
int padA[] = {paddingd, paddingh, paddingw};
|
||||
int convstrideA[] = {strided, strideh, stridew};
|
||||
int dilationA[] = {dilationd, dilationh, dilationw};
|
||||
// difference between
|
||||
// CUDNN_CONVOLUTION and CUDNN_CROSS_CORRELATION
|
||||
// is the kernel rc order
|
||||
// currently, No perf difference is observed between
|
||||
// this two mode
|
||||
checkCudaErrors(cudnnSetConvolutionNdDescriptor(
|
||||
cudnnConvDesc, 3,
|
||||
padA, convstrideA, dilationA,
|
||||
CUDNN_CROSS_CORRELATION, getDataType<Ty>()
|
||||
));
|
||||
|
||||
// using tensor core
|
||||
// checkCudaErrors( cudnnSetConvolutionMathType(cudnnConvDesc, CUDNN_TENSOR_OP_MATH) );
|
||||
|
||||
|
||||
int sy[] = {0,0,0,0,1};
|
||||
for (int i=3; i>=0; i--) sy[i] = sy[i+1] * y->shape[i+1];
|
||||
int strideY[5];
|
||||
if (xformat == "ncdhw") {
|
||||
y->shape.unpack(yn, yc, yd, yh, yw);
|
||||
int tmp[5] = {sy[0],sy[1],sy[2],sy[3],sy[4]};
|
||||
memcpy(strideY, tmp, sizeof(tmp));
|
||||
} else {
|
||||
y->shape.unpack(yn, yd, yh, yw, yc);
|
||||
int tmp[5] = {sy[0],sy[2],sy[3],sy[4],sy[1]};
|
||||
memcpy(strideY, tmp, sizeof(tmp));
|
||||
}
|
||||
int dimY[] = {yn, yc, yd, yh, yw};
|
||||
|
||||
checkCudaErrors(cudnnSetTensorNdDescriptor(
|
||||
cudnnOdesc, getDataType<Ty>(),
|
||||
5, dimY, strideY
|
||||
));
|
||||
|
||||
cudnnConvolutionFwdAlgo_t algos[] = {
|
||||
CUDNN_CONVOLUTION_FWD_ALGO_GEMM,
|
||||
CUDNN_CONVOLUTION_FWD_ALGO_FFT,
|
||||
CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING,
|
||||
CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM,
|
||||
CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM,
|
||||
CUDNN_CONVOLUTION_FWD_ALGO_DIRECT,
|
||||
CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD,
|
||||
CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED,
|
||||
};
|
||||
int num_algos = CUDNN_CONVOLUTION_FWD_ALGO_COUNT;
|
||||
int perf_count;
|
||||
cudnnConvolutionFwdAlgoPerf_t perf_results[num_algos];
|
||||
cudnnConvolutionFwdAlgo_t algo;
|
||||
bool benchmark=true;
|
||||
|
||||
jk.clear();
|
||||
jk << dimX[0] << "," << dimX[1] << "," << dimX[2] << "," << dimX[3] << "," << dimX[4] << ",";
|
||||
jk << dimW[0] << "," << dimW[1] << "," << dimW[2] << "," << dimW[3] << "," << dimW[4] << ",";
|
||||
jk << paddingd << paddingh << paddingw << "," << strided << strideh <<stridew << "," << dilationd << dilationh << dilationw << "," << groups << ".";
|
||||
auto iter = fwd_algo_cache.find(jk.to_string());
|
||||
|
||||
if (iter!=fwd_algo_cache.end()) algo = iter->second;
|
||||
else {
|
||||
if (fwd_algo_cache.size()>=max_cache_size) benchmark = false;
|
||||
if (benchmark) {
|
||||
size_t max_ws_size = 0;
|
||||
for (int i = 0; i < num_algos; i++) {
|
||||
size_t sz;
|
||||
cudnnStatus_t ret = cudnnGetConvolutionForwardWorkspaceSize(
|
||||
handle_, cudnnIdesc, cudnnFdesc, cudnnConvDesc,
|
||||
cudnnOdesc, algos[i], &sz);
|
||||
// continue if use too much workspace
|
||||
if (sz > mem_info.total_cuda_ram * max_workspace_ratio) continue;
|
||||
if (CUDNN_STATUS_SUCCESS == ret && sz > max_ws_size) max_ws_size = sz;
|
||||
}
|
||||
size_t allocation;
|
||||
void* ws = exe.temp_allocator->alloc(max_ws_size, allocation);
|
||||
checkCudaErrors(cudnnFindConvolutionForwardAlgorithmEx(
|
||||
handle_,
|
||||
cudnnIdesc, x->ptr<Tx>(),
|
||||
cudnnFdesc, w->ptr<Tw>(),
|
||||
cudnnConvDesc,
|
||||
cudnnOdesc, y->ptr<Ty>(),
|
||||
num_algos,
|
||||
&perf_count,
|
||||
perf_results,
|
||||
ws,
|
||||
max_ws_size));
|
||||
exe.temp_allocator->free(ws, max_ws_size, allocation);
|
||||
} else {
|
||||
checkCudaErrors(cudnnGetConvolutionForwardAlgorithm_v7(
|
||||
handle_,
|
||||
cudnnIdesc,
|
||||
cudnnFdesc,
|
||||
cudnnConvDesc,
|
||||
cudnnOdesc,
|
||||
num_algos,
|
||||
&perf_count,
|
||||
perf_results));
|
||||
}
|
||||
int best_algo_idx=-1;
|
||||
for (int i = 0; i < perf_count; i++)
|
||||
if (perf_results[i].status == CUDNN_STATUS_SUCCESS){
|
||||
best_algo_idx=i;
|
||||
break;
|
||||
}
|
||||
ASSERT(best_algo_idx!=-1);
|
||||
algo=perf_results[best_algo_idx].algo;
|
||||
if (benchmark) {
|
||||
fwd_algo_cache[jk.to_string()] = algo;
|
||||
if (fwd_algo_cache.size()==max_cache_size)
|
||||
LOGw << "forward_ algorithm cache is full";
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: warp work space
|
||||
void *workSpace = 0;
|
||||
size_t workSpaceSize;
|
||||
checkCudaErrors (cudnnGetConvolutionForwardWorkspaceSize(
|
||||
handle_, cudnnIdesc, cudnnFdesc, cudnnConvDesc,
|
||||
cudnnOdesc, algo, &workSpaceSize) );
|
||||
size_t allocation;
|
||||
if (workSpaceSize > 0) {
|
||||
workSpace = exe.temp_allocator->alloc(workSpaceSize, allocation);
|
||||
}
|
||||
float alpha=1, beta=0;
|
||||
checkCudaErrors(cudnnConvolutionForward(
|
||||
handle_,
|
||||
(void*)(&alpha),
|
||||
cudnnIdesc, x->ptr<Tx>(),
|
||||
cudnnFdesc, w->ptr<Tw>(),
|
||||
cudnnConvDesc,
|
||||
algo,
|
||||
workSpace, workSpaceSize,
|
||||
(void*)(&beta),
|
||||
cudnnOdesc, y->ptr<Ty>())
|
||||
);
|
||||
if (workSpace)
|
||||
exe.temp_allocator->free(workSpace, workSpaceSize, allocation);
|
||||
|
||||
checkCudaErrors(cudnnDestroyTensorDescriptor( cudnnIdesc ));
|
||||
checkCudaErrors(cudnnDestroyFilterDescriptor( cudnnFdesc ));
|
||||
checkCudaErrors(cudnnDestroyTensorDescriptor( cudnnOdesc ));
|
||||
checkCudaErrors(cudnnDestroyConvolutionDescriptor( cudnnConvDesc ));
|
||||
}
|
||||
#endif
|
||||
#endif // JIT
|
||||
|
||||
} // jittor
|
|
@ -0,0 +1,24 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2021 Jittor. All Rights Reserved.
|
||||
// Maintainers: Dun Liang <randonlang@gmail.com>.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#pragma once
|
||||
#include "op.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
struct CudnnConv3dOp : Op {
|
||||
Var* x, * w, * y;
|
||||
int strided, strideh, stridew, paddingd, paddingh, paddingw, dilationd, dilationh, dilationw, groups;
|
||||
string xformat;
|
||||
CudnnConv3dOp(Var* x, Var* w, int strided, int strideh, int stridew, int paddingd, int paddingh, int paddingw, int dilationd=1, int dilationh=1, int dilationw=1, int groups=1, string xformat="ncdhw");
|
||||
|
||||
const char* name() const override { return "cudnn_conv3d"; }
|
||||
VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override;
|
||||
void infer_shape() override;
|
||||
DECLARE_jit_run;
|
||||
};
|
||||
|
||||
} // jittor
|
|
@ -634,23 +634,6 @@ def kthvalue(input, k, dim=None, keepdim=False):
|
|||
|
||||
jt.Var.kthvalue = kthvalue
|
||||
|
||||
|
||||
def gather(x,dim,index):
|
||||
if dim<0:
|
||||
dim+=index.ndim
|
||||
x_shape = list(x.shape )
|
||||
i_shape = list(index.shape)
|
||||
assert i_shape[dim]>0
|
||||
assert x.ndim == index.ndim
|
||||
i_shape[dim]=x_shape[dim]
|
||||
assert i_shape == x_shape
|
||||
ins = []
|
||||
for i in range(index.ndim):
|
||||
ins.append(jt.index(index.shape,dim=i))
|
||||
ins[dim]=index
|
||||
return x.reindex(ins)
|
||||
jt.Var.gather = gather
|
||||
|
||||
def _prod(x,dim=0):
|
||||
x = jt.log(x)
|
||||
x = x.sum(dim=dim)
|
||||
|
@ -1255,3 +1238,7 @@ Examples::
|
|||
return x.reindex(x.shape, ids)
|
||||
|
||||
jt.Var.roll = roll
|
||||
|
||||
def safe_log(x):
|
||||
return jt.safe_clip(x, 1e-30, 1e30).log()
|
||||
jt.Var.safe_log = safe_log
|
||||
|
|
|
@ -21,18 +21,19 @@ from collections import OrderedDict
|
|||
from jittor.pool import *
|
||||
from jittor.optim import *
|
||||
from jittor.misc import _pair, _triple
|
||||
from jittor_utils import LOG
|
||||
|
||||
|
||||
def matmul_transpose(a, b):
|
||||
'''
|
||||
returns a * b^T
|
||||
'''
|
||||
assert len(a.shape) >= 2 and len(b.shape) == 2
|
||||
assert a.shape[-1] == b.shape[-1], (a.shape, b.shape)
|
||||
if len(a.shape)>2:
|
||||
if len(a.shape) != 2:
|
||||
aa = a.reshape((-1, a.shape[-1]))
|
||||
cc = matmul_transpose(aa, b)
|
||||
return cc.reshape(a.shape[:-1]+(-1,))
|
||||
assert len(a.shape) == 2 and len(b.shape) == 2
|
||||
|
||||
shape = list(a.shape)[:-1] + list(b.shape)
|
||||
a = a.broadcast(shape, [len(shape)-2])
|
||||
|
@ -639,7 +640,6 @@ class Conv1d(Module):
|
|||
|
||||
class Conv3d(Module):
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
|
||||
LOG.w("Optimizations of Conv3d are working in progress, it maybe slow currently.")
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size, kernel_size, kernel_size)
|
||||
|
@ -665,65 +665,7 @@ class Conv3d(Module):
|
|||
self.bias = None
|
||||
|
||||
def execute(self, x):
|
||||
if self.groups == 1:
|
||||
N,C,H,W,D = x.shape
|
||||
Kh, Kw, Kd = self.kernel_size
|
||||
assert C==self.in_channels
|
||||
oh = (H+self.padding[0]*2-Kh*self.dilation[0]+self.dilation[0]-1)//self.stride[0]+1
|
||||
ow = (W+self.padding[1]*2-Kw*self.dilation[1]+self.dilation[1]-1)//self.stride[1]+1
|
||||
od = (D+self.padding[2]*2-Kd*self.dilation[2]+self.dilation[2]-1)//self.stride[2]+1
|
||||
xx = x.reindex([N,self.out_channels,C,oh,ow,od,Kh,Kw,Kd], [
|
||||
'i0', # Nid
|
||||
'i2', # Cid
|
||||
f'i3*{self.stride[0]}-{self.padding[0]}+i6*{self.dilation[0]}', # Hid+Khid
|
||||
f'i4*{self.stride[1]}-{self.padding[1]}+i7*{self.dilation[1]}', # Wid+KWid
|
||||
f'i5*{self.stride[2]}-{self.padding[2]}+i8*{self.dilation[2]}', # Did+KDid
|
||||
])
|
||||
ww = self.weight.broadcast(xx.shape, [0,3,4,5])
|
||||
yy = xx*ww
|
||||
y = yy.sum([2,6,7,8]) # Kc, Kh, Kw, Kd
|
||||
if self.bias is not None:
|
||||
b = self.bias.broadcast(y.shape, [0,2,3,4])
|
||||
y = y + b
|
||||
return y
|
||||
else:
|
||||
N,C,H,W,D = x.shape
|
||||
Kh, Kw, Kd = self.kernel_size
|
||||
G = self.groups
|
||||
CpG = C // G # channels per group
|
||||
assert C==self.in_channels
|
||||
oc = self.out_channels
|
||||
oh = (H+self.padding[0]*2-Kh*self.dilation[0]+self.dilation[0]-1)//self.stride[0]+1
|
||||
ow = (W+self.padding[1]*2-Kw*self.dilation[1]+self.dilation[1]-1)//self.stride[1]+1
|
||||
od = (D+self.padding[2]*2-Kd*self.dilation[2]+self.dilation[2]-1)//self.stride[2]+1
|
||||
xx = x.reindex([N,G,oc//G,CpG,oh,ow,od,Kh,Kw,Kd], [
|
||||
'i0', # Nid
|
||||
f'i1*{CpG}+i3', # Gid
|
||||
f'i4*{self.stride[0]}-{self.padding[0]}+i7*{self.dilation[0]}', # Hid+Khid
|
||||
f'i5*{self.stride[1]}-{self.padding[1]}+i8*{self.dilation[1]}', # Wid+KWid
|
||||
f'i6*{self.stride[2]}-{self.padding[2]}+i9*{self.dilation[2]}', # Did+KDid
|
||||
])
|
||||
# w: [oc, CpG, Kh, Kw, Kd]
|
||||
ww = self.weight.reindex([N, G, oc//G, CpG, oh, ow, od, Kh, Kw, Kd], [
|
||||
f'i1*{oc//G}+i2',
|
||||
'i3',
|
||||
'i7',
|
||||
'i8',
|
||||
'i9'
|
||||
])
|
||||
ww.compile_options = xx.compile_options = {"G":G,"C":C}
|
||||
yy = xx*ww
|
||||
y = yy.reindex_reduce('add', [N, oc, oh, ow, od], [
|
||||
'i0',
|
||||
f'i1*{oc//G}+i2',
|
||||
'i4',
|
||||
'i5',
|
||||
'i6'
|
||||
])
|
||||
if self.bias is not None:
|
||||
b = self.bias.broadcast(y.shape, [0,2,3,4])
|
||||
y = y + b
|
||||
return y
|
||||
return conv3d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
|
||||
|
||||
def conv2d(x, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
|
||||
padding = _pair(padding)
|
||||
|
@ -789,13 +731,16 @@ def conv3d(x, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
|
|||
dilation = _triple(dilation)
|
||||
out_channels = weight.shape[0]
|
||||
|
||||
if jt.flags.use_cuda and jt.cudnn:
|
||||
return jt.cudnn.ops.cudnn_conv3d(x, weight, *stride, *padding, *dilation, groups)
|
||||
|
||||
if groups == 1:
|
||||
N,C,H,W,D = x.shape
|
||||
Kh, Kw, Kd = weight.shape[-3:]
|
||||
oh = (H+padding[0]*2-Kh*dilation[0]+dilation[0]-1)//stride[0]+1
|
||||
ow = (W+padding[1]*2-Kw*dilation[1]+dilation[1]-1)//stride[1]+1
|
||||
od = (D+padding[2]*2-Kd*dilation[2]+dilation[2]-1)//stride[2]+1
|
||||
xx = x.reindex([N,out_channels,C,oh,ow,od,Kh,Kw,Kd], [
|
||||
N,C,D,H,W = x.shape
|
||||
Kd, Kh, Kw = weight.shape[-3:]
|
||||
od = (D+padding[0]*2-Kd*dilation[0]+dilation[0]-1)//stride[0]+1
|
||||
oh = (H+padding[1]*2-Kh*dilation[1]+dilation[1]-1)//stride[1]+1
|
||||
ow = (W+padding[2]*2-Kw*dilation[2]+dilation[2]-1)//stride[2]+1
|
||||
xx = x.reindex([N,out_channels,C,od,oh,ow,Kd,Kh,Kw], [
|
||||
'i0', # Nid
|
||||
'i2', # Cid
|
||||
f'i3*{stride[0]}-{padding[0]}+i6*{dilation[0]}', # Hid+Khid
|
||||
|
@ -810,15 +755,15 @@ def conv3d(x, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
|
|||
y = y + b
|
||||
return y
|
||||
else:
|
||||
N,C,H,W,D = x.shape
|
||||
Kh, Kw, Kd = weight.shape[-3:]
|
||||
N,C,D,H,W = x.shape
|
||||
Kd, Kh, Kw = weight.shape[-3:]
|
||||
G = groups
|
||||
CpG = C // G # channels per group
|
||||
oc = out_channels
|
||||
oh = (H+padding[0]*2-Kh*dilation[0]+dilation[0]-1)//stride[0]+1
|
||||
ow = (W+padding[1]*2-Kw*dilation[1]+dilation[1]-1)//stride[1]+1
|
||||
od = (D+padding[2]*2-Kd*dilation[2]+dilation[2]-1)//stride[2]+1
|
||||
xx = x.reindex([N,G,oc//G,CpG,oh,ow,od,Kh,Kw,Kd], [
|
||||
od = (D+padding[0]*2-Kd*dilation[0]+dilation[0]-1)//stride[0]+1
|
||||
oh = (H+padding[1]*2-Kh*dilation[1]+dilation[1]-1)//stride[1]+1
|
||||
ow = (W+padding[2]*2-Kw*dilation[2]+dilation[2]-1)//stride[2]+1
|
||||
xx = x.reindex([N,G,oc//G,CpG,od,oh,ow,Kd,Kh,Kw], [
|
||||
'i0', # Nid
|
||||
f'i1*{CpG}+i3', # Gid
|
||||
f'i4*{stride[0]}-{padding[0]}+i7*{dilation[0]}', # Hid+Khid
|
||||
|
@ -835,7 +780,7 @@ def conv3d(x, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
|
|||
'i9'
|
||||
])
|
||||
yy = xx*ww
|
||||
y = yy.reindex_reduce('add', [N, oc, oh, ow, od], [
|
||||
y = yy.reindex_reduce('add', [N, oc, od, oh, ow], [
|
||||
'i0',
|
||||
f'i1*{oc//G}+i2',
|
||||
'i4',
|
||||
|
@ -906,6 +851,45 @@ class ConvTranspose(Module):
|
|||
y = y + b
|
||||
return y
|
||||
|
||||
class ConvTranspose3d(Module):
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride=1, \
|
||||
padding=0, output_padding=0, groups=1, bias=True, dilation=1):
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
|
||||
# added
|
||||
self.dilation = dilation
|
||||
self.group = groups
|
||||
assert groups==1, "Group conv not supported yet."
|
||||
|
||||
self.kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size, kernel_size, kernel_size)
|
||||
self.stride = stride if isinstance(stride, tuple) else (stride, stride, stride)
|
||||
self.dilation = dilation if isinstance(dilation, tuple) else (dilation, dilation, dilation)
|
||||
# added
|
||||
self.padding = padding if isinstance(padding, tuple) else (padding, padding, padding)
|
||||
self.real_padding = (
|
||||
self.dilation[0] * (self.kernel_size[0] - 1) - self.padding[0],
|
||||
self.dilation[1] * (self.kernel_size[1] - 1) - self.padding[1],
|
||||
self.dilation[2] * (self.kernel_size[2] - 1) - self.padding[2])
|
||||
self.output_padding = output_padding if isinstance (output_padding, tuple) else (output_padding, output_padding, output_padding)
|
||||
assert self.output_padding[0] < max(self.stride[0], self.dilation[0]) and \
|
||||
self.output_padding[1] < max(self.stride[1], self.dilation[1]) and \
|
||||
self.output_padding[2] < max(self.stride[2], self.dilation[2]), \
|
||||
"output padding must be smaller than max(stride, dilation)"
|
||||
|
||||
self.weight = init.invariant_uniform((in_channels, out_channels) + self.kernel_size, dtype="float")
|
||||
if bias:
|
||||
fan=1
|
||||
for i in self.weight.shape[1:]:
|
||||
fan *= i
|
||||
bound = 1 / math.sqrt(fan)
|
||||
self.bias = init.uniform([out_channels], dtype="float", low=-bound, high=bound)
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
def execute(self, x):
|
||||
return conv_transpose3d(x, self.weight, self.bias, self.stride, self.padding, self.output_padding, self.group, self.dilation)
|
||||
|
||||
def conv_transpose(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1):
|
||||
x = input
|
||||
N,C,H,W = x.shape
|
||||
|
@ -944,6 +928,49 @@ def conv_transpose(input, weight, bias=None, stride=1, padding=0, output_padding
|
|||
assert not bias, "Bias should be none or jittor var"
|
||||
return y
|
||||
|
||||
def conv_transpose3d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1):
|
||||
x = input
|
||||
N,C,D,H,W = x.shape
|
||||
i,o,d,h,w = weight.shape
|
||||
assert C==i
|
||||
assert groups==1, "Group conv not supported yet."
|
||||
stride = stride if isinstance(stride, tuple) else (stride, stride, stride)
|
||||
dilation = dilation if isinstance(dilation, tuple) else (dilation, dilation, dilation)
|
||||
# added
|
||||
padding = padding if isinstance(padding, tuple) else (padding, padding, padding)
|
||||
output_padding = output_padding if isinstance (output_padding, tuple) else (output_padding, output_padding, output_padding)
|
||||
assert output_padding[0] < max(stride[0], dilation[0]) and \
|
||||
output_padding[1] < max(stride[1], dilation[1]) and \
|
||||
output_padding[2] < max(stride[2], dilation[2]), \
|
||||
"output padding must be smaller than max(stride, dilation)"
|
||||
|
||||
stride_d, stride_h, stride_w = stride
|
||||
padding_d, padding_h, padding_w = padding
|
||||
dilation_d, dilation_h, dilation_w = dilation
|
||||
|
||||
d_out = (D-1) * stride_d + output_padding[0] - 2*padding_d + 1 + (d-1)*dilation_d
|
||||
h_out = (H-1) * stride_h + output_padding[1] - 2*padding_h + 1 + (h-1)*dilation_h
|
||||
w_out = (W-1) * stride_w + output_padding[2] - 2*padding_w + 1 + (w-1)*dilation_w
|
||||
out_shape = (N, o, d_out, h_out, w_out)
|
||||
if jt.flags.use_cuda and jt.cudnn:
|
||||
return jt.cudnn.ops.cudnn_conv3d_backward_x(weight, x, *out_shape[2:], *stride, *padding, *dilation, groups)
|
||||
shape = (N, i, o, D, H, W, d, h, w)
|
||||
xx = x.broadcast(shape, (2, 6, 7, 8)) # i,h,w
|
||||
ww = weight.broadcast(shape, (0, 3, 4, 5)) # N,H,W
|
||||
y = (ww*xx).reindex_reduce("add", out_shape, [
|
||||
'i0', # N
|
||||
'i2', # o
|
||||
f'i3*{stride_d}-{padding_d}+i6*{dilation_d}', # Did+Kdid
|
||||
f'i4*{stride_h}-{padding_h}+i7*{dilation_h}', # Hid+Khid
|
||||
f'i5*{stride_w}-{padding_w}+i8*{dilation_w}', # Wid+KWid
|
||||
])
|
||||
if isinstance(bias, jt.Var):
|
||||
b = bias.broadcast(y.shape, [0,2,3,4])
|
||||
y = y + b
|
||||
else:
|
||||
assert not bias, "Bias should be none or jittor var"
|
||||
return y
|
||||
|
||||
conv_transpose2d = conv_transpose
|
||||
|
||||
def pad(x,padding, mode='constant', value=0):
|
||||
|
@ -1286,7 +1313,7 @@ def linspace_from_neg_one(grid,num_steps,align_corners):
|
|||
return jt.array(ra,dtype=grid.dtype)
|
||||
|
||||
def make_base_grid_4D(theta,N,C,H,W,align_corners):
|
||||
base_grid = jt.zeros((N, H, W, 3), dtype=theta.dtype);
|
||||
base_grid = jt.zeros((N, H, W, 3), dtype=theta.dtype)
|
||||
base_grid[...,0] = linspace_from_neg_one(theta, W, align_corners)
|
||||
base_grid[...,1] = jt.unsqueeze(linspace_from_neg_one(theta, H, align_corners),-1)
|
||||
base_grid[...,-1] = 1
|
||||
|
|
|
@ -0,0 +1,31 @@
|
|||
# wget https://github.com/oneapi-src/oneDNN/archive/refs/tags/v2.2.zip
|
||||
# extract zip
|
||||
# cd to root folder
|
||||
|
||||
mkdir -p build
|
||||
cd build
|
||||
make clean
|
||||
export CC=aarch64-linux-gnu-gcc-8
|
||||
export CXX=aarch64-linux-gnu-g++-8
|
||||
cmake .. \
|
||||
-DCMAKE_SYSTEM_NAME=Linux \
|
||||
-DCMAKE_SYSTEM_PROCESSOR=AARCH64 \
|
||||
-DCMAKE_LIBRARY_PATH=/usr/aarch64-linux-gnu/lib \
|
||||
-DCMAKE_BUILD_TYPE=Release
|
||||
# -DCMAKE_SHARED_LINKER_FLAGS=' -lm ' \
|
||||
make -j8
|
||||
|
||||
name=dnnl_lnx_2.2.0_cpu_gomp_aarch64
|
||||
mkdir -p $name
|
||||
cp -r ../include ./$name/
|
||||
mkdir -p ./$name/lib
|
||||
cp ./src/libmkldnn.so ./$name/lib/libmkldnn.so
|
||||
cp -r ../examples ./$name/
|
||||
cp ./include/oneapi/dnnl/* ./$name/include/oneapi/dnnl/
|
||||
|
||||
tar -acvf $name.tgz ./$name/
|
||||
|
||||
rsync -avPu $name.tgz jittor-web:Documents/jittor-blog/assets/
|
||||
ssh jittor-web Documents/jittor-blog.git/hooks/post-update
|
||||
echo "https://cg.cs.tsinghua.edu.cn/jittor/assets/$name.tgz"
|
||||
md5sum $name.tgz
|
|
@ -486,9 +486,11 @@ void Executor::run_sync(vector<Var*> vars, bool device_sync) {
|
|||
if (use_cuda)
|
||||
checkCudaErrors(cudaDeviceSynchronize());
|
||||
#endif
|
||||
for (Var* var : op->outputs())
|
||||
check_nan(var);
|
||||
}
|
||||
#ifdef JT_CHECK_NAN
|
||||
for (Var* var : op->outputs())
|
||||
check_nan(var);
|
||||
#endif
|
||||
LOGvvv << "Finished Op(" >> op->name() << rid >>
|
||||
"/" >> queue.size() >> ") output:" << op->outputs();
|
||||
if (is_fused_op) {
|
||||
|
|
|
@ -238,6 +238,21 @@ struct NanoVector {
|
|||
v[i] = at(i);
|
||||
return v;
|
||||
}
|
||||
|
||||
inline void _unpack(int i) {
|
||||
return;
|
||||
}
|
||||
|
||||
template<class... Args>
|
||||
void _unpack(int i, int& x, Args&&... args) {
|
||||
x = this->operator[](i);
|
||||
_unpack(i+1, std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
template<class... Args>
|
||||
void unpack(Args&&... args) {
|
||||
_unpack(0, std::forward<Args>(args)...);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
|
|
|
@ -104,6 +104,8 @@ int OpCompiler::total_member_count() {
|
|||
// array need a extra local var
|
||||
if (op->ops[i]->name()==string("array"))
|
||||
member_count += 1;
|
||||
if (op->ops[i]->name()==string("safe_clip"))
|
||||
member_count += 2;
|
||||
member_count += v.size();
|
||||
i += 1;
|
||||
}
|
||||
|
@ -826,11 +828,15 @@ string OpCompiler::__get_fused_src(
|
|||
const unordered_set<string> members = {
|
||||
"x", "y", "z", "cond", "output", "extras"
|
||||
};
|
||||
const unordered_set<string> scalar_members = {
|
||||
"left", "right"
|
||||
};
|
||||
const unordered_set<string> unchanged = {
|
||||
"for", "const", "auto", "get_random_engine",
|
||||
"int", "float", "bool", "CHECK", "STRINGIZE",
|
||||
"void", "__restrict__", "if", "true", "false",
|
||||
"Op", "Var", "Node", "itof", "assert", "ASSERT"
|
||||
"Op", "Var", "Node", "itof", "assert", "ASSERT",
|
||||
"float64"
|
||||
};
|
||||
auto not_change = [&](const string& s) -> bool {
|
||||
if (unchanged.count(s)) return true;
|
||||
|
@ -941,7 +947,8 @@ string OpCompiler::__get_fused_src(
|
|||
while (l<src.size() && isvar(src[l])) l++;
|
||||
auto var = src.substr(j, l-j);
|
||||
if (var[0] == ':' || isdigit(var[0]) || not_change(var) || src[j-1]=='.' || src[j-1]=='>') {} else
|
||||
if (members.count(var)) {
|
||||
if (members.count(var) || scalar_members.count(var)) {
|
||||
bool is_member = members.count(var);
|
||||
string arg_name = "op" + S(oi) + "_" + var;
|
||||
if (l<src.size() && src[l]=='[') {
|
||||
// handle extras[...]
|
||||
|
@ -964,7 +971,8 @@ string OpCompiler::__get_fused_src(
|
|||
" = (("+name3+"Op*)(ops[" + S(oi) + "]))->" + var;
|
||||
fused_kernel_args += ";\n";
|
||||
kernel_args.insert(arg_name);
|
||||
op_members[oi].push_back(arg_name);
|
||||
if (is_member)
|
||||
op_members[oi].push_back(arg_name);
|
||||
}
|
||||
fused_kernel += arg_name;
|
||||
j = l-1;
|
||||
|
|
|
@ -0,0 +1,47 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2021 Jittor. All Rights Reserved.
|
||||
// Maintainers: Dun Liang <randonlang@gmail.com>.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#include <cmath>
|
||||
#include "var.h"
|
||||
#include "ops/safe_clip_op.h"
|
||||
#include "ops/op_register.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
#ifndef JIT
|
||||
|
||||
SafeClipOp::SafeClipOp(Var* x, float64 left, float64 right) : x(x), left(left), right(right) {
|
||||
flags.set(NodeFlags::_cpu);
|
||||
flags.set(NodeFlags::_cuda);
|
||||
set_type(OpType::element);
|
||||
y = create_output(nullptr, x->dtype());
|
||||
}
|
||||
|
||||
VarPtr SafeClipOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
||||
return dout;
|
||||
}
|
||||
|
||||
void SafeClipOp::infer_shape() {
|
||||
y->set_shape(x->shape);
|
||||
}
|
||||
|
||||
void SafeClipOp::jit_prepare(JK& jk) {
|
||||
jk << _CS("[Tx:") << x->dtype() <<']';
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
void SafeClipOp::jit_run() {
|
||||
auto* __restrict__ xp = x->ptr<Tx>();
|
||||
Tx left_value = (Tx)std::max((float64)std::numeric_limits<Tx>::lowest(), left);
|
||||
Tx right_value = (Tx)std::min((float64)std::numeric_limits<Tx>::max(), right);
|
||||
auto* __restrict__ yp = y->ptr<Tx>();
|
||||
index_t num = y->num;
|
||||
for (index_t i=0; i<num; i++)
|
||||
yp[i] = xp[i] < left_value ? left_value : (xp[i] > right_value ? right_value : xp[i]);
|
||||
}
|
||||
#endif // JIT
|
||||
|
||||
} // jittor
|
|
@ -0,0 +1,33 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2021 Jittor. All Rights Reserved.
|
||||
// Maintainers: Dun Liang <randonlang@gmail.com>.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#pragma once
|
||||
#include "op.h"
|
||||
|
||||
|
||||
namespace jittor {
|
||||
|
||||
struct SafeClipOp : Op {
|
||||
Var* x, * y;
|
||||
float64 left, right;
|
||||
/** Safe clip value to a range, and keep
|
||||
the gradient pass thought.
|
||||
|
||||
* [in] x: input value
|
||||
* [in] left: float64 clip min value.
|
||||
* [in] right: float64 clip max value.
|
||||
|
||||
*/
|
||||
// @pybind(safe_clip)
|
||||
SafeClipOp(Var* x, float64 left, float64 right);
|
||||
|
||||
const char* name() const override { return "safe_clip"; }
|
||||
VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override;
|
||||
void infer_shape() override;
|
||||
DECLARE_jit_run;
|
||||
};
|
||||
|
||||
} // jittor
|
|
@ -67,6 +67,10 @@ void LoopToFuncPass::run() {
|
|||
args.push_back(d.get());
|
||||
continue;
|
||||
}
|
||||
if (endswith(d->attrs["lvalue"], "_value")) {
|
||||
args.push_back(d.get());
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
func->push_back(d->clone());
|
||||
|
|
|
@ -8,6 +8,7 @@
|
|||
#include <streambuf>
|
||||
#include "misc/hash.h"
|
||||
#include "utils/cache_compile.h"
|
||||
#include "utils/str_utils.h"
|
||||
|
||||
namespace jittor {
|
||||
namespace jit_compiler {
|
||||
|
@ -137,7 +138,7 @@ size_t skip_comments(const string& src, size_t i) {
|
|||
return i;
|
||||
}
|
||||
|
||||
void process(string src, vector<string>& input_names) {
|
||||
void process(string src, vector<string>& input_names, string& cmd) {
|
||||
for (size_t i=0; i<src.size(); i++) {
|
||||
i = skip_comments(src, i);
|
||||
if (i>=src.size()) break;
|
||||
|
@ -159,6 +160,20 @@ void process(string src, vector<string>& input_names) {
|
|||
input_names.push_back(inc);
|
||||
}
|
||||
}
|
||||
if (l-k>2 && src[k] == 'J' && src[k+1] == 'T' && j-i==6 && src.substr(i,j-i) == "#ifdef") {
|
||||
auto inc = src.substr(k, l-k);
|
||||
auto env = getenv(inc.c_str());
|
||||
if (env && string(env)!="0") {
|
||||
string dflag = " -D"+inc+"="+string(env)+" -o ";
|
||||
if (cmd.find(dflag) == string::npos) {
|
||||
// -D flags should insert before -o flag
|
||||
auto cmds = split(cmd, " -o ", 2);
|
||||
if (cmds.size() == 2) {
|
||||
cmd = cmds[0] + dflag + cmds[1];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
i=l;
|
||||
}
|
||||
}
|
||||
|
@ -173,12 +188,6 @@ bool cache_compile(const string& cmd, const string& cache_path, const string& ji
|
|||
bool ran = false;
|
||||
output_cache_key = read_all(output_name+".key");
|
||||
string cd_cmd = cache_path.size() ? "cd " + cache_path + " && " + cmd : cmd;
|
||||
if (output_cache_key.size() == 0) {
|
||||
LOGvv << "Cache key of" << output_name << "not found.";
|
||||
LOGvvv << "Run cmd:" << cmd;
|
||||
system_with_check(cd_cmd.c_str());
|
||||
ran = true;
|
||||
}
|
||||
string cache_key = cmd;
|
||||
cache_key += "\n";
|
||||
unordered_set<string> processed;
|
||||
|
@ -192,7 +201,7 @@ bool cache_compile(const string& cmd, const string& cache_path, const string& ji
|
|||
ASSERT(src.size()) << "Source read failed:" << input_names[i];
|
||||
auto hash = S(hash64(src));
|
||||
vector<string> new_names;
|
||||
process(src, new_names);
|
||||
process(src, new_names, cd_cmd);
|
||||
for (auto& name : new_names) {
|
||||
string full_name;
|
||||
if (name.substr(0, 4) == "jit/" || name.substr(0, 4) == "gen/")
|
||||
|
@ -222,9 +231,15 @@ bool cache_compile(const string& cmd, const string& cache_path, const string& ji
|
|||
cache_key += hash;
|
||||
cache_key += "\n";
|
||||
}
|
||||
if (output_cache_key.size() == 0) {
|
||||
LOGvv << "Cache key of" << output_name << "not found.";
|
||||
LOGvvv << "Run cmd:" << cd_cmd;
|
||||
system_with_check(cd_cmd.c_str());
|
||||
ran = true;
|
||||
}
|
||||
if (output_cache_key.size() != 0 && output_cache_key != cache_key) {
|
||||
LOGvv << "Cache key of" << output_name << "changed.";
|
||||
LOGvvv << "Run cmd:" << cmd;
|
||||
LOGvvv << "Run cmd:" << cd_cmd;
|
||||
system_with_check(cd_cmd.c_str());
|
||||
ran = true;
|
||||
}
|
||||
|
@ -296,7 +311,8 @@ void test_find_nams_error(string cmd) {
|
|||
|
||||
void test_process(string src, vector<string> files) {
|
||||
vector<string> ifiles;
|
||||
jittor::jit_compiler::process(src, ifiles);
|
||||
string cmd;
|
||||
jittor::jit_compiler::process(src, ifiles, cmd);
|
||||
CHECK(files.size() == ifiles.size());
|
||||
for (size_t i=0; i<files.size(); i++)
|
||||
CHECKop(files[i],==,ifiles[i]);
|
||||
|
|
|
@ -322,6 +322,34 @@ but you can hot fix it by this command:
|
|||
)";
|
||||
}
|
||||
|
||||
static inline void check_cuda_gcc_version(const string& output) {
|
||||
/* if such error occur:
|
||||
error: identifier "__is_assignable" is undefined
|
||||
this means your gcc version is not match with nvcc,
|
||||
for example, nvcc 10 support gcc<=7, nvcc 11 support gcc<=9,
|
||||
|
||||
https://gist.github.com/ax3l/9489132
|
||||
*/
|
||||
string pat = "__is_assignable";
|
||||
auto id = output.find(pat);
|
||||
if (id == string::npos) return;
|
||||
LOGf << output << R"(
|
||||
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
|
||||
Dear user, your nvcc and gcc version are still not match
|
||||
after dirty hack, your should install the correct version of g++
|
||||
or nvcc, for example, nvcc 10 support g++<=7, nvcc 11 support g++<=9,
|
||||
here is the NVCC Compatibility Matrix:
|
||||
https://gist.github.com/ax3l/9489132
|
||||
Please install correct version of gcc, for example:
|
||||
>>> sudo apt install g++-7
|
||||
After your g++ is installed, using enviroment variable `cc_path` to
|
||||
tell jittor use the correct version of g++, for example:
|
||||
>>> cc_path='g++-7' python3.7 -m jittor.test.test_core
|
||||
If you still have problems, please contact us:
|
||||
https://github.com/Jittor/jittor/issues
|
||||
)";
|
||||
}
|
||||
|
||||
int system_popen(const char* cmd) {
|
||||
char buf[BUFSIZ];
|
||||
string cmd2;
|
||||
|
@ -342,6 +370,7 @@ int system_popen(const char* cmd) {
|
|||
}
|
||||
if (ret) {
|
||||
check_cuda_unsupport_version(output);
|
||||
check_cuda_gcc_version(output);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
|
|
@ -25,15 +25,15 @@ bool endswith(const string& a, const string& b) {
|
|||
|
||||
vector<string> split(const string& s, const string& sep, int max_split) {
|
||||
vector<string> ret;
|
||||
int pos = -1, pos_next;
|
||||
int pos = 0, pos_next;
|
||||
while (1) {
|
||||
pos_next = s.find(sep, pos+1);
|
||||
pos_next = s.find(sep, pos);
|
||||
if (pos_next == (int)string::npos || (int)ret.size() == max_split-1) {
|
||||
ret.push_back(s.substr(pos+sep.size()));
|
||||
ret.push_back(s.substr(pos));
|
||||
return ret;
|
||||
}
|
||||
ret.push_back(s.substr(pos+sep.size(), pos_next-pos-sep.size()));
|
||||
pos = pos_next;
|
||||
ret.push_back(s.substr(pos, pos_next-pos));
|
||||
pos = pos_next + sep.size();
|
||||
}
|
||||
ASSERT(max_split==0);
|
||||
return ret;
|
||||
|
|
|
@ -128,7 +128,58 @@ class TestCudnnConvOp(unittest.TestCase):
|
|||
check([10,3,100,100], [5,3,3,3], stride=2, padding=0, dilation=1)
|
||||
check([10,4,40,50], [5,4,5,5], stride=1, padding=1, dilation=1)
|
||||
check([10,4,40,50], [5,4,4,4], stride=3, padding=1, dilation=1)
|
||||
|
||||
|
||||
def test_conv3d(self):
|
||||
def check(xshape, wshape, stride=(1,1,1), padding=(0,0,0), dilation=(1,1,1), group=1):
|
||||
with jt.flag_scope(use_cuda=1):
|
||||
x = jt.random(xshape)
|
||||
w = jt.random(wshape)
|
||||
# y = jt.cudnn.ops.cudnn_conv3d(x, w, *stride, *padding, *dilation, group)
|
||||
y = jt.nn.conv3d(x, w, None, stride, padding, dilation, group)
|
||||
masky = jt.rand_like(y)
|
||||
dx, dw = jt.grad(masky*y, [x, w])
|
||||
|
||||
y2 = jt.nn.conv3d(x, w, None, stride, padding, dilation, group)
|
||||
dx2, dw2 = jt.grad(masky*y2, [x, w])
|
||||
np.testing.assert_allclose(y.data, y2.data)
|
||||
np.testing.assert_allclose(dx.data, dx2.data, rtol=1e-5, atol=1e-3)
|
||||
np.testing.assert_allclose(dw.data, dw2.data, rtol=1e-5, atol=1e-3)
|
||||
|
||||
check((2,4,10,10,10), (5,4,3,3,3), (1,1,1), (1,1,1))
|
||||
check((2,4,10,10,10), (5,4,3,3,3), (2,2,2), (1,1,1))
|
||||
check((2,4,10,10,10), (5,4,3,3,3), (2,2,2), (0,0,0))
|
||||
check((2,4,10,10,10), (5,4,3,3,3), (1,2,3), (0,0,0))
|
||||
check((2,4,10,10,10), (5,4,3,4,5), (1,1,1), (1,1,1))
|
||||
check((2,4,10,10,10), (5,4,3,4,5), (1,2,3), (0,0,0))
|
||||
check((2,4,10,10,10), (5,4,3,3,3), (1,1,1), (1,1,1), dilation=(1,2,3))
|
||||
|
||||
def test_conv_transpose3d(self):
|
||||
jt.set_global_seed(10)
|
||||
def check(xshape, wshape, stride=(1,1,1), padding=(0,0,0), dilation=(1,1,1), group=1):
|
||||
with jt.flag_scope(use_cuda=1):
|
||||
x = jt.random(xshape)
|
||||
w = jt.random(wshape)
|
||||
|
||||
y2 = jt.nn.conv_transpose3d(x, w, None, stride, padding, 0, group, dilation)
|
||||
|
||||
with jt.flag_scope(use_cuda=1):
|
||||
# y = jt.cudnn.ops.cudnn_conv3d_backward_x(w, x, *y2.shape[2:], *stride, *padding, *dilation, group)
|
||||
y = jt.nn.conv_transpose3d(x, w, None, stride, padding, 0, group, dilation)
|
||||
masky = jt.rand_like(y)
|
||||
dx, dw = jt.grad(masky*y, [x, w])
|
||||
|
||||
dx2, dw2 = jt.grad(masky*y2, [x, w])
|
||||
np.testing.assert_allclose(y.data, y2.data, rtol=1e-6, atol=1e-4)
|
||||
np.testing.assert_allclose(dx.data, dx2.data, rtol=1e-6, atol=1e-4)
|
||||
np.testing.assert_allclose(dw.data, dw2.data, rtol=1e-5, atol=1e-3)
|
||||
|
||||
check((2,5,10,10,10), (5,4,3,3,3), (1,1,1), (1,1,1))
|
||||
check((2,5,10,10,10), (5,4,3,3,3), (2,2,2), (1,1,1))
|
||||
check((2,5,10,10,10), (5,4,3,3,3), (2,2,2), (0,0,0))
|
||||
check((2,5,10,10,10), (5,4,3,3,3), (1,2,3), (0,0,0))
|
||||
check((2,5,10,10,10), (5,4,3,4,5), (1,1,1), (1,1,1))
|
||||
check((2,5,10,10,10), (5,4,3,4,5), (1,2,3), (0,0,0))
|
||||
check((2,5,10,10,10), (5,4,3,3,3), (1,1,1), (1,1,1), dilation=(1,2,3))
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
@ -161,6 +161,16 @@ class TestDatasetSeed(unittest.TestCase):
|
|||
for i in range(len(d)):
|
||||
for j in range(i+1, len(d)):
|
||||
assert not np.allclose(dd[i], dd[j])
|
||||
|
||||
def test_cifar(self):
|
||||
from jittor.dataset.cifar import CIFAR10
|
||||
a = CIFAR10()
|
||||
a.set_attrs(batch_size=16)
|
||||
for imgs, labels in a:
|
||||
print(imgs.shape, labels.shape)
|
||||
assert imgs.shape == [16,32,32,3,]
|
||||
assert labels.shape == [16,]
|
||||
break
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -31,14 +31,14 @@ class TestOneHot(unittest.TestCase):
|
|||
probs,probs2 = np.random.uniform(0,1,(10)), np.random.uniform(0,1,(10))
|
||||
probs,probs2 = probs / probs.sum(),probs2 / probs2.sum()
|
||||
import torch
|
||||
jc, jc2 = jd.OneHotCategorical(jt.array(probs).reshape(1,-1)),jd.OneHotCategorical(jt.array(probs2).reshape(1,-1))
|
||||
jc, jc2 = jd.OneHotCategorical(jt.array(probs)),jd.OneHotCategorical(jt.array(probs2))
|
||||
tc, tc2 = torch.distributions.OneHotCategorical(torch.tensor(probs)),torch.distributions.OneHotCategorical(torch.tensor(probs2))
|
||||
assert np.allclose(jc.entropy().data,tc.entropy().numpy())
|
||||
x = np.zeros((4,10))
|
||||
for _ in range(4):
|
||||
nx = np.random.randint(0,9)
|
||||
x[_,nx] = 1
|
||||
assert np.allclose(jc.log_prob(jt.array(x)),tc.log_prob(torch.tensor(x)))
|
||||
np.testing.assert_allclose(jc.log_prob(jt.array(x)),tc.log_prob(torch.tensor(x)), atol=1e-5)
|
||||
assert np.allclose(jd.kl_divergence(jc,jc2),torch.distributions.kl_divergence(tc,tc2))
|
||||
|
||||
def test_cate(self):
|
||||
|
@ -67,17 +67,55 @@ class TestOneHot(unittest.TestCase):
|
|||
tn2 = torch.distributions.Normal(mu2,sigma2)
|
||||
assert np.allclose(jd.kl_divergence(jn,jn2).data,torch.distributions.kl_divergence(tn,tn2).numpy())
|
||||
|
||||
def test_categorical(self):
|
||||
def test_categorical1(self):
|
||||
import torch
|
||||
for _ in range(4):
|
||||
probs,probs2 = np.random.uniform(0,1,(10)), np.random.uniform(0,1,(10))
|
||||
probs,probs2 = probs / probs.sum(),probs2 / probs2.sum()
|
||||
jc, jc2 = jd.Categorical(jt.array(probs).reshape(1,-1)),jd.Categorical(jt.array(probs2).reshape(1,-1))
|
||||
jc, jc2 = jd.Categorical(jt.array(probs)),jd.Categorical(jt.array(probs2))
|
||||
tc, tc2 = torch.distributions.Categorical(torch.tensor(probs)),torch.distributions.Categorical(torch.tensor(probs2))
|
||||
assert np.allclose(jc.entropy().data, tc.entropy().numpy()), (jc.entropy().data, tc.entropy().numpy())
|
||||
x = np.random.randint(0,10,(4))
|
||||
assert np.allclose(jc.log_prob(x), tc.log_prob(torch.tensor(x)))
|
||||
np.testing.assert_allclose(jc.log_prob(x), tc.log_prob(torch.tensor(x)), atol=1e-5)
|
||||
assert np.allclose(jd.kl_divergence(jc,jc2),torch.distributions.kl_divergence(tc,tc2))
|
||||
|
||||
def test_categorical2(self):
|
||||
def check(prob_shape, sample_shape):
|
||||
import torch
|
||||
for _ in range(4):
|
||||
probs,probs2 = np.random.uniform(0,1,prob_shape), np.random.uniform(0,1, prob_shape)
|
||||
|
||||
jc, jc2 = jd.Categorical(jt.array(probs)),jd.Categorical(jt.array(probs2))
|
||||
tc, tc2 = torch.distributions.Categorical(torch.tensor(probs)),torch.distributions.Categorical(torch.tensor(probs2))
|
||||
assert np.allclose(jc.entropy().data, tc.entropy().numpy()), (jc.entropy().data, tc.entropy().numpy())
|
||||
x1 = jc.sample(sample_shape)
|
||||
x2 = tc.sample(sample_shape)
|
||||
assert tuple(x1.shape) == tuple(x2.shape)
|
||||
x = np.random.randint(0,prob_shape[-1], tuple(x1.shape))
|
||||
np.testing.assert_allclose(jc.log_prob(x), tc.log_prob(torch.tensor(x)), atol=1e-5)
|
||||
np.testing.assert_allclose(jd.kl_divergence(jc,jc2), torch.distributions.kl_divergence(tc,tc2), atol=1e-5)
|
||||
check((10,), (4,))
|
||||
check((2,3), (4,))
|
||||
check((3,4,5,6), (2,))
|
||||
|
||||
def test_one_hot_categorical2(self):
|
||||
def check(prob_shape, sample_shape):
|
||||
import torch
|
||||
for _ in range(4):
|
||||
probs,probs2 = np.random.uniform(0,1,prob_shape), np.random.uniform(0,1, prob_shape)
|
||||
|
||||
jc, jc2 = jd.OneHotCategorical(jt.array(probs)),jd.OneHotCategorical(jt.array(probs2))
|
||||
tc, tc2 = torch.distributions.OneHotCategorical(torch.tensor(probs)),torch.distributions.OneHotCategorical(torch.tensor(probs2))
|
||||
assert np.allclose(jc.entropy().data, tc.entropy().numpy()), (jc.entropy().data, tc.entropy().numpy())
|
||||
x1 = jc.sample(sample_shape)
|
||||
x2 = tc.sample(sample_shape)
|
||||
assert tuple(x1.shape) == tuple(x2.shape)
|
||||
x = np.random.randint(0,prob_shape[-1], tuple(x1.shape))
|
||||
np.testing.assert_allclose(jc.log_prob(x), tc.log_prob(torch.tensor(x)), atol=1e-5)
|
||||
np.testing.assert_allclose(jd.kl_divergence(jc,jc2), torch.distributions.kl_divergence(tc,tc2), atol=1e-5)
|
||||
check((10,), (4,))
|
||||
check((2,3), (4,))
|
||||
check((3,4,5,6), (2,))
|
||||
|
||||
def test_uniform(self):
|
||||
import torch
|
||||
|
@ -98,11 +136,11 @@ class TestOneHot(unittest.TestCase):
|
|||
prob, prob2 = np.random.uniform(0,1), np.random.uniform(0,1)
|
||||
jg, jg2 = jd.Geometric(prob),jd.Geometric(prob2)
|
||||
tg, tg2 = torch.distributions.Geometric(prob),torch.distributions.Geometric(prob2)
|
||||
assert np.allclose(jg.entropy().data,tg.entropy().numpy())
|
||||
np.testing.assert_allclose(jg.entropy().data,tg.entropy().numpy(), atol=1e-4)
|
||||
x = np.random.randint(1,10)
|
||||
assert np.allclose(jg.log_prob(x),tg.log_prob(torch.tensor(x)))
|
||||
np.testing.assert_allclose(jg.log_prob(x),tg.log_prob(torch.tensor(x)), atol=1e-4)
|
||||
# print(jd.kl_divergence(jg,jg2),torch.distributions.kl_divergence(tg,tg2))
|
||||
assert np.allclose(jd.kl_divergence(jg,jg2),torch.distributions.kl_divergence(tg,tg2))
|
||||
np.testing.assert_allclose(jd.kl_divergence(jg,jg2),torch.distributions.kl_divergence(tg,tg2), atol=1e-4)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -347,5 +347,11 @@ class TestMatmul(unittest.TestCase):
|
|||
def test_matmul_example2_cuda(self):
|
||||
self.test_matmul_example2()
|
||||
|
||||
def test_linear1d(self):
|
||||
linear = jt.nn.Linear(10,20)
|
||||
a = jt.random((10,))
|
||||
b = linear(a)
|
||||
assert b.shape == (20,)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -38,7 +38,7 @@ class TestResnet(unittest.TestCase):
|
|||
@classmethod
|
||||
def setUpClass(self):
|
||||
# hyper-parameters
|
||||
self.batch_size = 100
|
||||
self.batch_size = int(os.environ.get("TEST_BATCH_SIZE", "100"))
|
||||
self.weight_decay = 0.0001
|
||||
self.momentum = 0.9
|
||||
self.learning_rate = 0.1
|
||||
|
|
|
@ -257,14 +257,14 @@ class Tester(unittest.TestCase):
|
|||
expect = input_data.transpose(2,0,1)
|
||||
self.assertTrue(np.allclose(expect, output), f"{expect.shape}\n{output.shape}")
|
||||
|
||||
ndarray = np.random.randint(low=0, high=255, size=(height, width, channels)).astype(np.uint8)
|
||||
ndarray = np.random.randint(low=0, high=255, size=(channels, height, width)).astype(np.uint8)
|
||||
output = trans(ndarray)
|
||||
expected_output = ndarray.transpose((2, 0, 1)) / 255.0
|
||||
self.assertTrue(np.allclose(output, expected_output))
|
||||
expected_output = ndarray / 255.0
|
||||
np.testing.assert_allclose(output, expected_output)
|
||||
|
||||
ndarray = np.random.rand(height, width, channels).astype(np.float32)
|
||||
ndarray = np.random.rand(channels, height, width).astype(np.float32)
|
||||
output = trans(ndarray)
|
||||
expected_output = ndarray.transpose((2, 0, 1))
|
||||
expected_output = ndarray
|
||||
self.assertTrue(np.allclose(output, expected_output))
|
||||
|
||||
# separate test for mode '1' PIL images
|
||||
|
|
|
@ -69,6 +69,13 @@ class TestUnaryOp(unittest.TestCase):
|
|||
b1 = b.sigmoid().numpy()
|
||||
assert np.isnan(b1).any() == False
|
||||
|
||||
def test_safe_clip(self):
|
||||
a = jt.array([-1.0,0,0.4,1,2,3])
|
||||
b = a.safe_clip(0.1, 0.5)
|
||||
assert np.allclose(b.data, [0.1,0.1,0.4,0.5,0.5,0.5])
|
||||
da = jt.grad(b, a)
|
||||
assert (da.data == 1).all()
|
||||
|
||||
class TestUnaryOpCuda(TestUnaryOp, test_cuda(2)):
|
||||
pass
|
||||
|
||||
|
|
|
@ -389,7 +389,7 @@ class CenterCrop:
|
|||
|
||||
def to_tensor(pic):
|
||||
"""
|
||||
Function for turning Image.Image to np.array.
|
||||
Function for turning Image.Image to np.array with CHW format.
|
||||
|
||||
Args::
|
||||
|
||||
|
@ -414,14 +414,13 @@ def to_tensor(pic):
|
|||
if _is_numpy(pic):
|
||||
# handle numpy array
|
||||
if pic.ndim == 2:
|
||||
pic = pic[:, :, None]
|
||||
pic = pic[None, :, :]
|
||||
|
||||
img = pic.transpose((2, 0, 1))
|
||||
# backward compatibility
|
||||
if img.dtype == 'uint8':
|
||||
return np.float32(img) * np.float32(1/255.0)
|
||||
if pic.dtype == 'uint8':
|
||||
return np.float32(pic) * np.float32(1/255.0)
|
||||
else:
|
||||
return img
|
||||
return pic
|
||||
|
||||
# handle PIL Image
|
||||
if pic.mode == 'I':
|
||||
|
@ -499,7 +498,7 @@ def _to_jittor_array(pic):
|
|||
def to_pil_image(pic, mode=None):
|
||||
"""Convert a tensor or an ndarray to PIL Image.
|
||||
Args:
|
||||
pic (Tensor or numpy.ndarray): Image to be converted to PIL Image.
|
||||
pic (Tensor or numpy.ndarray): Image(HWC format) to be converted to PIL Image.
|
||||
mode (`PIL.Image mode`_): color space and pixel depth of input data (optional).
|
||||
.. _PIL.Image mode: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#concept-modes
|
||||
Returns:
|
||||
|
@ -694,7 +693,7 @@ class Gray:
|
|||
transform = transform.Gray()
|
||||
img_ = transform(img)
|
||||
'''
|
||||
def __init__(self, num_output_channels):
|
||||
def __init__(self, num_output_channels=1):
|
||||
self.num_output_channels = num_output_channels
|
||||
|
||||
def __call__(self, img:Image.Image):
|
||||
|
|
|
@ -12,6 +12,9 @@ import hashlib
|
|||
import urllib.request
|
||||
from tqdm import tqdm
|
||||
from jittor_utils import lock
|
||||
import gzip
|
||||
import tarfile
|
||||
import zipfile
|
||||
|
||||
def ensure_dir(dir_path):
|
||||
if not os.path.isdir(dir_path):
|
||||
|
@ -69,3 +72,77 @@ def calculate_md5(file_path, chunk_size=1024 * 1024):
|
|||
def check_md5(file_path, md5, **kwargs):
|
||||
return md5 == calculate_md5(file_path, **kwargs)
|
||||
|
||||
|
||||
def check_integrity(fpath, md5=None):
|
||||
if not os.path.isfile(fpath):
|
||||
return False
|
||||
if md5 is None:
|
||||
return True
|
||||
return check_md5(fpath, md5)
|
||||
|
||||
|
||||
def _is_tarxz(filename):
|
||||
return filename.endswith(".tar.xz")
|
||||
|
||||
|
||||
def _is_tar(filename):
|
||||
return filename.endswith(".tar")
|
||||
|
||||
|
||||
def _is_targz(filename):
|
||||
return filename.endswith(".tar.gz")
|
||||
|
||||
|
||||
def _is_tgz(filename):
|
||||
return filename.endswith(".tgz")
|
||||
|
||||
|
||||
def _is_gzip(filename):
|
||||
return filename.endswith(".gz") and not filename.endswith(".tar.gz")
|
||||
|
||||
|
||||
def _is_zip(filename):
|
||||
return filename.endswith(".zip")
|
||||
|
||||
|
||||
def extract_archive(from_path, to_path=None, remove_finished=False):
|
||||
if to_path is None:
|
||||
to_path = os.path.dirname(from_path)
|
||||
|
||||
if _is_tar(from_path):
|
||||
with tarfile.open(from_path, 'r') as tar:
|
||||
tar.extractall(path=to_path)
|
||||
elif _is_targz(from_path) or _is_tgz(from_path):
|
||||
with tarfile.open(from_path, 'r:gz') as tar:
|
||||
tar.extractall(path=to_path)
|
||||
elif _is_tarxz(from_path):
|
||||
# .tar.xz archive only supported in Python 3.x
|
||||
with tarfile.open(from_path, 'r:xz') as tar:
|
||||
tar.extractall(path=to_path)
|
||||
elif _is_gzip(from_path):
|
||||
to_path = os.path.join(to_path, os.path.splitext(os.path.basename(from_path))[0])
|
||||
with open(to_path, "wb") as out_f, gzip.GzipFile(from_path) as zip_f:
|
||||
out_f.write(zip_f.read())
|
||||
elif _is_zip(from_path):
|
||||
with zipfile.ZipFile(from_path, 'r') as z:
|
||||
z.extractall(to_path)
|
||||
else:
|
||||
raise ValueError("Extraction of {} not supported".format(from_path))
|
||||
|
||||
if remove_finished:
|
||||
os.remove(from_path)
|
||||
|
||||
|
||||
def download_and_extract_archive(url, download_root, extract_root=None, filename=None,
|
||||
md5=None, remove_finished=False):
|
||||
download_root = os.path.expanduser(download_root)
|
||||
if extract_root is None:
|
||||
extract_root = download_root
|
||||
if not filename:
|
||||
filename = os.path.basename(url)
|
||||
|
||||
download_url_to_local(url, filename, download_root, md5)
|
||||
|
||||
archive = os.path.join(download_root, filename)
|
||||
print("Extracting {} to {}".format(archive, extract_root))
|
||||
extract_archive(archive, extract_root, remove_finished)
|
||||
|
|
Loading…
Reference in New Issue