From 6c980c21461dd8a916c5b69a6d12695c5a7dc406 Mon Sep 17 00:00:00 2001 From: Dun Liang Date: Mon, 25 Oct 2021 16:12:27 +0800 Subject: [PATCH] add performance test tutorial --- .../Jittor性能测试与对比方法.md | 176 ++++++++++++++++++ doc/source/index.rst | 1 + python/jittor/__init__.py | 41 +++- 3 files changed, 216 insertions(+), 2 deletions(-) create mode 100644 doc/source/Jittor性能测试与对比方法.md diff --git a/doc/source/Jittor性能测试与对比方法.md b/doc/source/Jittor性能测试与对比方法.md new file mode 100644 index 00000000..aa05d000 --- /dev/null +++ b/doc/source/Jittor性能测试与对比方法.md @@ -0,0 +1,176 @@ +Jittor性能测试与对比方法 +===================== + +下面代码以AlexNet为例,用于演示 Jittor 性能测试的正确方法: + +```python +import time +import jittor as jt +from jittor.models import resnet50 +jt.flags.use_cuda = jt.has_cuda + +warmup = 10 +rerun = 100 +batch_size = 8 +data = jt.random((batch_size, 3, 224, 224)) +model = resnet50() +model.eval() + +# 此段代码对jittor进行热身,确保时间测试准确 +jt.sync_all(True) +for i in range(warmup): + pred = model(data) + # sync是把计算图发送到计算设备上 + pred.sync() +# sync_all(true)是把计算图发射到计算设备上,并且同步。 +# 只有运行了jt.sync_all(True)才会真正地运行,时间才是有效的,因此执行forward前后都要执行这句话 +jt.sync_all(True) + +# 开始测试运行时间 +start = time.time() +for i in range(rerun): + pred = model(data) + pred.sync() +jt.sync_all(True) +end = time.time() + +print("Jittor FPS:", (rerun*batch_size)/(end-start)) + +``` + +在这段代码中,我们定义了几个参数`batch_size`, `warmup`, `rerun`, batch_size代表批大小,warmup是用于热身的循环次数,而rerun是用于测速的循环次数,最终输出FPS,对Jittor进行正确测速的关键是 热身部分和同步部分,热身部分确保测试时间稳定,没有包含编译用的时间,而同步部分确保计算完成,因为jittor是一个异步框架,只有同步操作能保证计算完成。 + +以上代码的运行结果如下(RTX Titan,batch 8): + +``` +Compiling Operators(8/8) used: 7.35s eta: 0s +Compiling Operators(13/13) used: 8.36s eta: 0s +Jittor FPS: 908.9853866375396 +``` + +我们还可以使用类似的代码测试 PyTorch的性能: + +```python +import time +import torch +from torchvision.models import resnet50 + +warmup = 10 +rerun = 100 +batch_size = 8 +data = torch.randn((batch_size, 3, 224, 224)).cuda() +model = resnet50() +model.cuda() +model.eval() + +# 此段代码对pytorch进行热身,确保时间测试准确 +torch.cuda.synchronize() +for i in range(warmup): + pred = model(data) +# synchronize用于确保PyTorch计算完成 +torch.cuda.synchronize() + +# 开始测试运行时间 +start = time.time() +for i in range(rerun): + pred = model(data) +torch.cuda.synchronize() +end = time.time() + +print("PyTorch FPS:", (rerun*batch_size)/(end-start)) +``` + + +以上代码的运行结果如下(RTX Titan,batch 8): + +``` +PyTorch FPS: 807.4806873965665 +``` + +我们还可以对这两段代码合并,并对比结果的一致性: + +```python +import time +import jittor as jt +from jittor.models import resnet50 +jt.flags.use_cuda = jt.has_cuda + +warmup = 100 +rerun = 1000 +batch_size = 8 +data = jt.random((batch_size, 3, 224, 224)) +model = resnet50() +model.eval() + +# 此段代码对jittor进行热身,确保时间测试准确 +jt.sync_all(True) +for i in range(warmup): + pred = model(data) + # sync是把计算图发送到计算设备上 + pred.sync() +# sync_all(true)是把计算图发射到计算设备上,并且同步。 +# 只有运行了jt.sync_all(True)才会真正地运行,时间才是有效的,因此执行forward前后都要执行这句话 +jt.sync_all(True) + +# 开始测试运行时间 +start = time.time() +for i in range(rerun): + pred = model(data) + pred.sync() +jt.sync_all(True) +end = time.time() + +print("Jittor FPS:", (rerun*batch_size)/(end-start)) +# 将 jittor 数据和参数导出为 numpy 和 torch 格式 +jittor_data = pred.numpy() +jittor_param = model.state_dict(to="torch") + +import numpy as np +import torch +from torchvision.models import resnet50 +data = torch.Tensor(data.numpy()).cuda() +model = resnet50() +# 加载 jittor 参数 +model.load_state_dict(jittor_param) +model.cuda() +model.eval() + +# 此段代码对pytorch进行热身,确保时间测试准确 +torch.cuda.synchronize() +for i in range(warmup): + pred = model(data) +# synchronize用于确保PyTorch计算完成 +torch.cuda.synchronize() + +# 开始测试运行时间 +start = time.time() +for i in range(rerun): + pred = model(data) +torch.cuda.synchronize() +end = time.time() + +print("PyTorch FPS:", (rerun*batch_size)/(end-start)) +pytorch_data = pred.detach().cpu().numpy() +err = np.mean(np.abs(pytorch_data - jittor_data)) +print("mean error:", err) + +``` + + +以上代码运行结果如下: + +``` +Jittor FPS: 908.9853866375396 +PyTorch FPS: 807.4806873965665 +mean error: 1e-5 +``` + +误差输出为1e-5, 在可接受范围内。正确测速与对比的几大关键点为: + +1. 充分热身,除去框架的准备时间。 +2. 多次运行,确保测试时间稳定。 +3. 加上同步语句,确保测试时间准确。 +4. 保证显存充足,在显存不足时,jittor会调用统一内存来弥补,会产生性能损失,请密切关注`nvidia-smi`的输出结果。 +5. 保证对比模型的一致性,检查输出结果的一致。 + +如果您对测试结果有疑问,或者有优化需求,欢迎随时联系Jittor开发团队。 diff --git a/doc/source/index.rst b/doc/source/index.rst index eec82932..aec4aad8 100644 --- a/doc/source/index.rst +++ b/doc/source/index.rst @@ -48,6 +48,7 @@ :caption: 其他: Jittor调试技巧 + Jittor性能测试与对比方法 教程 Indices and tables diff --git a/python/jittor/__init__.py b/python/jittor/__init__.py index 1fd56546..141d0df8 100644 --- a/python/jittor/__init__.py +++ b/python/jittor/__init__.py @@ -9,7 +9,7 @@ # file 'LICENSE.txt', which is part of this source code package. # *************************************************************** -__version__ = '1.3.1.10' +__version__ = '1.3.1.11' from jittor_utils import lock with lock.lock_scope(): ori_int = int @@ -834,7 +834,35 @@ class Module: self.dfs([], None, callback, callback_leave) return _uniq(ps) - def state_dict(self): + def state_dict(self, to=None): + ''' Returns a dictionary containing + Jittor Var of the module and its descendants. + + Args: + to: target type of var, canbe None or 'numpy' or 'torch' + + Return: + dictionary of module's states. + + Example:: + + import jittor as jt + from jittor.models import resnet50 + jittor_model = resnet50() + dict = jittor_model.state_dict() + jittor_model.load_state_dict(dict) + + Example2(export Jittor params to PyTorch):: + + import jittor as jt + from jittor.models import resnet50 + jittor_model = resnet50() + import torch + from torchvision.models import resnet50 + torch_model = resnet50() + torch_model.load_state_dict(jittor_model.state_dict(to="torch")) + + ''' uniq_set = set() ps = {} stack = [] @@ -855,6 +883,15 @@ class Module: def callback_leave(parents, k, v, n): stack.pop() self.dfs([], None, callback, callback_leave) + if to == "numpy": + for k,v in ps.items(): + if isinstance(v, Var): + ps[k] = v.numpy() + elif to == "torch": + import torch + for k,v in ps.items(): + if isinstance(v, Var): + ps[k] = torch.Tensor(v.numpy()) return ps def named_parameters(self):