mirror of https://github.com/Jittor/Jittor
177 lines
4.8 KiB
Markdown
177 lines
4.8 KiB
Markdown
Jittor性能测试与对比方法
|
||
=====================
|
||
|
||
下面代码以AlexNet为例,用于演示 Jittor 性能测试的正确方法:
|
||
|
||
```python
|
||
import time
|
||
import jittor as jt
|
||
from jittor.models import resnet50
|
||
jt.flags.use_cuda = jt.has_cuda
|
||
|
||
warmup = 10
|
||
rerun = 100
|
||
batch_size = 8
|
||
data = jt.random((batch_size, 3, 224, 224))
|
||
model = resnet50()
|
||
model.eval()
|
||
|
||
# 此段代码对jittor进行热身,确保时间测试准确
|
||
jt.sync_all(True)
|
||
for i in range(warmup):
|
||
pred = model(data)
|
||
# sync是把计算图发送到计算设备上
|
||
pred.sync()
|
||
# sync_all(true)是把计算图发射到计算设备上,并且同步。
|
||
# 只有运行了jt.sync_all(True)才会真正地运行,时间才是有效的,因此执行forward前后都要执行这句话
|
||
jt.sync_all(True)
|
||
|
||
# 开始测试运行时间
|
||
start = time.time()
|
||
for i in range(rerun):
|
||
pred = model(data)
|
||
pred.sync()
|
||
jt.sync_all(True)
|
||
end = time.time()
|
||
|
||
print("Jittor FPS:", (rerun*batch_size)/(end-start))
|
||
|
||
```
|
||
|
||
在这段代码中,我们定义了几个参数`batch_size`, `warmup`, `rerun`, batch_size代表批大小,warmup是用于热身的循环次数,而rerun是用于测速的循环次数,最终输出FPS,对Jittor进行正确测速的关键是 热身部分和同步部分,热身部分确保测试时间稳定,没有包含编译用的时间,而同步部分确保计算完成,因为jittor是一个异步框架,只有同步操作能保证计算完成。
|
||
|
||
以上代码的运行结果如下(RTX Titan,batch 8):
|
||
|
||
```
|
||
Compiling Operators(8/8) used: 7.35s eta: 0s
|
||
Compiling Operators(13/13) used: 8.36s eta: 0s
|
||
Jittor FPS: 908.9853866375396
|
||
```
|
||
|
||
我们还可以使用类似的代码测试 PyTorch的性能:
|
||
|
||
```python
|
||
import time
|
||
import torch
|
||
from torchvision.models import resnet50
|
||
|
||
warmup = 10
|
||
rerun = 100
|
||
batch_size = 8
|
||
data = torch.randn((batch_size, 3, 224, 224)).cuda()
|
||
model = resnet50()
|
||
model.cuda()
|
||
model.eval()
|
||
|
||
# 此段代码对pytorch进行热身,确保时间测试准确
|
||
torch.cuda.synchronize()
|
||
for i in range(warmup):
|
||
pred = model(data)
|
||
# synchronize用于确保PyTorch计算完成
|
||
torch.cuda.synchronize()
|
||
|
||
# 开始测试运行时间
|
||
start = time.time()
|
||
for i in range(rerun):
|
||
pred = model(data)
|
||
torch.cuda.synchronize()
|
||
end = time.time()
|
||
|
||
print("PyTorch FPS:", (rerun*batch_size)/(end-start))
|
||
```
|
||
|
||
|
||
以上代码的运行结果如下(RTX Titan,batch 8):
|
||
|
||
```
|
||
PyTorch FPS: 807.4806873965665
|
||
```
|
||
|
||
我们还可以对这两段代码合并,并对比结果的一致性:
|
||
|
||
```python
|
||
import time
|
||
import jittor as jt
|
||
from jittor.models import resnet50
|
||
jt.flags.use_cuda = jt.has_cuda
|
||
|
||
warmup = 100
|
||
rerun = 1000
|
||
batch_size = 8
|
||
data = jt.random((batch_size, 3, 224, 224))
|
||
model = resnet50()
|
||
model.eval()
|
||
|
||
# 此段代码对jittor进行热身,确保时间测试准确
|
||
jt.sync_all(True)
|
||
for i in range(warmup):
|
||
pred = model(data)
|
||
# sync是把计算图发送到计算设备上
|
||
pred.sync()
|
||
# sync_all(true)是把计算图发射到计算设备上,并且同步。
|
||
# 只有运行了jt.sync_all(True)才会真正地运行,时间才是有效的,因此执行forward前后都要执行这句话
|
||
jt.sync_all(True)
|
||
|
||
# 开始测试运行时间
|
||
start = time.time()
|
||
for i in range(rerun):
|
||
pred = model(data)
|
||
pred.sync()
|
||
jt.sync_all(True)
|
||
end = time.time()
|
||
|
||
print("Jittor FPS:", (rerun*batch_size)/(end-start))
|
||
# 将 jittor 数据和参数导出为 numpy 和 torch 格式
|
||
jittor_data = pred.numpy()
|
||
jittor_param = model.state_dict(to="torch")
|
||
|
||
import numpy as np
|
||
import torch
|
||
from torchvision.models import resnet50
|
||
data = torch.Tensor(data.numpy()).cuda()
|
||
model = resnet50()
|
||
# 加载 jittor 参数
|
||
model.load_state_dict(jittor_param)
|
||
model.cuda()
|
||
model.eval()
|
||
|
||
# 此段代码对pytorch进行热身,确保时间测试准确
|
||
torch.cuda.synchronize()
|
||
for i in range(warmup):
|
||
pred = model(data)
|
||
# synchronize用于确保PyTorch计算完成
|
||
torch.cuda.synchronize()
|
||
|
||
# 开始测试运行时间
|
||
start = time.time()
|
||
for i in range(rerun):
|
||
pred = model(data)
|
||
torch.cuda.synchronize()
|
||
end = time.time()
|
||
|
||
print("PyTorch FPS:", (rerun*batch_size)/(end-start))
|
||
pytorch_data = pred.detach().cpu().numpy()
|
||
err = np.mean(np.abs(pytorch_data - jittor_data))
|
||
print("mean error:", err)
|
||
|
||
```
|
||
|
||
|
||
以上代码运行结果如下:
|
||
|
||
```
|
||
Jittor FPS: 908.9853866375396
|
||
PyTorch FPS: 807.4806873965665
|
||
mean error: 1e-5
|
||
```
|
||
|
||
误差输出为1e-5, 在可接受范围内。正确测速与对比的几大关键点为:
|
||
|
||
1. 充分热身,除去框架的准备时间。
|
||
2. 多次运行,确保测试时间稳定。
|
||
3. 加上同步语句,确保测试时间准确。
|
||
4. 保证显存充足,在显存不足时,jittor会调用统一内存来弥补,会产生性能损失,请密切关注`nvidia-smi`的输出结果。
|
||
5. 保证对比模型的一致性,检查输出结果的一致。
|
||
|
||
如果您对测试结果有疑问,或者有优化需求,欢迎随时联系Jittor开发团队。
|