Added test for backward pass
This commit is contained in:
parent
a3de9cb930
commit
f545323c44
|
@ -98,12 +98,17 @@ def random_affines_4x4(dim):
|
|||
return affines.reshape(*dim, 4, 4)
|
||||
|
||||
|
||||
def random_attention_inputs(batch_size, n_seq, n, no_heads, c_hidden, inf=1e9, dtype=torch.float32):
|
||||
q = torch.rand(batch_size, n_seq, n, c_hidden, dtype=dtype).cuda()
|
||||
kv = torch.rand(batch_size, n_seq, n, c_hidden, dtype=dtype).cuda()
|
||||
def random_attention_inputs(batch_size, n_seq, n, no_heads, c_hidden, inf=1e9,
|
||||
dtype=torch.float32, requires_grad=False):
|
||||
q = torch.rand(batch_size, n_seq, n, c_hidden, dtype=dtype, requires_grad=requires_grad).cuda()
|
||||
kv = torch.rand(batch_size, n_seq, n, c_hidden, dtype=dtype, requires_grad=requires_grad).cuda()
|
||||
|
||||
mask = torch.randint(0, 2, (batch_size, n_seq, 1, 1, n), dtype=dtype).cuda()
|
||||
biases = [inf * (mask - 1), torch.rand(batch_size, 1, no_heads, n, n)]
|
||||
biases = [b.to(dtype=dtype).cuda() for b in biases]
|
||||
mask = torch.randint(0, 2, (batch_size, n_seq, 1, 1, n), dtype=dtype, requires_grad=requires_grad).cuda()
|
||||
z_bias = torch.rand(batch_size, 1, no_heads, n, n, dtype=dtype, requires_grad=requires_grad).cuda()
|
||||
mask_bias = inf * (mask - 1)
|
||||
if requires_grad:
|
||||
mask_bias = mask_bias.detach().clone().requires_grad_()
|
||||
|
||||
biases = [mask_bias, z_bias]
|
||||
|
||||
return q, kv, mask, biases
|
||||
|
|
|
@ -17,15 +17,16 @@ Unit tests to compare components of OpenFold run with the DeepSpeed memory-effic
|
|||
attention kernel, DS4Sci_EvoformerAttention vs. a stock PyTorch attention implementation.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import unittest
|
||||
import numpy as np
|
||||
import pickle
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
from openfold.data import data_transforms
|
||||
from openfold.model.primitives import (
|
||||
lecun_normal_init_,
|
||||
Attention,
|
||||
Attention
|
||||
)
|
||||
from openfold.utils.tensor_utils import tensor_tree_map
|
||||
|
||||
|
@ -39,15 +40,15 @@ class TestDeepSpeedKernel(unittest.TestCase):
|
|||
def compare_attention_types(self, use_flash=False):
|
||||
"""Compare attention with and without using DeepSpeed Evoformer kernel."""
|
||||
batch_size = consts.batch_size
|
||||
n_seq = consts.n_seq
|
||||
n = 2 ** 12
|
||||
n_seq = 18
|
||||
n_res = 20
|
||||
c_hidden = 32
|
||||
no_heads = 4
|
||||
eps = 2e-2
|
||||
|
||||
q, kv, mask, biases = random_attention_inputs(batch_size=batch_size,
|
||||
n_seq=n_seq,
|
||||
n=n,
|
||||
n=n_res,
|
||||
no_heads=no_heads,
|
||||
c_hidden=c_hidden)
|
||||
|
||||
|
@ -61,7 +62,7 @@ class TestDeepSpeedKernel(unittest.TestCase):
|
|||
|
||||
if use_flash:
|
||||
biases = [biases[0]]
|
||||
flash_mask = mask.reshape(batch_size * n_seq, n)
|
||||
flash_mask = mask.reshape(batch_size * n_seq, n_res)
|
||||
real_out = a(q, kv, use_flash=True, flash_mask=flash_mask).cpu()
|
||||
else:
|
||||
real_out = a(q, kv, biases=biases).cpu()
|
||||
|
@ -71,15 +72,79 @@ class TestDeepSpeedKernel(unittest.TestCase):
|
|||
err = torch.max(torch.abs(ds_out - real_out))
|
||||
self.assertTrue(err < eps, f'Error: {err}')
|
||||
|
||||
def test_ds_kernel_vs_attention(self):
|
||||
def test_ds_kernel_vs_attention_forward(self):
|
||||
"""Compare regular attention vs. DeepSpeed Evoformer kernel."""
|
||||
self.compare_attention_types(use_flash=False)
|
||||
|
||||
@compare_utils.skip_unless_flash_attn_installed()
|
||||
def test_ds_kernel_vs_flash_attention(self):
|
||||
def test_ds_kernel_vs_flash_attn_forward(self):
|
||||
"""Compare Flash Attention vs. DeepSpeed Evoformer kernel."""
|
||||
self.compare_attention_types(use_flash=True)
|
||||
|
||||
def test_ds_kernel_vs_attention_backward(self):
|
||||
"""Compare backward pass for regular attention vs. DeepSpeed Evoformer kernel."""
|
||||
batch_size = consts.batch_size
|
||||
n_seq = 18
|
||||
n_res = 20
|
||||
c_hidden = 32
|
||||
no_heads = 4
|
||||
eps = consts.eps
|
||||
|
||||
q, kv, mask, biases = random_attention_inputs(batch_size=batch_size,
|
||||
n_seq=n_seq,
|
||||
n=n_res,
|
||||
no_heads=no_heads,
|
||||
c_hidden=c_hidden,
|
||||
requires_grad=True)
|
||||
|
||||
attn = Attention(
|
||||
c_hidden, c_hidden, c_hidden, c_hidden, no_heads
|
||||
).cuda()
|
||||
|
||||
with torch.no_grad():
|
||||
lecun_normal_init_(attn.linear_g.weight)
|
||||
lecun_normal_init_(attn.linear_o.weight)
|
||||
|
||||
def clone(t):
|
||||
t = t.clone()
|
||||
if t.requires_grad:
|
||||
t.retain_grad()
|
||||
return t
|
||||
|
||||
def init_attn():
|
||||
a_clone = Attention(
|
||||
c_hidden, c_hidden, c_hidden, c_hidden, no_heads
|
||||
).cuda()
|
||||
|
||||
a_clone.load_state_dict(attn.state_dict())
|
||||
return a_clone
|
||||
|
||||
q_repro = clone(q)
|
||||
kv_repro = clone(kv)
|
||||
biases_repro = [clone(b) for b in biases]
|
||||
|
||||
a = init_attn()
|
||||
out_repro = a(q_repro, kv_repro, biases=biases_repro, use_deepspeed_evo_attention=True)
|
||||
loss_repro = torch.mean(out_repro)
|
||||
loss_repro.backward()
|
||||
|
||||
q_gt = clone(q)
|
||||
kv_gt = clone(kv)
|
||||
biases_gt = [clone(b) for b in biases]
|
||||
|
||||
a = init_attn()
|
||||
out_gt = a(q_gt, kv_gt, biases=biases_gt)
|
||||
|
||||
loss_gt = torch.mean(out_gt)
|
||||
loss_gt.backward()
|
||||
|
||||
pairs = zip([q_repro, kv_repro, biases_repro[0], biases_repro[1]],
|
||||
[q_gt, kv_gt, biases_gt[0], biases_gt[1]])
|
||||
for i, item in enumerate(pairs):
|
||||
t_repro, t_gt = item
|
||||
err = torch.max(torch.abs(t_repro.grad.cpu() - t_gt.grad.cpu()))
|
||||
self.assertTrue(err < eps, f'Error item #{i}: {err}')
|
||||
|
||||
def compare_evoformer(self, dtype):
|
||||
"""
|
||||
Compare Evoformer output with and without using DeepSpeed Evoformer attention kernel.
|
||||
|
@ -88,7 +153,9 @@ class TestDeepSpeedKernel(unittest.TestCase):
|
|||
"""
|
||||
n_res = 20
|
||||
n_seq = 18
|
||||
eps = 0.5
|
||||
c_m_shape = (consts.c_m,)
|
||||
c_z_shape = (consts.c_z,)
|
||||
eps = 2e-2
|
||||
|
||||
activations = {
|
||||
"msa": torch.rand(n_seq, n_res, consts.c_m, device='cuda', dtype=dtype),
|
||||
|
@ -113,8 +180,10 @@ class TestDeepSpeedKernel(unittest.TestCase):
|
|||
inplace_safe=False,
|
||||
)
|
||||
|
||||
out_repro_msa = out_repro_msa.cpu()
|
||||
out_repro_pair = out_repro_pair.cpu()
|
||||
# In practice, layer norms applied later in the network make any
|
||||
# kernel rounding errors negligible
|
||||
out_repro_msa = F.layer_norm(out_repro_msa, c_m_shape).cpu()
|
||||
out_repro_pair = F.layer_norm(out_repro_pair, c_z_shape).cpu()
|
||||
|
||||
out_repro_msa_ds, out_repro_pair_ds = model.evoformer.blocks[0](
|
||||
activations["msa"],
|
||||
|
@ -126,8 +195,8 @@ class TestDeepSpeedKernel(unittest.TestCase):
|
|||
_mask_trans=False,
|
||||
inplace_safe=False,
|
||||
)
|
||||
out_repro_msa_ds = out_repro_msa_ds.cpu()
|
||||
out_repro_pair_ds = out_repro_pair_ds.cpu()
|
||||
out_repro_msa_ds = F.layer_norm(out_repro_msa_ds, c_m_shape).cpu()
|
||||
out_repro_pair_ds = F.layer_norm(out_repro_pair_ds, c_z_shape).cpu()
|
||||
|
||||
err = torch.mean(torch.abs(out_repro_msa - out_repro_msa_ds))
|
||||
self.assertTrue(err < eps, f'MSA Error: {err}')
|
||||
|
@ -188,7 +257,7 @@ class TestDeepSpeedKernel(unittest.TestCase):
|
|||
def test_compare_model(self):
|
||||
"""
|
||||
Run full model with and without using DeepSpeed Evoformer attention kernel
|
||||
and compare output coordinates
|
||||
and compare output coordinates.
|
||||
"""
|
||||
eps = 0.5
|
||||
with open("tests/test_data/sample_feats.pickle", "rb") as fp:
|
||||
|
|
Loading…
Reference in New Issue