reduce with f64

This commit is contained in:
cxjyxx_me 2022-05-05 05:32:42 -04:00
parent 37df805e57
commit ada16207be
3 changed files with 21 additions and 20 deletions

View File

@ -258,6 +258,13 @@ static auto make_unary = get_op_info("unary")
ReduceOp::ReduceOp(Var* x, NanoString op, NanoVector dims, bool keepdims)
: x(x) {
if (x->dtype() == ns_float32 && !(amp_reg & amp_keep_reduce) && (op==ns_add || op==ns_mean)) {
auto out = make_unary(x, ns_float64);
out = make_reduce(out, op, dims, keepdims);
out = make_unary(out, ns_float32);
forward(out);
return;
}
flags.set(NodeFlags::_cpu);
flags.set(NodeFlags::_cuda);
set_type(OpType::reduce);
@ -277,13 +284,6 @@ ReduceOp::ReduceOp(Var* x, NanoString op, NanoVector dims, bool keepdims)
reduce_mask |= 1<<dim;
}
}
if (x->dtype() == ns_float32 && !(amp_reg & amp_keep_reduce) && (op==ns_add || op==ns_mean || op==ns_multiply)) {
auto out = make_unary(x, ns_float64);
out = make_reduce(out, op, dims, keepdims);
out = make_unary(out, ns_float32);
forward(out);
return;
}
// if (x->dtype() == ns_bool && ns == ns_add)
if (x->dtype() == ns_bool)
y = create_output(nullptr, ns_int32);
@ -293,6 +293,13 @@ ReduceOp::ReduceOp(Var* x, NanoString op, NanoVector dims, bool keepdims)
ReduceOp::ReduceOp(Var* x, NanoString op, uint dims_mask, uint keepdims_mask)
: x(x) {
if (x->dtype() == ns_float32 && !(amp_reg & amp_keep_reduce) && (op==ns_add || op==ns_mean)) {
auto out1 = make_unary(x, ns_float64);
auto out2 = make_reduce2(out1, op, dims_mask, keepdims_mask);
auto out3 = make_unary(out2, ns_float32);
forward(out3);
return;
}
flags.set(NodeFlags::_cpu);
flags.set(NodeFlags::_cuda);
set_type(OpType::reduce);
@ -302,13 +309,6 @@ ReduceOp::ReduceOp(Var* x, NanoString op, uint dims_mask, uint keepdims_mask)
ASSERT(ns.is_binary());
reduce_mask = dims_mask;
this->keepdims_mask = keepdims_mask;
if (x->dtype() == ns_float32 && !(amp_reg & amp_keep_reduce) && (op==ns_add || op==ns_mean || op==ns_multiply)) {
auto out = make_unary(x, ns_float64);
out = make_reduce2(x, op, dims_mask, keepdims_mask);
out = make_unary(out, ns_float32);
forward(out);
return;
}
y = create_output(nullptr, reduce_dtype_infer(ns, x->ns));
}

View File

@ -59,7 +59,7 @@ class TestReduceF64Op(unittest.TestCase):
jt.flags.use_cuda = True
x = gen_data((3,32,64,64))
with jt.profile_scope() as report:
with jt.profile_scope(log_silent=1) as report:
x_jt = jt.array(x)
print(x_jt.sum())
check_report(report)
@ -68,7 +68,7 @@ class TestReduceF64Op(unittest.TestCase):
jt.flags.use_cuda = True
x = gen_data((3,32,64,64))
with jt.profile_scope() as report:
with jt.profile_scope(log_silent=1) as report:
x_jt = jt.array(x)
print(x_jt.mean())
check_report(report)
@ -77,7 +77,7 @@ class TestReduceF64Op(unittest.TestCase):
jt.flags.use_cuda = True
x = gen_data((3,32,64,64))
with jt.profile_scope() as report:
with jt.profile_scope(log_silent=1) as report:
x_jt = jt.array(x)
bn = nn.BatchNorm(32)
bn.train()
@ -93,7 +93,7 @@ class TestReduceF64Op(unittest.TestCase):
@unittest.skipIf(not jt.has_cuda, "Cuda not found")
@jt.flag_scope(use_cuda=1, use_stat_allocator=1)
def test_resnet(self):
with jt.profile_scope() as report:
with jt.profile_scope(log_silent=1) as report:
self.setup_seed(1)
# hyper-parameters
@ -111,6 +111,7 @@ class TestReduceF64Op(unittest.TestCase):
loss_list=[]
acc_list=[]
mnist_net = MnistNet()
mnist_net.train()
global prev
prev = time.time()
SGD = nn.SGD(mnist_net.parameters(), self.learning_rate, self.momentum, self.weight_decay)
@ -178,7 +179,7 @@ class TestReduceF64Op(unittest.TestCase):
if jt.in_mpi:
assert jt.core.number_of_lived_vars() < 8100, jt.core.number_of_lived_vars()
else:
assert jt.core.number_of_lived_vars() < 7000, jt.core.number_of_lived_vars()
assert jt.core.number_of_lived_vars() < 8000, jt.core.number_of_lived_vars()
if self.train_loader.epoch_id >= 2:
break

View File

@ -130,7 +130,7 @@ class TestResnetFp32(unittest.TestCase):
if jt.in_mpi:
assert jt.core.number_of_lived_vars() < 8100, jt.core.number_of_lived_vars()
else:
assert jt.core.number_of_lived_vars() < 7000, jt.core.number_of_lived_vars()
assert jt.core.number_of_lived_vars() < 8000, jt.core.number_of_lived_vars()
if self.train_loader.epoch_id >= 2:
break