Merge branch 'master' into macOS

This commit is contained in:
lzhengning 2021-06-18 10:30:57 +08:00
commit df0ea12d7e
21 changed files with 520 additions and 112 deletions

View File

@ -9,7 +9,7 @@
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
__version__ = '1.2.3.24'
__version__ = '1.2.3.32'
from jittor_utils import lock
with lock.lock_scope():
ori_int = int
@ -437,11 +437,11 @@ def pow(x, y):
Var.pow = Var.__pow__ = pow
def argmax(x, dim, keepdims:bool=False):
return x.arg_reduce("max", dim, keepdims)
return jt.arg_reduce(x, "max", dim, keepdims)
Var.argmax = argmax
def argmin(x, dim, keepdims:bool=False):
return x.arg_reduce("min", dim, keepdims)
return jt.arg_reduce(x, "min", dim, keepdims)
Var.argmin = argmin
def randn(*size, dtype="float32", requires_grad=True) -> Var:

View File

@ -1090,7 +1090,7 @@ if os.path.isfile(version_file) and not os.path.isdir(os.path.join(jittor_path,
compile(cc_path, cc_flags+opt_flags, files, 'jittor_core'+extension_suffix)
# TODO: move to compile_extern.py
compile_extern()
# compile_extern()
with jit_utils.import_scope(import_flags):
import jittor_core as core

View File

@ -1,5 +1,6 @@
from .dataset import Dataset, ImageFolder
from .dataset import Dataset, ImageFolder, dataset_root
from .mnist import MNIST
from .cifar import CIFAR10, CIFAR100
from .voc import VOC
from .sampler import *

View File

@ -0,0 +1,189 @@
import os
from jittor_utils.misc import download_and_extract_archive, check_integrity
from PIL import Image
import sys, pickle
import numpy as np
from jittor.dataset import Dataset, dataset_root
class CIFAR10(Dataset):
"""`CIFAR10 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset.
Args:
root (string): Root directory of dataset where directory
``cifar-10-batches-py`` exists or will be saved to if download is set to True.
train (bool, optional): If True, creates dataset from training set, otherwise
creates from test set.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
Example::
from jittor.dataset.cifar import CIFAR10
a = CIFAR10()
a.set_attrs(batch_size=16)
for imgs, labels in a:
print(imgs.shape, labels.shape)
break
"""
base_folder = 'cifar-10-batches-py'
url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
filename = "cifar-10-python.tar.gz"
tgz_md5 = 'c58f30108f718f92721af3b95e74349a'
train_list = [
['data_batch_1', 'c99cafc152244af753f735de768cd75f'],
['data_batch_2', 'd4bba439e000b95fd0a9bffe97cbabec'],
['data_batch_3', '54ebc095f3ab1f0389bbae665268c751'],
['data_batch_4', '634d18415352ddfa80567beed471001a'],
['data_batch_5', '482c414d41f54cd18b22e5b47cb7c3cb'],
]
test_list = [
['test_batch', '40351d587109b95175f43aff81a1287e'],
]
meta = {
'filename': 'batches.meta',
'key': 'label_names',
'md5': '5ff9c542aee3614f3951f8cda6e48888',
}
def __init__(self, root=dataset_root+"/cifar_data/", train=True, transform=None, target_transform=None,
download=True):
super(CIFAR10, self).__init__()
self.root = root
self.transform=transform
self.target_transform=target_transform
self.train = train # training set or test set
if download:
self.download()
if not self._check_integrity():
raise RuntimeError('Dataset not found or corrupted.' +
' You can use download=True to download it')
if self.train:
downloaded_list = self.train_list
else:
downloaded_list = self.test_list
self.data = []
self.targets = []
# now load the picked numpy arrays
for file_name, checksum in downloaded_list:
file_path = os.path.join(self.root, self.base_folder, file_name)
with open(file_path, 'rb') as f:
if sys.version_info[0] == 2:
entry = pickle.load(f)
else:
entry = pickle.load(f, encoding='latin1')
self.data.append(entry['data'])
if 'labels' in entry:
self.targets.extend(entry['labels'])
else:
self.targets.extend(entry['fine_labels'])
self.data = np.vstack(self.data).reshape(-1, 3, 32, 32)
self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC
self._load_meta()
def _load_meta(self):
path = os.path.join(self.root, self.base_folder, self.meta['filename'])
if not check_integrity(path, self.meta['md5']):
raise RuntimeError('Dataset metadata file not found or corrupted.' +
' You can use download=True to download it')
with open(path, 'rb') as infile:
if sys.version_info[0] == 2:
data = pickle.load(infile)
else:
data = pickle.load(infile, encoding='latin1')
self.classes = data[self.meta['key']]
self.class_to_idx = {_class: i for i, _class in enumerate(self.classes)}
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is index of the target class.
"""
img, target = self.data[index], self.targets[index]
# doing this so that it is consistent with all other datasets
# to return a PIL Image
img = Image.fromarray(img)
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
def __len__(self):
return len(self.data)
def _check_integrity(self):
root = self.root
for fentry in (self.train_list + self.test_list):
filename, md5 = fentry[0], fentry[1]
fpath = os.path.join(root, self.base_folder, filename)
if not check_integrity(fpath, md5):
return False
return True
def download(self):
if self._check_integrity():
print('Files already downloaded and verified')
return
download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.tgz_md5)
def extra_repr(self):
return "Split: {}".format("Train" if self.train is True else "Test")
class CIFAR100(CIFAR10):
"""`CIFAR100 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset.
This is a subclass of the `CIFAR10` Dataset.
Example::
from jittor.dataset.cifar import CIFAR100
a = CIFAR100()
a.set_attrs(batch_size=16)
for imgs, labels in a:
print(imgs.shape, labels.shape)
break
"""
base_folder = 'cifar-100-python'
url = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz"
filename = "cifar-100-python.tar.gz"
tgz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85'
train_list = [
['train', '16019d7e3df5f24257cddd939b257f8d'],
]
test_list = [
['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc'],
]
meta = {
'filename': 'meta',
'key': 'fine_label_names',
'md5': '7973b15100ade9c7d40fb424638fde48',
}

View File

@ -29,18 +29,7 @@ kernel(in0->num/in0->shape[in0->shape.size()-1], 0, in0_p, out0_p, in0->shape[in
class OneHotCategorical:
def __init__(self, probs=None, logits=None):
assert not (probs is None and logits is None)
if probs is None:
# cannot align to pytorch
probs = jt.sigmoid(logits)
elif logits is None:
logits = jt.log(probs)
with jt.no_grad():
self.probs = probs / probs.sum(-1, True)
self.cum_probs = simple_presum(self.probs)
self.cum_probs_l = self.cum_probs[..., :-1]
self.cum_probs_r = self.cum_probs[..., 1:]
self.logits = logits
Categorical.__init__(self, probs, logits)
def sample(self, sample_shape=[]):
shape = sample_shape + self.probs.shape[:-1] + (1,)
@ -48,17 +37,12 @@ class OneHotCategorical:
one_hot = jt.logical_and(self.cum_probs_l < rand, rand <= self.cum_probs_r).float()
return one_hot
def log_prob(self,x):
if len(x.shape) == 1:
x = x.unsqueeze(0)
logits = self.logits.broadcast(x.shape)
indices = jt.argmax(x, dim=-1)[0]
return logits.gather(1, indices.unsqueeze(-1)).reshape(-1)
def log_prob(self, x):
x = jt.argmax(x, dim=-1)[0]
return Categorical.log_prob(self, x)
def entropy(self):
min_real = -(math.pow(2,23)-1) / math.pow(2,22) * math.pow(2,127)
logits = jt.clamp(self.logits,min_v=min_real)
p_log_p = logits * self.probs
p_log_p = self.logits * self.probs
return -p_log_p.sum(-1)
@ -68,29 +52,32 @@ class Categorical:
if probs is None:
# cannot align to pytorch
probs = jt.sigmoid(logits)
elif logits is None:
logits = jt.log(probs)
probs = probs / probs.sum(-1, True)
if logits is None:
logits = jt.safe_log(probs)
with jt.no_grad():
self.probs = probs / probs.sum(-1, True)
self.probs = probs
self.logits = logits
self.cum_probs = simple_presum(probs)
self.cum_probs = simple_presum(self.probs)
self.cum_probs_l = self.cum_probs[..., :-1]
self.cum_probs_r = self.cum_probs[..., 1:]
def sample(self, sample_shape=[]):
def sample(self, sample_shape=()):
shape = sample_shape + self.probs.shape[:-1] + (1,)
rand = jt.rand(shape)
one_hot = jt.logical_and(self.cum_probs_l < rand, rand <= self.cum_probs_r)
index = one_hot.index(one_hot.ndim-1)
index = one_hot.index(one_hot.ndim - 1)
return (one_hot * index).sum(-1)
def log_prob(self, x):
return jt.log(self.probs)[0,x]
a = self.probs.ndim
b = x.ndim
indexes = tuple( f'i{i}' for i in range(b-a+1, b) )
indexes = indexes + (x,)
return jt.safe_log(self.probs).getitem(indexes)
def entropy(self):
min_real = -(math.pow(2,23)-1) / math.pow(2,22) * math.pow(2,127)
logits = jt.clamp(self.logits,min_v=min_real)
p_log_p = logits * self.probs
p_log_p = self.logits * self.probs
return -p_log_p.sum(-1)
@ -104,11 +91,11 @@ class Normal:
def log_prob(self, x):
var = self.sigma**2
log_scale = jt.log(self.sigma)
log_scale = jt.safe_log(self.sigma)
return -((x-self.mu)**2) / (2*var) - log_scale-np.log(np.sqrt(2*np.pi))
def entropy(self):
return 0.5+0.5*np.log(2*np.pi)+jt.log(self.sigma)
return 0.5+0.5*np.log(2*np.pi)+jt.safe_log(self.sigma)
class Uniform:
@ -123,10 +110,10 @@ class Uniform:
def log_prob(self,x):
if x < self.low or x >= self.high:
return math.inf
return -jt.log(self.high - self.low)
return -jt.safe_log(self.high - self.low)
def entropy(self):
return jt.log(self.high - self.low)
return jt.safe_log(self.high - self.low)
class Geometric:
@ -138,15 +125,14 @@ class Geometric:
self.logits = logits
elif logits is None:
self.prob = p
self.logits = -jt.log(1. / p - 1)
self.logits = -jt.safe_log(1. / p - 1)
def sample(self, sample_shape):
tiny = jt.info(self.probs.dtype).tiny
u = jt.clamp(jt.rand(sample_shape),min_v=tiny)
return (jt.log(u) / (jt.log(-self.probs+1))).floor()
u = jt.rand(sample_shape)
return (jt.safe_log(u) / (jt.safe_log(-self.probs+1))).floor()
def log_prob(self, x):
return x*jt.log(-self.prob+1)+jt.log(self.prob)
return x*jt.safe_log(-self.prob+1)+jt.safe_log(self.prob)
def entropy(self):
return binary_cross_entropy_with_logits(jt.array(self.logits),jt.array(self.prob)) / self.prob
@ -157,16 +143,14 @@ def kl_divergence(cur_dist, old_dist):
if isinstance(cur_dist, Normal):
vr = (cur_dist.sigma / old_dist.sigma)**2
t1 = ((cur_dist.mu - old_dist.mu) / old_dist.sigma)**2
return 0.5*(vr+t1-1-jt.log(vr))
return 0.5*(vr+t1-1-jt.safe_log(vr))
if isinstance(cur_dist, Categorical) or isinstance(cur_dist,OneHotCategorical):
t = cur_dist.probs * (cur_dist.logits-old_dist.logits)
t[jt.array((old_dist.probs == 0))] = math.inf
t[jt.array((cur_dist.probs == 0))] = 0
return t.sum(-1)
if isinstance(cur_dist, Uniform):
res = jt.log((old_dist.high - old_dist.low) / (cur_dist.high - cur_dist.low))
res = jt.safe_log((old_dist.high - old_dist.low) / (cur_dist.high - cur_dist.low))
if old_dist.low > cur_dist.low or old_dist.high < cur_dist.high:
res = math.inf
return res
if isinstance(cur_dist, Geometric):
return -cur_dist.entropy() - jt.log(-old_dist.prob+1) / cur_dist.prob - old_dist.logits
return -cur_dist.entropy() - jt.safe_log(-old_dist.prob+1) / cur_dist.prob - old_dist.logits

View File

@ -634,23 +634,6 @@ def kthvalue(input, k, dim=None, keepdim=False):
jt.Var.kthvalue = kthvalue
def gather(x,dim,index):
if dim<0:
dim+=index.ndim
x_shape = list(x.shape )
i_shape = list(index.shape)
assert i_shape[dim]>0
assert x.ndim == index.ndim
i_shape[dim]=x_shape[dim]
assert i_shape == x_shape
ins = []
for i in range(index.ndim):
ins.append(jt.index(index.shape,dim=i))
ins[dim]=index
return x.reindex(ins)
jt.Var.gather = gather
def _prod(x,dim=0):
x = jt.log(x)
x = x.sum(dim=dim)
@ -1255,3 +1238,7 @@ Examples::
return x.reindex(x.shape, ids)
jt.Var.roll = roll
def safe_log(x):
return jt.safe_clip(x, 1e-30, 1e30).log()
jt.Var.safe_log = safe_log

View File

@ -28,12 +28,12 @@ def matmul_transpose(a, b):
'''
returns a * b^T
'''
assert len(a.shape) >= 2 and len(b.shape) == 2
assert a.shape[-1] == b.shape[-1], (a.shape, b.shape)
if len(a.shape)>2:
if len(a.shape) != 2:
aa = a.reshape((-1, a.shape[-1]))
cc = matmul_transpose(aa, b)
return cc.reshape(a.shape[:-1]+(-1,))
assert len(a.shape) == 2 and len(b.shape) == 2
shape = list(a.shape)[:-1] + list(b.shape)
a = a.broadcast(shape, [len(shape)-2])

View File

@ -486,9 +486,11 @@ void Executor::run_sync(vector<Var*> vars, bool device_sync) {
if (use_cuda)
checkCudaErrors(cudaDeviceSynchronize());
#endif
for (Var* var : op->outputs())
check_nan(var);
}
#ifdef JT_CHECK_NAN
for (Var* var : op->outputs())
check_nan(var);
#endif
LOGvvv << "Finished Op(" >> op->name() << rid >>
"/" >> queue.size() >> ") output:" << op->outputs();
if (is_fused_op) {

View File

@ -104,6 +104,8 @@ int OpCompiler::total_member_count() {
// array need a extra local var
if (op->ops[i]->name()==string("array"))
member_count += 1;
if (op->ops[i]->name()==string("safe_clip"))
member_count += 2;
member_count += v.size();
i += 1;
}
@ -826,11 +828,15 @@ string OpCompiler::__get_fused_src(
const unordered_set<string> members = {
"x", "y", "z", "cond", "output", "extras"
};
const unordered_set<string> scalar_members = {
"left", "right"
};
const unordered_set<string> unchanged = {
"for", "const", "auto", "get_random_engine",
"int", "float", "bool", "CHECK", "STRINGIZE",
"void", "__restrict__", "if", "true", "false",
"Op", "Var", "Node", "itof", "assert", "ASSERT"
"Op", "Var", "Node", "itof", "assert", "ASSERT",
"float64"
};
auto not_change = [&](const string& s) -> bool {
if (unchanged.count(s)) return true;
@ -941,7 +947,8 @@ string OpCompiler::__get_fused_src(
while (l<src.size() && isvar(src[l])) l++;
auto var = src.substr(j, l-j);
if (var[0] == ':' || isdigit(var[0]) || not_change(var) || src[j-1]=='.' || src[j-1]=='>') {} else
if (members.count(var)) {
if (members.count(var) || scalar_members.count(var)) {
bool is_member = members.count(var);
string arg_name = "op" + S(oi) + "_" + var;
if (l<src.size() && src[l]=='[') {
// handle extras[...]
@ -964,7 +971,8 @@ string OpCompiler::__get_fused_src(
" = (("+name3+"Op*)(ops[" + S(oi) + "]))->" + var;
fused_kernel_args += ";\n";
kernel_args.insert(arg_name);
op_members[oi].push_back(arg_name);
if (is_member)
op_members[oi].push_back(arg_name);
}
fused_kernel += arg_name;
j = l-1;

View File

@ -0,0 +1,47 @@
// ***************************************************************
// Copyright (c) 2021 Jittor. All Rights Reserved.
// Maintainers: Dun Liang <randonlang@gmail.com>.
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#include <cmath>
#include "var.h"
#include "ops/safe_clip_op.h"
#include "ops/op_register.h"
namespace jittor {
#ifndef JIT
SafeClipOp::SafeClipOp(Var* x, float64 left, float64 right) : x(x), left(left), right(right) {
flags.set(NodeFlags::_cpu);
flags.set(NodeFlags::_cuda);
set_type(OpType::element);
y = create_output(nullptr, x->dtype());
}
VarPtr SafeClipOp::grad(Var* out, Var* dout, Var* v, int v_index) {
return dout;
}
void SafeClipOp::infer_shape() {
y->set_shape(x->shape);
}
void SafeClipOp::jit_prepare(JK& jk) {
jk << _CS("[Tx:") << x->dtype() <<']';
}
#else // JIT
void SafeClipOp::jit_run() {
auto* __restrict__ xp = x->ptr<Tx>();
Tx left_value = (Tx)std::max((float64)std::numeric_limits<Tx>::lowest(), left);
Tx right_value = (Tx)std::min((float64)std::numeric_limits<Tx>::max(), right);
auto* __restrict__ yp = y->ptr<Tx>();
index_t num = y->num;
for (index_t i=0; i<num; i++)
yp[i] = xp[i] < left_value ? left_value : (xp[i] > right_value ? right_value : xp[i]);
}
#endif // JIT
} // jittor

View File

@ -0,0 +1,33 @@
// ***************************************************************
// Copyright (c) 2021 Jittor. All Rights Reserved.
// Maintainers: Dun Liang <randonlang@gmail.com>.
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#pragma once
#include "op.h"
namespace jittor {
struct SafeClipOp : Op {
Var* x, * y;
float64 left, right;
/** Safe clip value to a range, and keep
the gradient pass thought.
* [in] x: input value
* [in] left: float64 clip min value.
* [in] right: float64 clip max value.
*/
// @pybind(safe_clip)
SafeClipOp(Var* x, float64 left, float64 right);
const char* name() const override { return "safe_clip"; }
VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override;
void infer_shape() override;
DECLARE_jit_run;
};
} // jittor

View File

@ -67,6 +67,10 @@ void LoopToFuncPass::run() {
args.push_back(d.get());
continue;
}
if (endswith(d->attrs["lvalue"], "_value")) {
args.push_back(d.get());
continue;
}
}
}
func->push_back(d->clone());

View File

@ -8,6 +8,7 @@
#include <streambuf>
#include "misc/hash.h"
#include "utils/cache_compile.h"
#include "utils/str_utils.h"
namespace jittor {
namespace jit_compiler {
@ -137,7 +138,7 @@ size_t skip_comments(const string& src, size_t i) {
return i;
}
void process(string src, vector<string>& input_names) {
void process(string src, vector<string>& input_names, string& cmd) {
for (size_t i=0; i<src.size(); i++) {
i = skip_comments(src, i);
if (i>=src.size()) break;
@ -159,6 +160,20 @@ void process(string src, vector<string>& input_names) {
input_names.push_back(inc);
}
}
if (l-k>2 && src[k] == 'J' && src[k+1] == 'T' && j-i==6 && src.substr(i,j-i) == "#ifdef") {
auto inc = src.substr(k, l-k);
auto env = getenv(inc.c_str());
if (env && string(env)!="0") {
string dflag = " -D"+inc+"="+string(env)+" -o ";
if (cmd.find(dflag) == string::npos) {
// -D flags should insert before -o flag
auto cmds = split(cmd, " -o ", 2);
if (cmds.size() == 2) {
cmd = cmds[0] + dflag + cmds[1];
}
}
}
}
i=l;
}
}
@ -173,12 +188,6 @@ bool cache_compile(const string& cmd, const string& cache_path, const string& ji
bool ran = false;
output_cache_key = read_all(output_name+".key");
string cd_cmd = cache_path.size() ? "cd " + cache_path + " && " + cmd : cmd;
if (output_cache_key.size() == 0) {
LOGvv << "Cache key of" << output_name << "not found.";
LOGvvv << "Run cmd:" << cmd;
system_with_check(cd_cmd.c_str());
ran = true;
}
string cache_key = cmd;
cache_key += "\n";
unordered_set<string> processed;
@ -194,7 +203,7 @@ bool cache_compile(const string& cmd, const string& cache_path, const string& ji
ASSERT(src.size()) << "Source read failed:" << input_names[i];
auto hash = S(hash64(src));
vector<string> new_names;
process(src, new_names);
process(src, new_names, cd_cmd);
for (auto& name : new_names) {
string full_name;
if (name.substr(0, 4) == "jit/" || name.substr(0, 4) == "gen/")
@ -224,9 +233,15 @@ bool cache_compile(const string& cmd, const string& cache_path, const string& ji
cache_key += hash;
cache_key += "\n";
}
if (output_cache_key.size() == 0) {
LOGvv << "Cache key of" << output_name << "not found.";
LOGvvv << "Run cmd:" << cd_cmd;
system_with_check(cd_cmd.c_str());
ran = true;
}
if (output_cache_key.size() != 0 && output_cache_key != cache_key) {
LOGvv << "Cache key of" << output_name << "changed.";
LOGvvv << "Run cmd:" << cmd;
LOGvvv << "Run cmd:" << cd_cmd;
system_with_check(cd_cmd.c_str());
ran = true;
}
@ -298,7 +313,8 @@ void test_find_nams_error(string cmd) {
void test_process(string src, vector<string> files) {
vector<string> ifiles;
jittor::jit_compiler::process(src, ifiles);
string cmd;
jittor::jit_compiler::process(src, ifiles, cmd);
CHECK(files.size() == ifiles.size());
for (size_t i=0; i<files.size(); i++)
CHECKop(files[i],==,ifiles[i]);

View File

@ -25,15 +25,15 @@ bool endswith(const string& a, const string& b) {
vector<string> split(const string& s, const string& sep, int max_split) {
vector<string> ret;
int pos = -1, pos_next;
int pos = 0, pos_next;
while (1) {
pos_next = s.find(sep, pos+1);
pos_next = s.find(sep, pos);
if (pos_next == (int)string::npos || (int)ret.size() == max_split-1) {
ret.push_back(s.substr(pos+sep.size()));
ret.push_back(s.substr(pos));
return ret;
}
ret.push_back(s.substr(pos+sep.size(), pos_next-pos-sep.size()));
pos = pos_next;
ret.push_back(s.substr(pos, pos_next-pos));
pos = pos_next + sep.size();
}
ASSERT(max_split==0);
return ret;

View File

@ -161,6 +161,16 @@ class TestDatasetSeed(unittest.TestCase):
for i in range(len(d)):
for j in range(i+1, len(d)):
assert not np.allclose(dd[i], dd[j])
def test_cifar(self):
from jittor.dataset.cifar import CIFAR10
a = CIFAR10()
a.set_attrs(batch_size=16)
for imgs, labels in a:
print(imgs.shape, labels.shape)
assert imgs.shape == [16,32,32,3,]
assert labels.shape == [16,]
break
if __name__ == "__main__":

View File

@ -31,14 +31,14 @@ class TestOneHot(unittest.TestCase):
probs,probs2 = np.random.uniform(0,1,(10)), np.random.uniform(0,1,(10))
probs,probs2 = probs / probs.sum(),probs2 / probs2.sum()
import torch
jc, jc2 = jd.OneHotCategorical(jt.array(probs).reshape(1,-1)),jd.OneHotCategorical(jt.array(probs2).reshape(1,-1))
jc, jc2 = jd.OneHotCategorical(jt.array(probs)),jd.OneHotCategorical(jt.array(probs2))
tc, tc2 = torch.distributions.OneHotCategorical(torch.tensor(probs)),torch.distributions.OneHotCategorical(torch.tensor(probs2))
assert np.allclose(jc.entropy().data,tc.entropy().numpy())
x = np.zeros((4,10))
for _ in range(4):
nx = np.random.randint(0,9)
x[_,nx] = 1
assert np.allclose(jc.log_prob(jt.array(x)),tc.log_prob(torch.tensor(x)))
np.testing.assert_allclose(jc.log_prob(jt.array(x)),tc.log_prob(torch.tensor(x)), atol=1e-5)
assert np.allclose(jd.kl_divergence(jc,jc2),torch.distributions.kl_divergence(tc,tc2))
def test_cate(self):
@ -67,17 +67,55 @@ class TestOneHot(unittest.TestCase):
tn2 = torch.distributions.Normal(mu2,sigma2)
assert np.allclose(jd.kl_divergence(jn,jn2).data,torch.distributions.kl_divergence(tn,tn2).numpy())
def test_categorical(self):
def test_categorical1(self):
import torch
for _ in range(4):
probs,probs2 = np.random.uniform(0,1,(10)), np.random.uniform(0,1,(10))
probs,probs2 = probs / probs.sum(),probs2 / probs2.sum()
jc, jc2 = jd.Categorical(jt.array(probs).reshape(1,-1)),jd.Categorical(jt.array(probs2).reshape(1,-1))
jc, jc2 = jd.Categorical(jt.array(probs)),jd.Categorical(jt.array(probs2))
tc, tc2 = torch.distributions.Categorical(torch.tensor(probs)),torch.distributions.Categorical(torch.tensor(probs2))
assert np.allclose(jc.entropy().data, tc.entropy().numpy()), (jc.entropy().data, tc.entropy().numpy())
x = np.random.randint(0,10,(4))
assert np.allclose(jc.log_prob(x), tc.log_prob(torch.tensor(x)))
np.testing.assert_allclose(jc.log_prob(x), tc.log_prob(torch.tensor(x)), atol=1e-5)
assert np.allclose(jd.kl_divergence(jc,jc2),torch.distributions.kl_divergence(tc,tc2))
def test_categorical2(self):
def check(prob_shape, sample_shape):
import torch
for _ in range(4):
probs,probs2 = np.random.uniform(0,1,prob_shape), np.random.uniform(0,1, prob_shape)
jc, jc2 = jd.Categorical(jt.array(probs)),jd.Categorical(jt.array(probs2))
tc, tc2 = torch.distributions.Categorical(torch.tensor(probs)),torch.distributions.Categorical(torch.tensor(probs2))
assert np.allclose(jc.entropy().data, tc.entropy().numpy()), (jc.entropy().data, tc.entropy().numpy())
x1 = jc.sample(sample_shape)
x2 = tc.sample(sample_shape)
assert tuple(x1.shape) == tuple(x2.shape)
x = np.random.randint(0,prob_shape[-1], tuple(x1.shape))
np.testing.assert_allclose(jc.log_prob(x), tc.log_prob(torch.tensor(x)), atol=1e-5)
np.testing.assert_allclose(jd.kl_divergence(jc,jc2), torch.distributions.kl_divergence(tc,tc2), atol=1e-5)
check((10,), (4,))
check((2,3), (4,))
check((3,4,5,6), (2,))
def test_one_hot_categorical2(self):
def check(prob_shape, sample_shape):
import torch
for _ in range(4):
probs,probs2 = np.random.uniform(0,1,prob_shape), np.random.uniform(0,1, prob_shape)
jc, jc2 = jd.OneHotCategorical(jt.array(probs)),jd.OneHotCategorical(jt.array(probs2))
tc, tc2 = torch.distributions.OneHotCategorical(torch.tensor(probs)),torch.distributions.OneHotCategorical(torch.tensor(probs2))
assert np.allclose(jc.entropy().data, tc.entropy().numpy()), (jc.entropy().data, tc.entropy().numpy())
x1 = jc.sample(sample_shape)
x2 = tc.sample(sample_shape)
assert tuple(x1.shape) == tuple(x2.shape)
x = np.random.randint(0,prob_shape[-1], tuple(x1.shape))
np.testing.assert_allclose(jc.log_prob(x), tc.log_prob(torch.tensor(x)), atol=1e-5)
np.testing.assert_allclose(jd.kl_divergence(jc,jc2), torch.distributions.kl_divergence(tc,tc2), atol=1e-5)
check((10,), (4,))
check((2,3), (4,))
check((3,4,5,6), (2,))
def test_uniform(self):
import torch
@ -98,11 +136,11 @@ class TestOneHot(unittest.TestCase):
prob, prob2 = np.random.uniform(0,1), np.random.uniform(0,1)
jg, jg2 = jd.Geometric(prob),jd.Geometric(prob2)
tg, tg2 = torch.distributions.Geometric(prob),torch.distributions.Geometric(prob2)
assert np.allclose(jg.entropy().data,tg.entropy().numpy())
np.testing.assert_allclose(jg.entropy().data,tg.entropy().numpy(), atol=1e-4)
x = np.random.randint(1,10)
assert np.allclose(jg.log_prob(x),tg.log_prob(torch.tensor(x)))
np.testing.assert_allclose(jg.log_prob(x),tg.log_prob(torch.tensor(x)), atol=1e-4)
# print(jd.kl_divergence(jg,jg2),torch.distributions.kl_divergence(tg,tg2))
assert np.allclose(jd.kl_divergence(jg,jg2),torch.distributions.kl_divergence(tg,tg2))
np.testing.assert_allclose(jd.kl_divergence(jg,jg2),torch.distributions.kl_divergence(tg,tg2), atol=1e-4)
if __name__ == "__main__":
unittest.main()

View File

@ -347,5 +347,11 @@ class TestMatmul(unittest.TestCase):
def test_matmul_example2_cuda(self):
self.test_matmul_example2()
def test_linear1d(self):
linear = jt.nn.Linear(10,20)
a = jt.random((10,))
b = linear(a)
assert b.shape == (20,)
if __name__ == "__main__":
unittest.main()

View File

@ -257,14 +257,14 @@ class Tester(unittest.TestCase):
expect = input_data.transpose(2,0,1)
self.assertTrue(np.allclose(expect, output), f"{expect.shape}\n{output.shape}")
ndarray = np.random.randint(low=0, high=255, size=(height, width, channels)).astype(np.uint8)
ndarray = np.random.randint(low=0, high=255, size=(channels, height, width)).astype(np.uint8)
output = trans(ndarray)
expected_output = ndarray.transpose((2, 0, 1)) / 255.0
self.assertTrue(np.allclose(output, expected_output))
expected_output = ndarray / 255.0
np.testing.assert_allclose(output, expected_output)
ndarray = np.random.rand(height, width, channels).astype(np.float32)
ndarray = np.random.rand(channels, height, width).astype(np.float32)
output = trans(ndarray)
expected_output = ndarray.transpose((2, 0, 1))
expected_output = ndarray
self.assertTrue(np.allclose(output, expected_output))
# separate test for mode '1' PIL images

View File

@ -69,6 +69,13 @@ class TestUnaryOp(unittest.TestCase):
b1 = b.sigmoid().numpy()
assert np.isnan(b1).any() == False
def test_safe_clip(self):
a = jt.array([-1.0,0,0.4,1,2,3])
b = a.safe_clip(0.1, 0.5)
assert np.allclose(b.data, [0.1,0.1,0.4,0.5,0.5,0.5])
da = jt.grad(b, a)
assert (da.data == 1).all()
class TestUnaryOpCuda(TestUnaryOp, test_cuda(2)):
pass

View File

@ -389,7 +389,7 @@ class CenterCrop:
def to_tensor(pic):
"""
Function for turning Image.Image to np.array.
Function for turning Image.Image to np.array with CHW format.
Args::
@ -414,14 +414,13 @@ def to_tensor(pic):
if _is_numpy(pic):
# handle numpy array
if pic.ndim == 2:
pic = pic[:, :, None]
pic = pic[None, :, :]
img = pic.transpose((2, 0, 1))
# backward compatibility
if img.dtype == 'uint8':
return np.float32(img) * np.float32(1/255.0)
if pic.dtype == 'uint8':
return np.float32(pic) * np.float32(1/255.0)
else:
return img
return pic
# handle PIL Image
if pic.mode == 'I':
@ -499,7 +498,7 @@ def _to_jittor_array(pic):
def to_pil_image(pic, mode=None):
"""Convert a tensor or an ndarray to PIL Image.
Args:
pic (Tensor or numpy.ndarray): Image to be converted to PIL Image.
pic (Tensor or numpy.ndarray): Image(HWC format) to be converted to PIL Image.
mode (`PIL.Image mode`_): color space and pixel depth of input data (optional).
.. _PIL.Image mode: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#concept-modes
Returns:
@ -694,7 +693,7 @@ class Gray:
transform = transform.Gray()
img_ = transform(img)
'''
def __init__(self, num_output_channels):
def __init__(self, num_output_channels=1):
self.num_output_channels = num_output_channels
def __call__(self, img:Image.Image):

View File

@ -12,6 +12,9 @@ import hashlib
import urllib.request
from tqdm import tqdm
from jittor_utils import lock
import gzip
import tarfile
import zipfile
def ensure_dir(dir_path):
if not os.path.isdir(dir_path):
@ -69,3 +72,77 @@ def calculate_md5(file_path, chunk_size=1024 * 1024):
def check_md5(file_path, md5, **kwargs):
return md5 == calculate_md5(file_path, **kwargs)
def check_integrity(fpath, md5=None):
if not os.path.isfile(fpath):
return False
if md5 is None:
return True
return check_md5(fpath, md5)
def _is_tarxz(filename):
return filename.endswith(".tar.xz")
def _is_tar(filename):
return filename.endswith(".tar")
def _is_targz(filename):
return filename.endswith(".tar.gz")
def _is_tgz(filename):
return filename.endswith(".tgz")
def _is_gzip(filename):
return filename.endswith(".gz") and not filename.endswith(".tar.gz")
def _is_zip(filename):
return filename.endswith(".zip")
def extract_archive(from_path, to_path=None, remove_finished=False):
if to_path is None:
to_path = os.path.dirname(from_path)
if _is_tar(from_path):
with tarfile.open(from_path, 'r') as tar:
tar.extractall(path=to_path)
elif _is_targz(from_path) or _is_tgz(from_path):
with tarfile.open(from_path, 'r:gz') as tar:
tar.extractall(path=to_path)
elif _is_tarxz(from_path):
# .tar.xz archive only supported in Python 3.x
with tarfile.open(from_path, 'r:xz') as tar:
tar.extractall(path=to_path)
elif _is_gzip(from_path):
to_path = os.path.join(to_path, os.path.splitext(os.path.basename(from_path))[0])
with open(to_path, "wb") as out_f, gzip.GzipFile(from_path) as zip_f:
out_f.write(zip_f.read())
elif _is_zip(from_path):
with zipfile.ZipFile(from_path, 'r') as z:
z.extractall(to_path)
else:
raise ValueError("Extraction of {} not supported".format(from_path))
if remove_finished:
os.remove(from_path)
def download_and_extract_archive(url, download_root, extract_root=None, filename=None,
md5=None, remove_finished=False):
download_root = os.path.expanduser(download_root)
if extract_root is None:
extract_root = download_root
if not filename:
filename = os.path.basename(url)
download_url_to_local(url, filename, download_root, md5)
archive = os.path.join(download_root, filename)
print("Extracting {} to {}".format(archive, extract_root))
extract_archive(archive, extract_root, remove_finished)