mirror of https://github.com/Jittor/Jittor
121 lines
3.7 KiB
Python
121 lines
3.7 KiB
Python
# ***************************************************************
|
|
# Copyright (c) 2021 Jittor. All Rights Reserved.
|
|
# Maintainers: Dun Liang <randonlang@gmail.com>.
|
|
# This file is subject to the terms and conditions defined in
|
|
# file 'LICENSE.txt', which is part of this source code package.
|
|
# ***************************************************************
|
|
import unittest
|
|
import jittor as jt
|
|
import numpy as np
|
|
import os
|
|
|
|
from jittor.test.misc import superglue
|
|
from jittor.test.misc.superglue import SuperGlue
|
|
import time
|
|
|
|
@jt.flag_scope(use_cuda=1)
|
|
def main():
|
|
global superglue
|
|
superglue.split_size = int(os.environ.get("split_size", "12"))
|
|
# superglue.split_size = 1000000
|
|
|
|
batch = 30
|
|
num = 2000
|
|
dim = 128
|
|
|
|
# jt.display_memory_info()
|
|
# os.system("nvidia-smi")
|
|
# breakpoint()
|
|
|
|
with jt.no_grad():
|
|
|
|
config = {
|
|
'superglue': {
|
|
'sinkhorn_iterations': 25,
|
|
'match_threshold': 0.01,
|
|
'keypoint_position_dim': 2,
|
|
'descriptor_dim': dim,
|
|
'use_dual_softmax': True,
|
|
'GNN_layers': ['self', 'cross'] * 9,
|
|
}
|
|
}
|
|
|
|
superglue = SuperGlue(config.get('superglue', {}))
|
|
|
|
superglue.eval()
|
|
|
|
data = {
|
|
'keypoints0': jt.rand((batch, num, 2), dtype=jt.float),
|
|
'keypoints1': jt.rand((batch, num, 2), dtype=jt.float),
|
|
'shape0': jt.rand((batch, 2), dtype=jt.float),
|
|
'shape1': jt.rand((batch, 2), dtype=jt.float),
|
|
'descriptors0': jt.rand((batch, dim, num), dtype=jt.float),
|
|
'descriptors1': jt.rand((batch, dim, num), dtype=jt.float),
|
|
'scores0': jt.rand((batch, num), dtype=jt.float),
|
|
'scores1': jt.rand((batch, num), dtype=jt.float),
|
|
'all_matches': jt.randint(0, num, (batch, num, 2), dtype=jt.int),
|
|
'return_match': False,
|
|
# 'match_num': match_num
|
|
}
|
|
|
|
use_fp16 = int(os.environ.get("use_fp16", "0"))
|
|
if use_fp16:
|
|
jt.flags.amp_reg = 2
|
|
for k,v in data.items():
|
|
if isinstance(v, jt.Var) and v.dtype == "float32":
|
|
v.assign(v.float16())
|
|
for v in superglue.parameters():
|
|
if v.dtype == "float32":
|
|
v.assign(v.float16())
|
|
jt.sync_all(True)
|
|
|
|
import pickle
|
|
jt.sync_all(True)
|
|
for x in range(5):
|
|
print(x)
|
|
jt.gc()
|
|
x = superglue(data)['loss']
|
|
x.sync()
|
|
jt.display_memory_info()
|
|
# os.system("nvidia-smi")
|
|
# breakpoint()
|
|
# print(data)
|
|
# print(x)
|
|
|
|
# with open("/tmp/record.pkl", "wb") as f:
|
|
# pickle.dump([data, x], f, pickle.HIGHEST_PROTOCOL)
|
|
|
|
# with jt.flag_scope(trace_py_var=3, profile_memory_enable=1):
|
|
# x = superglue(data)['loss']
|
|
# x.sync()
|
|
# jt.get_max_memory_treemap()
|
|
# exit(0)
|
|
|
|
jt.sync_all(True)
|
|
time0 = time.time()
|
|
jt.flags.profiler_enable = int(os.environ.get("profiler", "0"))
|
|
|
|
for x in range(20):
|
|
print(x)
|
|
# jt.display_memory_info()
|
|
x = superglue(data)['loss']
|
|
x.sync()
|
|
# print(x)
|
|
|
|
jt.sync_all(True)
|
|
time1 = time.time()
|
|
print("avg time:", (time1 - time0) / 20)
|
|
return (time1 - time0) / 20
|
|
|
|
|
|
class TestSuperglue(unittest.TestCase):
|
|
def test(self):
|
|
if not jt.has_cuda: return
|
|
t1 = main()
|
|
os.environ["use_fp16"] = "1"
|
|
t2 = main()
|
|
os.environ["use_fp16"] = "0"
|
|
assert t1*0.55 > t2
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main() |