mirror of https://github.com/Jittor/Jittor
reduce with f64
This commit is contained in:
parent
37df805e57
commit
ada16207be
|
@ -258,6 +258,13 @@ static auto make_unary = get_op_info("unary")
|
|||
|
||||
ReduceOp::ReduceOp(Var* x, NanoString op, NanoVector dims, bool keepdims)
|
||||
: x(x) {
|
||||
if (x->dtype() == ns_float32 && !(amp_reg & amp_keep_reduce) && (op==ns_add || op==ns_mean)) {
|
||||
auto out = make_unary(x, ns_float64);
|
||||
out = make_reduce(out, op, dims, keepdims);
|
||||
out = make_unary(out, ns_float32);
|
||||
forward(out);
|
||||
return;
|
||||
}
|
||||
flags.set(NodeFlags::_cpu);
|
||||
flags.set(NodeFlags::_cuda);
|
||||
set_type(OpType::reduce);
|
||||
|
@ -277,13 +284,6 @@ ReduceOp::ReduceOp(Var* x, NanoString op, NanoVector dims, bool keepdims)
|
|||
reduce_mask |= 1<<dim;
|
||||
}
|
||||
}
|
||||
if (x->dtype() == ns_float32 && !(amp_reg & amp_keep_reduce) && (op==ns_add || op==ns_mean || op==ns_multiply)) {
|
||||
auto out = make_unary(x, ns_float64);
|
||||
out = make_reduce(out, op, dims, keepdims);
|
||||
out = make_unary(out, ns_float32);
|
||||
forward(out);
|
||||
return;
|
||||
}
|
||||
// if (x->dtype() == ns_bool && ns == ns_add)
|
||||
if (x->dtype() == ns_bool)
|
||||
y = create_output(nullptr, ns_int32);
|
||||
|
@ -293,6 +293,13 @@ ReduceOp::ReduceOp(Var* x, NanoString op, NanoVector dims, bool keepdims)
|
|||
|
||||
ReduceOp::ReduceOp(Var* x, NanoString op, uint dims_mask, uint keepdims_mask)
|
||||
: x(x) {
|
||||
if (x->dtype() == ns_float32 && !(amp_reg & amp_keep_reduce) && (op==ns_add || op==ns_mean)) {
|
||||
auto out1 = make_unary(x, ns_float64);
|
||||
auto out2 = make_reduce2(out1, op, dims_mask, keepdims_mask);
|
||||
auto out3 = make_unary(out2, ns_float32);
|
||||
forward(out3);
|
||||
return;
|
||||
}
|
||||
flags.set(NodeFlags::_cpu);
|
||||
flags.set(NodeFlags::_cuda);
|
||||
set_type(OpType::reduce);
|
||||
|
@ -302,13 +309,6 @@ ReduceOp::ReduceOp(Var* x, NanoString op, uint dims_mask, uint keepdims_mask)
|
|||
ASSERT(ns.is_binary());
|
||||
reduce_mask = dims_mask;
|
||||
this->keepdims_mask = keepdims_mask;
|
||||
if (x->dtype() == ns_float32 && !(amp_reg & amp_keep_reduce) && (op==ns_add || op==ns_mean || op==ns_multiply)) {
|
||||
auto out = make_unary(x, ns_float64);
|
||||
out = make_reduce2(x, op, dims_mask, keepdims_mask);
|
||||
out = make_unary(out, ns_float32);
|
||||
forward(out);
|
||||
return;
|
||||
}
|
||||
y = create_output(nullptr, reduce_dtype_infer(ns, x->ns));
|
||||
}
|
||||
|
||||
|
|
|
@ -59,7 +59,7 @@ class TestReduceF64Op(unittest.TestCase):
|
|||
jt.flags.use_cuda = True
|
||||
x = gen_data((3,32,64,64))
|
||||
|
||||
with jt.profile_scope() as report:
|
||||
with jt.profile_scope(log_silent=1) as report:
|
||||
x_jt = jt.array(x)
|
||||
print(x_jt.sum())
|
||||
check_report(report)
|
||||
|
@ -68,7 +68,7 @@ class TestReduceF64Op(unittest.TestCase):
|
|||
jt.flags.use_cuda = True
|
||||
x = gen_data((3,32,64,64))
|
||||
|
||||
with jt.profile_scope() as report:
|
||||
with jt.profile_scope(log_silent=1) as report:
|
||||
x_jt = jt.array(x)
|
||||
print(x_jt.mean())
|
||||
check_report(report)
|
||||
|
@ -77,7 +77,7 @@ class TestReduceF64Op(unittest.TestCase):
|
|||
jt.flags.use_cuda = True
|
||||
x = gen_data((3,32,64,64))
|
||||
|
||||
with jt.profile_scope() as report:
|
||||
with jt.profile_scope(log_silent=1) as report:
|
||||
x_jt = jt.array(x)
|
||||
bn = nn.BatchNorm(32)
|
||||
bn.train()
|
||||
|
@ -93,7 +93,7 @@ class TestReduceF64Op(unittest.TestCase):
|
|||
@unittest.skipIf(not jt.has_cuda, "Cuda not found")
|
||||
@jt.flag_scope(use_cuda=1, use_stat_allocator=1)
|
||||
def test_resnet(self):
|
||||
with jt.profile_scope() as report:
|
||||
with jt.profile_scope(log_silent=1) as report:
|
||||
self.setup_seed(1)
|
||||
|
||||
# hyper-parameters
|
||||
|
@ -111,6 +111,7 @@ class TestReduceF64Op(unittest.TestCase):
|
|||
loss_list=[]
|
||||
acc_list=[]
|
||||
mnist_net = MnistNet()
|
||||
mnist_net.train()
|
||||
global prev
|
||||
prev = time.time()
|
||||
SGD = nn.SGD(mnist_net.parameters(), self.learning_rate, self.momentum, self.weight_decay)
|
||||
|
@ -178,7 +179,7 @@ class TestReduceF64Op(unittest.TestCase):
|
|||
if jt.in_mpi:
|
||||
assert jt.core.number_of_lived_vars() < 8100, jt.core.number_of_lived_vars()
|
||||
else:
|
||||
assert jt.core.number_of_lived_vars() < 7000, jt.core.number_of_lived_vars()
|
||||
assert jt.core.number_of_lived_vars() < 8000, jt.core.number_of_lived_vars()
|
||||
if self.train_loader.epoch_id >= 2:
|
||||
break
|
||||
|
||||
|
|
|
@ -130,7 +130,7 @@ class TestResnetFp32(unittest.TestCase):
|
|||
if jt.in_mpi:
|
||||
assert jt.core.number_of_lived_vars() < 8100, jt.core.number_of_lived_vars()
|
||||
else:
|
||||
assert jt.core.number_of_lived_vars() < 7000, jt.core.number_of_lived_vars()
|
||||
assert jt.core.number_of_lived_vars() < 8000, jt.core.number_of_lived_vars()
|
||||
if self.train_loader.epoch_id >= 2:
|
||||
break
|
||||
|
||||
|
|
Loading…
Reference in New Issue