Merge branch 'nerf' of https://github.com/Jittor/jittor into nerf

This commit is contained in:
Dun Liang 2021-01-17 13:53:58 +08:00
commit 10359f02fa
10 changed files with 201 additions and 5 deletions

View File

@ -300,12 +300,12 @@ def std(x):
return out
Var.std = std
def norm(x, k, dim):
def norm(x, k, dim, keepdim=False):
assert k==1 or k==2
if k==1:
return x.abs().sum(dim)
return x.abs().sum(dim, keepdim)
if k==2:
return (x.sqr()).sum(dim).maximum(1e-6).sqrt()
return (x.sqr()).sum(dim, keepdim).maximum(1e-6).sqrt()
Var.norm = norm
origin_reshape = reshape

View File

@ -11,6 +11,12 @@ import jittor as jt
import numpy as np
import math
def eye(shape, dtype):
return jt.array(np.identity(shape[0])).unary(dtype)
def eye_(var):
var.assign(eye(var.shape, var.dtype))
def constant(shape, dtype, value=0.0):
return jt.array(value).unary(dtype).broadcast(shape)

View File

@ -18,10 +18,12 @@ if has_cupy:
import jittor as jt
import os
import ctypes
cupy_device = cp.cuda.Device(jt.mpi.local_rank())
cupy_device.__enter__()
def cvt(a):
a_pointer, read_only_flag = a.__array_interface__['data']
aptr=cp.cuda.MemoryPointer(cp.cuda.memory.UnownedMemory(a_pointer,a.size*a.itemsize,a,0),0)
aptr=cp.cuda.MemoryPointer(cp.cuda.memory.UnownedMemory(a_pointer,a.size*a.itemsize,a, jt.mpi.local_rank()),0)
a = cp.ndarray(a.shape,a.dtype,aptr)
return a

View File

@ -839,6 +839,7 @@ def get_max_memory_treemap(build_by=0, do_print=True):
if (do_print):
print(out)
return tree, out
def python_pass_warper(mod_func, args, kw):
import importlib
mod, func = mod_func.rsplit(".", 1)
@ -934,6 +935,84 @@ inline static void {func_name}({",".join(pargs+oargs)}) {{
"""
return new_src
def cumprod(a, dim):
class CumprodFunc(jt.Function):
def forward_code(self, np, data):
a = data["inputs"][0]
b = data["outputs"][0]
out = np.cumprod(a, self.dim)
np.copyto(b, out)
def backward_code(self, np, data):
a, b, dout = data["inputs"]
out = data["outputs"][0]
sdim = a.shape[self.dim]
dim = (len(a.shape)+1)*[1]
dim[self.dim+1] = sdim
res = np.tile(np.expand_dims(b, self.dim+1), dim)
dout = np.tile(np.expand_dims(dout, self.dim+1), dim)
dim[self.dim]=sdim
dim[self.dim+1]=1
a = np.tile(np.expand_dims(a, self.dim), dim)
res = res/a
mask = np.tril(np.ones((sdim, sdim)))
for i in range(self.dim):
mask = np.expand_dims(mask, 0)
for i in range(len(a.shape)-self.dim-2):
mask = np.expand_dims(mask, -1)
res = np.sum(mask*res*dout, self.dim)
np.copyto(out, res)
def execute(self, a, dim):
self.save_vars = a
self.dim = dim
self.res = jt.numpy_code(
a.shape,
a.dtype,
[a],
self.forward_code,
)
return self.res
def grad(self, grad_a):
a = self.save_vars
b = self.res
return jt.numpy_code(
a.shape,
a.dtype,
[a, b, grad_a],
self.backward_code,
)
func = CumprodFunc()
if dim<0:
dim+=len(a.shape)
return func(a, dim)
def linspace(start, end, steps):
res = jt.index((steps,))[0]
res = res*(end-start)/float(steps-1)+start
return res
def randperm(n):
# TODO: use jt.random
idx = np.arange(n)
return jt.array(np.random.permutation(idx))
def set_global_seed(seed):
jt.set_seed(seed)
np.random.seed(seed)
try:
import cupy
cupy.random.seed(seed)
except:
pass
def searchsorted(sorted, values, right=False):
"""
Find the indices from the innermost dimension of `sorted` for each `values`.

View File

@ -0,0 +1,51 @@
# ***************************************************************
# Copyright (c) 2020 Jittor. Authors:
# Guowei Yang <471184555@qq.com>
# Dun Liang <randonlang@gmail.com>.
# All Rights Reserved.
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
import unittest
import jittor as jt
import numpy as np
import ctypes
import sys
import torch
from torch.autograd import Variable
class TestCumprod(unittest.TestCase):
def test_cumprod_cpu(self):
jt.flags.use_cuda = 0
for i in range(1,6):
for j in range(i):
x = np.random.rand(*((10,)*i))
x_jt = jt.array(x)
y_jt = jt.cumprod(x_jt, j).sqr()
g_jt = jt.grad(y_jt.sum(), x_jt)
x_tc = Variable(torch.from_numpy(x), requires_grad=True)
y_tc = torch.cumprod(x_tc, j)**2
y_tc.sum().backward()
g_tc = x_tc.grad
assert np.allclose(y_jt.numpy(), y_tc.data)
assert np.allclose(g_jt.numpy(), g_tc.data)
def test_cumprod_gpu(self):
jt.flags.use_cuda = 1
for i in range(1,6):
for j in range(i):
x = np.random.rand(*((10,)*i))
x_jt = jt.array(x)
y_jt = jt.cumprod(x_jt, j).sqr()
g_jt = jt.grad(y_jt.sum(), x_jt)
x_tc = Variable(torch.from_numpy(x), requires_grad=True)
y_tc = torch.cumprod(x_tc, j)**2
y_tc.sum().backward()
g_tc = x_tc.grad
assert np.allclose(y_jt.numpy(), y_tc.data)
assert np.allclose(g_jt.numpy(), g_tc.data)
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,55 @@
# ***************************************************************
# Copyright (c) 2020 Jittor. Authors:
# Guowei Yang <471184555@qq.com>
# Dun Liang <randonlang@gmail.com>.
# All Rights Reserved.
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
import unittest
import jittor as jt
import numpy as np
import ctypes
import sys
import torch
from torch.autograd import Variable
class TestSearchsorted(unittest.TestCase):
def test_searchsorted_cpu(self):
jt.flags.use_cuda = 0
for i in range(1,3):
s = np.sort(np.random.rand(*((10,)*i)),-1)
v = np.random.rand(*((10,)*i))
s_jt = jt.array(s)
v_jt = jt.array(v)
s_tc = torch.from_numpy(s)
v_tc = torch.from_numpy(v)
y_jt = jt.searchsorted(s_jt, v_jt, right=True)
y_tc = torch.searchsorted(s_tc, v_tc, right=True)
assert np.allclose(y_jt.numpy(), y_tc.data)
y_jt = jt.searchsorted(s_jt, v_jt, right=False)
y_tc = torch.searchsorted(s_tc, v_tc, right=False)
assert np.allclose(y_jt.numpy(), y_tc.data)
def test_searchsorted_gpu(self):
jt.flags.use_cuda = 1
for i in range(1,3):
s = np.sort(np.random.rand(*((10,)*i)),-1)
v = np.random.rand(*((10,)*i))
s_jt = jt.array(s)
v_jt = jt.array(v)
s_tc = torch.from_numpy(s)
v_tc = torch.from_numpy(v)
y_jt = jt.searchsorted(s_jt, v_jt, right=True)
y_tc = torch.searchsorted(s_tc, v_tc, right=True)
assert np.allclose(y_jt.numpy(), y_tc.data)
y_jt = jt.searchsorted(s_jt, v_jt, right=False)
y_tc = torch.searchsorted(s_tc, v_tc, right=False)
assert np.allclose(y_jt.numpy(), y_tc.data)
if __name__ == "__main__":
unittest.main()

View File

@ -27,6 +27,7 @@
#include "parallel_compiler.h"
#include "memory_profiler.h"
#include "misc/nan_checker.h"
#include "memory_profiler.h"
namespace jittor {

View File

@ -211,6 +211,7 @@ vector<VarPtr> grad(Var* loss, vector<Var*> targets) {
if (id>=0)
grad = move(grads[id]);
if (!grad) {
// TODO: better warning message
LOGw << "grads[">>i>>"] '">> var->name>>"' doesn't have gradient. It will be set to zero:" << var;
grad = make_number(0.f, var);
assign_attrs(grad.ptr, var);

View File

@ -28,6 +28,7 @@ struct FloatOutput_ {
string suffix;
int p=4;
};
inline std::ostream& operator<<(std::ostream& os, const FloatOutput_& o) {
int w = 8;
os << std::setw(w-2-o.suffix.size());

View File

@ -211,7 +211,7 @@ void SetitemOp::graph_optimize() {
void GetitemOp::graph_optimize() {
// This optimize is still WIP
// LOGir << "hello getitem graph_optimize";
// setitem_grad_opt(this);
setitem_grad_opt(this);
(void)setitem_grad_opt;
// (void)getitem_inplace;
getitem_inplace(this);