mirror of https://github.com/Jittor/Jittor
add performance test tutorial
This commit is contained in:
parent
856f201d9f
commit
6c980c2146
|
@ -0,0 +1,176 @@
|
||||||
|
Jittor性能测试与对比方法
|
||||||
|
=====================
|
||||||
|
|
||||||
|
下面代码以AlexNet为例,用于演示 Jittor 性能测试的正确方法:
|
||||||
|
|
||||||
|
```python
|
||||||
|
import time
|
||||||
|
import jittor as jt
|
||||||
|
from jittor.models import resnet50
|
||||||
|
jt.flags.use_cuda = jt.has_cuda
|
||||||
|
|
||||||
|
warmup = 10
|
||||||
|
rerun = 100
|
||||||
|
batch_size = 8
|
||||||
|
data = jt.random((batch_size, 3, 224, 224))
|
||||||
|
model = resnet50()
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
# 此段代码对jittor进行热身,确保时间测试准确
|
||||||
|
jt.sync_all(True)
|
||||||
|
for i in range(warmup):
|
||||||
|
pred = model(data)
|
||||||
|
# sync是把计算图发送到计算设备上
|
||||||
|
pred.sync()
|
||||||
|
# sync_all(true)是把计算图发射到计算设备上,并且同步。
|
||||||
|
# 只有运行了jt.sync_all(True)才会真正地运行,时间才是有效的,因此执行forward前后都要执行这句话
|
||||||
|
jt.sync_all(True)
|
||||||
|
|
||||||
|
# 开始测试运行时间
|
||||||
|
start = time.time()
|
||||||
|
for i in range(rerun):
|
||||||
|
pred = model(data)
|
||||||
|
pred.sync()
|
||||||
|
jt.sync_all(True)
|
||||||
|
end = time.time()
|
||||||
|
|
||||||
|
print("Jittor FPS:", (rerun*batch_size)/(end-start))
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
在这段代码中,我们定义了几个参数`batch_size`, `warmup`, `rerun`, batch_size代表批大小,warmup是用于热身的循环次数,而rerun是用于测速的循环次数,最终输出FPS,对Jittor进行正确测速的关键是 热身部分和同步部分,热身部分确保测试时间稳定,没有包含编译用的时间,而同步部分确保计算完成,因为jittor是一个异步框架,只有同步操作能保证计算完成。
|
||||||
|
|
||||||
|
以上代码的运行结果如下(RTX Titan,batch 8):
|
||||||
|
|
||||||
|
```
|
||||||
|
Compiling Operators(8/8) used: 7.35s eta: 0s
|
||||||
|
Compiling Operators(13/13) used: 8.36s eta: 0s
|
||||||
|
Jittor FPS: 908.9853866375396
|
||||||
|
```
|
||||||
|
|
||||||
|
我们还可以使用类似的代码测试 PyTorch的性能:
|
||||||
|
|
||||||
|
```python
|
||||||
|
import time
|
||||||
|
import torch
|
||||||
|
from torchvision.models import resnet50
|
||||||
|
|
||||||
|
warmup = 10
|
||||||
|
rerun = 100
|
||||||
|
batch_size = 8
|
||||||
|
data = torch.randn((batch_size, 3, 224, 224)).cuda()
|
||||||
|
model = resnet50()
|
||||||
|
model.cuda()
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
# 此段代码对pytorch进行热身,确保时间测试准确
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
for i in range(warmup):
|
||||||
|
pred = model(data)
|
||||||
|
# synchronize用于确保PyTorch计算完成
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
# 开始测试运行时间
|
||||||
|
start = time.time()
|
||||||
|
for i in range(rerun):
|
||||||
|
pred = model(data)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
end = time.time()
|
||||||
|
|
||||||
|
print("PyTorch FPS:", (rerun*batch_size)/(end-start))
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
以上代码的运行结果如下(RTX Titan,batch 8):
|
||||||
|
|
||||||
|
```
|
||||||
|
PyTorch FPS: 807.4806873965665
|
||||||
|
```
|
||||||
|
|
||||||
|
我们还可以对这两段代码合并,并对比结果的一致性:
|
||||||
|
|
||||||
|
```python
|
||||||
|
import time
|
||||||
|
import jittor as jt
|
||||||
|
from jittor.models import resnet50
|
||||||
|
jt.flags.use_cuda = jt.has_cuda
|
||||||
|
|
||||||
|
warmup = 100
|
||||||
|
rerun = 1000
|
||||||
|
batch_size = 8
|
||||||
|
data = jt.random((batch_size, 3, 224, 224))
|
||||||
|
model = resnet50()
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
# 此段代码对jittor进行热身,确保时间测试准确
|
||||||
|
jt.sync_all(True)
|
||||||
|
for i in range(warmup):
|
||||||
|
pred = model(data)
|
||||||
|
# sync是把计算图发送到计算设备上
|
||||||
|
pred.sync()
|
||||||
|
# sync_all(true)是把计算图发射到计算设备上,并且同步。
|
||||||
|
# 只有运行了jt.sync_all(True)才会真正地运行,时间才是有效的,因此执行forward前后都要执行这句话
|
||||||
|
jt.sync_all(True)
|
||||||
|
|
||||||
|
# 开始测试运行时间
|
||||||
|
start = time.time()
|
||||||
|
for i in range(rerun):
|
||||||
|
pred = model(data)
|
||||||
|
pred.sync()
|
||||||
|
jt.sync_all(True)
|
||||||
|
end = time.time()
|
||||||
|
|
||||||
|
print("Jittor FPS:", (rerun*batch_size)/(end-start))
|
||||||
|
# 将 jittor 数据和参数导出为 numpy 和 torch 格式
|
||||||
|
jittor_data = pred.numpy()
|
||||||
|
jittor_param = model.state_dict(to="torch")
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torchvision.models import resnet50
|
||||||
|
data = torch.Tensor(data.numpy()).cuda()
|
||||||
|
model = resnet50()
|
||||||
|
# 加载 jittor 参数
|
||||||
|
model.load_state_dict(jittor_param)
|
||||||
|
model.cuda()
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
# 此段代码对pytorch进行热身,确保时间测试准确
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
for i in range(warmup):
|
||||||
|
pred = model(data)
|
||||||
|
# synchronize用于确保PyTorch计算完成
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
# 开始测试运行时间
|
||||||
|
start = time.time()
|
||||||
|
for i in range(rerun):
|
||||||
|
pred = model(data)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
end = time.time()
|
||||||
|
|
||||||
|
print("PyTorch FPS:", (rerun*batch_size)/(end-start))
|
||||||
|
pytorch_data = pred.detach().cpu().numpy()
|
||||||
|
err = np.mean(np.abs(pytorch_data - jittor_data))
|
||||||
|
print("mean error:", err)
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
以上代码运行结果如下:
|
||||||
|
|
||||||
|
```
|
||||||
|
Jittor FPS: 908.9853866375396
|
||||||
|
PyTorch FPS: 807.4806873965665
|
||||||
|
mean error: 1e-5
|
||||||
|
```
|
||||||
|
|
||||||
|
误差输出为1e-5, 在可接受范围内。正确测速与对比的几大关键点为:
|
||||||
|
|
||||||
|
1. 充分热身,除去框架的准备时间。
|
||||||
|
2. 多次运行,确保测试时间稳定。
|
||||||
|
3. 加上同步语句,确保测试时间准确。
|
||||||
|
4. 保证显存充足,在显存不足时,jittor会调用统一内存来弥补,会产生性能损失,请密切关注`nvidia-smi`的输出结果。
|
||||||
|
5. 保证对比模型的一致性,检查输出结果的一致。
|
||||||
|
|
||||||
|
如果您对测试结果有疑问,或者有优化需求,欢迎随时联系Jittor开发团队。
|
|
@ -48,6 +48,7 @@
|
||||||
:caption: 其他:
|
:caption: 其他:
|
||||||
|
|
||||||
Jittor调试技巧
|
Jittor调试技巧
|
||||||
|
Jittor性能测试与对比方法
|
||||||
教程 <https://cg.cs.tsinghua.edu.cn/jittor/tutorial/>
|
教程 <https://cg.cs.tsinghua.edu.cn/jittor/tutorial/>
|
||||||
|
|
||||||
Indices and tables
|
Indices and tables
|
||||||
|
|
|
@ -9,7 +9,7 @@
|
||||||
# file 'LICENSE.txt', which is part of this source code package.
|
# file 'LICENSE.txt', which is part of this source code package.
|
||||||
# ***************************************************************
|
# ***************************************************************
|
||||||
|
|
||||||
__version__ = '1.3.1.10'
|
__version__ = '1.3.1.11'
|
||||||
from jittor_utils import lock
|
from jittor_utils import lock
|
||||||
with lock.lock_scope():
|
with lock.lock_scope():
|
||||||
ori_int = int
|
ori_int = int
|
||||||
|
@ -834,7 +834,35 @@ class Module:
|
||||||
self.dfs([], None, callback, callback_leave)
|
self.dfs([], None, callback, callback_leave)
|
||||||
return _uniq(ps)
|
return _uniq(ps)
|
||||||
|
|
||||||
def state_dict(self):
|
def state_dict(self, to=None):
|
||||||
|
''' Returns a dictionary containing
|
||||||
|
Jittor Var of the module and its descendants.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
to: target type of var, canbe None or 'numpy' or 'torch'
|
||||||
|
|
||||||
|
Return:
|
||||||
|
dictionary of module's states.
|
||||||
|
|
||||||
|
Example::
|
||||||
|
|
||||||
|
import jittor as jt
|
||||||
|
from jittor.models import resnet50
|
||||||
|
jittor_model = resnet50()
|
||||||
|
dict = jittor_model.state_dict()
|
||||||
|
jittor_model.load_state_dict(dict)
|
||||||
|
|
||||||
|
Example2(export Jittor params to PyTorch)::
|
||||||
|
|
||||||
|
import jittor as jt
|
||||||
|
from jittor.models import resnet50
|
||||||
|
jittor_model = resnet50()
|
||||||
|
import torch
|
||||||
|
from torchvision.models import resnet50
|
||||||
|
torch_model = resnet50()
|
||||||
|
torch_model.load_state_dict(jittor_model.state_dict(to="torch"))
|
||||||
|
|
||||||
|
'''
|
||||||
uniq_set = set()
|
uniq_set = set()
|
||||||
ps = {}
|
ps = {}
|
||||||
stack = []
|
stack = []
|
||||||
|
@ -855,6 +883,15 @@ class Module:
|
||||||
def callback_leave(parents, k, v, n):
|
def callback_leave(parents, k, v, n):
|
||||||
stack.pop()
|
stack.pop()
|
||||||
self.dfs([], None, callback, callback_leave)
|
self.dfs([], None, callback, callback_leave)
|
||||||
|
if to == "numpy":
|
||||||
|
for k,v in ps.items():
|
||||||
|
if isinstance(v, Var):
|
||||||
|
ps[k] = v.numpy()
|
||||||
|
elif to == "torch":
|
||||||
|
import torch
|
||||||
|
for k,v in ps.items():
|
||||||
|
if isinstance(v, Var):
|
||||||
|
ps[k] = torch.Tensor(v.numpy())
|
||||||
return ps
|
return ps
|
||||||
|
|
||||||
def named_parameters(self):
|
def named_parameters(self):
|
||||||
|
|
Loading…
Reference in New Issue