auto diff hook optimizer and rand

This commit is contained in:
Dun Liang 2020-09-25 14:14:28 +08:00
parent c07d85dade
commit d6b9e32428
18 changed files with 287 additions and 56 deletions

View File

@ -502,9 +502,10 @@ class Module:
ms = []
stack = []
def callback(parents, k, v, n):
stack.append(str(k))
name = ".".join(stack[1:])
ms.append((name, v))
if isinstance(v, Module):
stack.append(str(k))
name = ".".join(stack[1:])
ms.append((name, v))
def callback_leave(parents, k, v, n):
stack.pop()
self.dfs([], "", callback, callback_leave)
@ -815,6 +816,28 @@ Var.__str__ = lambda x: str(x.data)
Var.__repr__ = lambda x: str(x.data)
Var.peek = lambda x: f"{x.dtype}{x.shape}"
def item(v):
return v.data.item()
def to_int(v):
dtype = str(v.dtype)
assert dtype.startswith("int")
return v.item()
def to_float(v):
dtype = str(v.dtype)
assert dtype.startswith("float")
return v.item()
def to_bool(v):
dtype = str(v.dtype)
assert dtype.startswith("int") or dtype=="bool"
return bool(v.item())
Var.item = item
Var.__int__ = to_int
Var.__float__ = to_float
Var.__bool__ = to_bool
ori_int = int

View File

@ -89,6 +89,11 @@ def chunk(x, chunks, dim=0):
return res
jt.Var.chunk = chunk
def expand(x, shape):
return x.broadcast(shape)
jt.Var.expand = expand
def stack(x, dim=0):
r'''
Concatenates sequence of vars along a new dimension.

View File

@ -261,19 +261,20 @@ class Linear(Module):
return x
class BatchNorm(Module):
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=None, is_train=True, sync=True):
assert affine == None
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, is_train=True, sync=True):
self.sync = sync
self.num_features = num_features
self.is_train = is_train
self.eps = eps
self.momentum = momentum
self.weight = init.constant((num_features,), "float32", 1.0)
self.bias = init.constant((num_features,), "float32", 0.0)
self.running_mean = init.constant((num_features,), "float32", 0.0).stop_grad()
self.running_var = init.constant((num_features,), "float32", 1.0).stop_grad()
self.affine = affine
if self.affine:
self.weight = init.constant((num_features,), "float32", 1.0)
self.bias = init.constant((num_features,), "float32", 0.0)
def execute(self, x):
if self.is_train:
xmean = jt.mean(x, dims=[0,2,3], keepdims=1)
@ -292,23 +293,27 @@ class BatchNorm(Module):
running_mean = self.running_mean.broadcast(x, [0,2,3])
running_var = self.running_var.broadcast(x, [0,2,3])
norm_x = (x-running_mean)/jt.sqrt(running_var+self.eps)
if not self.affine:
return norm_x
w = self.weight.broadcast(x, [0,2,3])
b = self.bias.broadcast(x, [0,2,3])
return norm_x * w + b
class BatchNorm1d(Module):
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=None, is_train=True, sync=True):
assert affine == None
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, is_train=True, sync=True):
self.sync = sync
self.num_features = num_features
self.is_train = is_train
self.eps = eps
self.momentum = momentum
self.weight = init.constant((num_features,), "float32", 1.0)
self.bias = init.constant((num_features,), "float32", 0.0)
self.running_mean = init.constant((num_features,), "float32", 0.0).stop_grad()
self.running_var = init.constant((num_features,), "float32", 1.0).stop_grad()
self.affine = affine
if self.affine:
self.weight = init.constant((num_features,), "float32", 1.0)
self.bias = init.constant((num_features,), "float32", 0.0)
def execute(self, x):
if self.is_train:
xmean = jt.mean(x, dims=[0], keepdims=1)
@ -328,20 +333,24 @@ class BatchNorm1d(Module):
running_mean = self.running_mean.broadcast(x, [0])
running_var = self.running_var.broadcast(x, [0])
norm_x = (x-running_mean)/jt.sqrt(running_var+self.eps)
if not self.affine:
return norm_x
w = self.weight.broadcast(x, [0])
b = self.bias.broadcast(x, [0])
return norm_x * w + b
class InstanceNorm2d(Module):
def __init__(self, num_features, eps=1e-05, momentum=0.1, affine=None, is_train=True, sync=True):
assert affine == None
def __init__(self, num_features, eps=1e-05, momentum=0.1, affine=True, is_train=True, sync=True):
self.sync = sync
self.num_features = num_features
self.is_train = is_train
self.eps = eps
self.momentum = momentum
self.weight = init.constant((num_features,), "float32", 1.0)
self.bias = init.constant((num_features,), "float32", 0.0)
self.affine = affine
if self.affine:
self.weight = init.constant((num_features,), "float32", 1.0)
self.bias = init.constant((num_features,), "float32", 0.0)
def execute(self, x):
xmean = jt.mean(x, dims=[2,3], keepdims=1)
@ -352,18 +361,22 @@ class InstanceNorm2d(Module):
xvar = jt.maximum(x2mean-xmean*xmean, 0)
norm_x = (x-xmean)/jt.sqrt(xvar+self.eps)
if not self.affine:
return norm_x
w = self.weight.broadcast(x, [0,2,3])
b = self.bias.broadcast(x, [0,2,3])
return norm_x * w + b
class GroupNorm(Module):
def __init__(self, num_groups, num_channels, eps=1e-05, affine=None, is_train=True):
assert affine == None
def __init__(self, num_groups, num_channels, eps=1e-05, affine=True, is_train=True):
self.num_groups = num_groups
self.num_channels = num_channels
self.eps = eps
self.weight = init.constant((num_channels,), "float32", 1.0)
self.bias = init.constant((num_channels,), "float32", 0.0)
self.affine = affine
if self.affine:
self.weight = init.constant((num_channels,), "float32", 1.0)
self.bias = init.constant((num_channels,), "float32", 0.0)
def execute(self, x):
N,C,H,W = x.shape
@ -374,6 +387,8 @@ class GroupNorm(Module):
x2mean = jt.mean(x*x, dims=[2,3], keepdims=1)
xvar = jt.maximum(x2mean-xmean*xmean, 0)
norm_x = (x-xmean)/jt.sqrt(xvar+self.eps)
if not self.affine:
return norm_x
w = self.weight.reshape((1,self.num_groups,C//self.num_groups,1))
b = self.bias.reshape((1,self.num_groups,C//self.num_groups,1))
return (norm_x * w + b).reshape((N,C,H,W))
@ -493,7 +508,11 @@ class ConvTranspose(Module):
self.weight = init.invariant_uniform((in_channels, out_channels) + self.kernel_size, dtype="float")
if bias:
self.bias = init.uniform([out_channels], dtype="float", low=-1, high=1)
fan=1
for i in self.weight.shape[1:]:
fan *= i
bound = 1 / math.sqrt(fan)
self.bias = init.uniform([out_channels], dtype="float", low=-bound, high=bound)
else:
self.bias = None

View File

@ -33,6 +33,11 @@ class Optimizer(object):
assert isinstance(pg, dict)
self.param_groups.append(pg)
self.n_step = 0
@property
def defaults(self):
exclude = set(("defaults", "param_groups", "n_step"))
return { k:v for k, v in self.__dict__.items() if k[0] != '_' and k not in exclude }
def pre_step(self, loss):
""" something should be done before step, such as calc gradients, mpi sync, and so on.
@ -76,8 +81,12 @@ class Optimizer(object):
pg_grads[i] = grads[pid].stop_grad()
pid += 1
def step(self, loss):
def backward(self, loss):
self.pre_step(loss)
def step(self, loss=None):
if loss is not None:
self.pre_step(loss)
for pg in self.param_groups:
lr = pg.get("lr", self.lr)
for p, g in zip(pg["params"], pg["grads"]):
@ -106,8 +115,9 @@ class SGD(Optimizer):
for p in pg["params"]:
values.append(jt.zeros(p.shape, p.dtype).stop_grad())
def step(self, loss):
self.pre_step(loss)
def step(self, loss=None):
if loss is not None:
self.pre_step(loss)
for pg in self.param_groups:
# get arguments from each param_groups
lr = pg.get("lr", self.lr)
@ -149,8 +159,9 @@ class RMSprop(Optimizer):
for p in pg["params"]:
values.append(jt.zeros(p.shape, p.dtype).stop_grad())
def step(self, loss):
self.pre_step(loss)
def step(self, loss=None):
if loss is not None:
self.pre_step(loss)
for pg in self.param_groups:
# get arguments from each param_groups
lr = pg.get("lr", self.lr)
@ -184,8 +195,9 @@ class Adam(Optimizer):
values.append(jt.zeros(p.shape, p.dtype).stop_grad())
m.append(jt.zeros(p.shape, p.dtype).stop_grad())
def step(self, loss):
self.pre_step(loss)
def step(self, loss=None):
if loss is not None:
self.pre_step(loss)
n = float(self.n_step)
for pg in self.param_groups:
# get arguments from each param_groups

View File

@ -21,7 +21,8 @@ class Linear(Module):
self.b = jt.random((out_features,))-0.5 if bias else None
def execute(self, x):
x = matmul(x, self.w)
if self.b: return x+self.b
if self.b is not None:
return x+self.b
return x
def relu(x):

View File

@ -140,10 +140,10 @@ class TestMklConvOp(unittest.TestCase):
dw_jt_tune=gs_tune[1].data
logs = find_log_with_re(rawlogs,
"Run tuner conv: confidence\\((20)\\) candidates\\((.*)\\)$")
assert len(logs) == 1
assert len(logs) == 2, len(logs)
assert logs[0][0] == "20", "confidence of reorder should be 20"
candidates = simple_parser(logs[0][1])
assert candidates == {"relay0":[1,0],"relay1":[1,0]}, candidates
assert candidates == {"relay0":[1,0]}, candidates
logs = find_log_with_re(rawlogs, r"get_relay_src([\s\S]*)")
assert len(logs)==2
@ -186,10 +186,11 @@ class TestMklConvOp(unittest.TestCase):
dw_jt_tune=gs_tune[1].data
logs = find_log_with_re(rawlogs,
"Run tuner conv: confidence\\((20)\\) candidates\\((.*)\\)$")
assert len(logs) == 1
assert len(logs) == 2
assert logs[0][0] == "20", "confidence of reorder should be 20"
candidates = simple_parser(logs[0][1])
assert candidates == {"relay0":[1,0],"relay1":[1,0]}, candidates
assert candidates == {"relay0":[1,0]}, candidates
# assert candidates == {"relay0":[1,0],"relay1":[1,0]}, candidates
logs = find_log_with_re(rawlogs, r"get_relay_src([\s\S]*)")
assert len(logs)==2

View File

@ -26,10 +26,10 @@ class FakeMpiBatchNorm(nn.Module):
self.is_train = is_train
self.eps = eps
self.momentum = momentum
self.weight = init.constant((num_features,), "float32", 1.0)
self.bias = init.constant((num_features,), "float32", 0.0)
self.running_mean = init.constant((num_features,), "float32", 0.0).stop_grad()
self.running_var = init.constant((num_features,), "float32", 1.0).stop_grad()
self.weight = init.constant((num_features,), "float32", 1.0)
self.bias = init.constant((num_features,), "float32", 0.0)
def execute(self, x, global_x):
if self.is_train:
@ -91,6 +91,7 @@ class TestMpiBatchnorm(unittest.TestCase):
gs2 = jt.grad(y2,bn2.parameters())
assert np.allclose(y1.data, y2.data, atol=1e-5),(mpi.world_rank(),y1.data, y2.data, y1.data-y2.data)
assert len(gs1) == len(gs2)
for i in range(len(gs1)):
assert np.allclose(gs1[i].data, gs2[i].data, rtol=1e-2),(mpi.world_rank(),gs1[i].data, gs2[i].data,gs1[i].data-gs2[i].data)

View File

@ -24,7 +24,7 @@ class TestNanoString(unittest.TestCase):
# t is about 0.01 for 100w loop
# 92ns one loop
print("nanostring time", t)
assert t < [1.5e-7, 1.7e-7][mid], t
assert t < [1.5e-7, 1.9e-7][mid], t
assert (jt.hash("asdasd") == 4152566416)
assert str(jt.NanoString("float"))=="float32"

View File

@ -4,6 +4,7 @@ import pickle
import numpy as np
import jittor_utils
from jittor_utils import LOG
import sys
jittor_utils.try_import_jit_utils_core()
@ -20,11 +21,58 @@ def convert(data):
if isinstance(data, dict):
return {k:convert(data[k]) for k in data}
if hasattr(data, "numpy"):
return data.detach().numpy()
if "Var" in data.__class__.__name__:
return data.numpy()
else:
return data.detach().cpu().numpy()
return data
rand_hooked = False
def hook_pt_rand(*shape):
import torch
if isinstance(shape, tuple) and len(shape)==1 and isinstance(shape[0], torch.Size):
shape = tuple(shape[0])
np.random.seed(0)
return torch.from_numpy(np.random.rand(*tuple(shape)).astype("float32"))
def hook_pt_normal(mean, std):
import torch
shape = tuple(mean.shape)
np.random.seed(0)
return torch.from_numpy(np.random.normal(size=shape).astype("float32")).to(std.device) * std + mean
def hook_jt_rand(shape, dtype="float32", rtype="uniform"):
import jittor
np.random.seed(0)
if rtype == "normal":
return jittor.array(np.random.normal(size=shape).astype(str(dtype)))
return jittor.array(np.random.rand(*shape).astype(str(dtype)))
def hook_rand():
global rand_hooked
if rand_hooked: return
rand_hooked = True
np.random.seed(0)
if "torch" in sys.modules:
LOG.i("Hook torch.rand")
torch = sys.modules["torch"]
torch.rand = hook_pt_rand
torch.normal = hook_pt_normal
torch.manual_seed(0)
if "jittor" in sys.modules:
jittor = sys.modules["jittor"]
LOG.i("Hook jittor.random")
jittor.random = hook_jt_rand
jittor.seed(0)
class Hook:
def __init__(self, base_name, rtol=5e-2, atol=1e-3):
if os.environ.get("use_auto_diff", '1') == '0':
return
hook_rand()
self.rid = 0
self.base_name = base_name
self.base_path = os.path.join(str(Path.home()), ".cache", "jittor", "auto_diff", base_name)
@ -47,16 +95,19 @@ class Hook:
return
has_error += 1
LOG.e(f"Ndarray <{name}> not match, shape:{a.shape}")
LOG.w(f"Ndarray <{name}> not match, shape:{a.shape}")
i = tuple( i[0] for i in index )
err_rate = is_error.mean()
LOG.e(f"error index at [{i}], a({a[i]}) b({b[i]}) err({err[i]}) > tol({tol[i]}), err_rate:{err_rate*100:.3f}%")
LOG.w(f"error index at [{i}], a({a[i]}) b({b[i]}) err({err[i]}) > tol({tol[i]}), err_rate:{err_rate*100:.3f}% amean({a.mean()}) bmean({b.mean()}) astd({a.std()}) bstd({b.std()}) ")
if err_rate > 0.01:
LOG.e("!"*10+"Very HIGH err rate"+"!"*10)
def check(self, name, pre_data, data):
global has_error
assert type(pre_data) == type(data)
if type(pre_data) != type(data):
LOG.e(f"type not match, {pre_data.__class__.__name__}!={data.__class__.__name__}, name: {name}")
has_error += 1
return
if isinstance(pre_data, (list, tuple)):
if len(pre_data) != len(data):
has_error += 1
@ -75,19 +126,26 @@ class Hook:
elif isinstance(pre_data, dict):
if len(pre_data) != len(data):
has_error += 1
LOG.e(f"Dict Name <{name}> len not match, {len(pre_data)} != {len(data)}")
LOG.w(f"Dict Name <{name}> len not match, {len(pre_data)} != {len(data)}")
for k in pre_data:
pv = pre_data[k]
if k not in data:
has_error += 1
LOG.e(f"Key <{k}> not in data, Name <{name}>")
msg = f"Key <{k}> not in data, Name <{name}>"
if isinstance(pv, np.ndarray):
LOG.e(msg)
else:
LOG.w(msg)
continue
self.check(name+f".{i}", pre_data[k], data[k])
self.check(name+f".{k}", pre_data[k], data[k])
else:
if pre_data != data:
has_error += 1
LOG.e(f"Type: {type(pre_data).__name__} Name <{name}> not match {pre_data} != {data}")
def record(self, name, data):
def record(self, name, data, ex_name=""):
if os.environ.get("use_auto_diff", '1') == '0':
return
rid = self.rid
self.rid += 1
fpath = os.path.join(self.base_path, f"{rid}.pkl")
@ -99,15 +157,18 @@ class Hook:
global has_error
has_error += 1
LOG.e(f"The {rid} result name not match, {pre_name} != {name}")
self.rid -= 1
return
LOG.i(f"check {rid}:<{name}> ...")
self.check(name, pre_data, data)
LOG.i(f"check {rid}:<{ex_name}{name}> ...")
self.check(ex_name+name, pre_data, data)
else:
with open(fpath, 'wb') as f:
pickle.dump((name, data), f)
LOG.i(f"save {rid}:<{name}> ok")
def record_params(self, parameters_dict):
if os.environ.get("use_auto_diff", '1') == '0':
return
rid = self.rid
self.rid += 1
global has_error
@ -162,23 +223,75 @@ class Hook:
LOG.i(f"save params ok")
def hook_module(self, mod):
def hook_module(self, mod, mod_name=""):
if os.environ.get("use_auto_diff", '1') == '0':
return
if mod_name != "":
mod_name = "<" + mod_name + ">"
def forward_hook(self2, input, output):
ex_name = '[' + self2.__class__.__name__ + ']'
if "relu" not in self2.__class__.__name__.lower():
# not test relu, because input may be inplaced
self.record(self2.__ad_mod_name__+".input", input)
self.record(self2.__ad_mod_name__+".output", output)
self.record(self2.__ad_mod_name__+".input", input, ex_name)
self.record(self2.__ad_mod_name__+".output", output, ex_name)
names = []
for name, module in mod.named_modules():
name = mod_name + name
module.__ad_mod_name__ = name
names.append(name)
module.register_forward_hook(forward_hook)
self.record_params(mod.state_dict())
mod_class_name = module.__class__.__name__.lower()
# make dropout in eval mod and record dropout.p
if "dropout" in mod_class_name:
self.record(name+'.p', module.p, "["+mod_class_name+"]")
module.eval()
ps = { mod_name+k:v for k, v in mod.state_dict().items() }
self.record_params(ps)
self.record("module names", names)
def hook_optimizer(self, opt, opt_name=""):
'''
net = Model()
opt = optim.SGD(net.parameters(), 0.1)
hook.hook_optimizer(opt)
'''
if os.environ.get("use_auto_diff", '1') == '0':
return
origin_step = opt.step
ex_name = '['+opt.__class__.__name__+']'
def step_hook(*args, **kw):
origin_step(*args, **kw)
self.record(opt_name+".default", opt.defaults, ex_name)
gid = 0
n_params = 0
for pg in opt.param_groups:
for p in pg["params"]:
if hasattr(p, "is_stop_grad"):
if p.is_stop_grad():
continue
n_params += 1
else:
n_params += 1
self.record(opt_name+".n_params", n_params, ex_name)
for pg in opt.param_groups:
for i, p in reversed(list(enumerate(pg["params"]))):
if hasattr(p, "is_stop_grad"):
if p.is_stop_grad():
continue
self.record(f"{opt_name}.grads[{gid}]", pg["grads"][i], "["+p.name()+"]")
self.record(f"{opt_name}.params[{gid}]", p, "["+p.name()+"]")
gid += 1
else:
self.record(f"{opt_name}.grads[{gid}]", p.grad)
self.record(f"{opt_name}.params[{gid}]", p)
gid += 1
opt.step = step_hook

View File

@ -464,7 +464,7 @@ void Executor::run_sync(vector<Var*> vars, bool device_sync) {
checkCudaErrors(cudaDeviceSynchronize());
}));
}
LOGvv << "cudaDeviceSynchronize times:" << sync_times << "/" <<queue.size();
LOGvv << "cudaDeviceSynchronize times:" << sync_times << "/" <<queue.size() << "device_sync:" << device_sync;
#endif
}

View File

@ -41,6 +41,10 @@ void Op::forward(Var* input) {
outputs_holder.emplace_back(input);
}
VarPtr Op::duplicate() {
return nullptr;
}
VarPtr Op::grad(Var* out, Var* dout, Var* v, int v_index) {
LOGw << "Grad of" << name() << "return zeros";
return nullptr;

View File

@ -47,6 +47,7 @@ struct Op : Node {
virtual void do_prepare();
virtual void do_run_after_prepare();
virtual void do_run();
virtual VarPtr duplicate();
void jit_run();
string name_ex() const;

View File

@ -87,6 +87,18 @@ BinaryOp::BinaryOp(Var* x, Var* y, NanoString op) : x(x), y(y) {
z = create_output(nullptr, binary_dtype_infer(op, x, y));
}
VarPtr dirty_clone_broadcast(Var* v) {
Op* op = v->input();
// dirty fix conv duplicated
if (op && !v->is_finished() && v->shape.size() > 4 && op->type() == OpType::broadcast) {
auto vp = op->duplicate();
if (vp) {
return vp;
}
}
return v;
}
VarPtr BinaryOp::grad(Var* out, Var* dout, Var* v, int v_index) {
if (ns == ns_add) return dout;
if (ns == ns_subtract) {
@ -97,9 +109,9 @@ VarPtr BinaryOp::grad(Var* out, Var* dout, Var* v, int v_index) {
}
if (ns == ns_multiply) {
if (v_index == 0)
return make_binary(y, dout, ns_multiply);
return make_binary(dirty_clone_broadcast(y), dirty_clone_broadcast(dout), ns_multiply);
else
return make_binary(x, dout, ns_multiply);
return make_binary(dirty_clone_broadcast(x), dirty_clone_broadcast(dout), ns_multiply);
}
if (ns == ns_divide) {
if (v_index == 0)

View File

@ -14,6 +14,10 @@ namespace jittor {
#ifndef JIT
static auto make_reduce = get_op_info("reduce")
.get_constructor<VarPtr, Var*, NanoString, uint, uint>();
static auto make_broadcast = get_op_info("broadcast_to")
.get_constructor<VarPtr, Var*, Var*, uint, uint>();
static auto make_broadcast2 = get_op_info("broadcast_to")
.get_constructor<VarPtr, Var*, NanoVector, uint, uint>();
BroadcastToOp::BroadcastToOp(Var* x, Var* y, NanoVector dims) : x(x), y(y) {
auto count = dims.size();
@ -51,6 +55,20 @@ BroadcastToOp::BroadcastToOp(Var* x, Var* y, uint dims_mask, uint keepdims_mask)
this->keepdims_mask = keepdims_mask;
}
BroadcastToOp::BroadcastToOp(Var* x, NanoVector shape, uint dims_mask, uint keepdims_mask) : x(x), y(nullptr), shape(shape) {
auto count = __builtin_popcount(dims_mask);
if (!count) {
forward(x);
return;
}
flags.set(NodeFlags::_cpu);
flags.set(NodeFlags::_cuda);
set_type(OpType::broadcast);
z = create_output(NanoVector(), x->dtype());
bcast_mask = dims_mask;
this->keepdims_mask = keepdims_mask;
}
BroadcastToOp::BroadcastToOp(Var* x, NanoVector shape, NanoVector dims) : x(x), y(nullptr), shape(shape) {
auto count = dims.size();
// forward x if don't need broadcast
@ -82,6 +100,13 @@ bool BroadcastToOp::need_broadcast(const Var* x, const NanoVector& shape) {
return false;
}
VarPtr BroadcastToOp::duplicate() {
if (y)
return make_broadcast(x, y, bcast_mask, keepdims_mask);
else
return make_broadcast2(x, shape, bcast_mask, keepdims_mask);
}
VarPtr BroadcastToOp::grad(Var* out, Var* dout, Var* v, int v_index) {
if (v_index==1) return nullptr;
if (bcast_mask==0) return dout;

View File

@ -21,12 +21,15 @@ struct BroadcastToOp : Op {
BroadcastToOp(Var* x, Var* y, NanoVector dims=NanoVector());
// @pybind(None)
BroadcastToOp(Var* x, Var* y, uint dims_mask, uint keepdims_mask);
// @pybind(None)
BroadcastToOp(Var* x, NanoVector shape, uint dims_mask, uint keepdims_mask);
bool need_broadcast(const Var* x, const NanoVector& shape);
const char* name() const override { return "broadcast_to"; }
VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override;
void infer_shape() override;
VarPtr duplicate() override;
DECLARE_jit_run;
};

View File

@ -13,6 +13,8 @@ namespace jittor {
#ifndef JIT
static auto make_reindex_reduce = get_op_info("reindex_reduce")
.get_constructor<VarPtr, Var*, NanoString, NanoVector, vector<string>&&, vector<string>&&, vector<Var*>&&>();
static auto make_reindex = get_op_info("reindex")
.get_constructor<VarPtr, Var*, NanoVector, vector<string>&&, float64, vector<string>&&, vector<Var*>&&>();
ReindexOp::ReindexOp(Var* x, NanoVector shape, vector<string>&& indexes, float64 overflow_value, vector<string>&& overflow_conditions, vector<Var*>&& extras)
: x(x),
@ -63,6 +65,10 @@ ReindexOp::ReindexOp(Var* x, vector<Var*>&& indexes, float64 overflow_value, vec
}
}
VarPtr ReindexOp::duplicate() {
return make_reindex(x, shape, clone(indexes), overflow_value, clone(overflow_conditions), clone(extras));
}
VarPtr ReindexOp::grad(Var* out, Var* dout, Var* v, int v_index) {
// Do not have grad to extras input
if (v_index) return nullptr;

View File

@ -98,6 +98,7 @@ struct ReindexOp : Op {
const char* name() const override { return "reindex"; }
VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override;
void infer_shape() override;
VarPtr duplicate() override;
DECLARE_jit_run;
};

View File

@ -50,19 +50,23 @@ void __unregiste_node_trace(Node* node) {
void __registe_node_trace_grad(Node* g, Node* node, int x_index) {
if (!g) return;
string& gname = trace_data.at(g);
string name = "grad(";
if (startswith(gname, "grad("))
return;
if (!node->is_var()) {
name += node->op()->name_ex();
name += ':';
name += S(x_index);
}
name += ":" + gname;
name += "):";
name += trace_data.at(node);
trace_data[g] = name;
gname = name;
std::function<void(Node*)> dfs = [&] (Node* node) {
for (Node* i : node->inputs()) {
string& iname = trace_data[i];
if (startswith(iname, "__init__.py:grad:")) {
if (iname.find("__init__.py:grad:") != string::npos && !startswith(iname, "grad(")) {
iname = name;
dfs(i);
}