Added test for backward pass

This commit is contained in:
Christina Floristean 2023-10-06 13:28:05 -04:00
parent a3de9cb930
commit f545323c44
2 changed files with 94 additions and 20 deletions

View File

@ -98,12 +98,17 @@ def random_affines_4x4(dim):
return affines.reshape(*dim, 4, 4)
def random_attention_inputs(batch_size, n_seq, n, no_heads, c_hidden, inf=1e9, dtype=torch.float32):
q = torch.rand(batch_size, n_seq, n, c_hidden, dtype=dtype).cuda()
kv = torch.rand(batch_size, n_seq, n, c_hidden, dtype=dtype).cuda()
def random_attention_inputs(batch_size, n_seq, n, no_heads, c_hidden, inf=1e9,
dtype=torch.float32, requires_grad=False):
q = torch.rand(batch_size, n_seq, n, c_hidden, dtype=dtype, requires_grad=requires_grad).cuda()
kv = torch.rand(batch_size, n_seq, n, c_hidden, dtype=dtype, requires_grad=requires_grad).cuda()
mask = torch.randint(0, 2, (batch_size, n_seq, 1, 1, n), dtype=dtype).cuda()
biases = [inf * (mask - 1), torch.rand(batch_size, 1, no_heads, n, n)]
biases = [b.to(dtype=dtype).cuda() for b in biases]
mask = torch.randint(0, 2, (batch_size, n_seq, 1, 1, n), dtype=dtype, requires_grad=requires_grad).cuda()
z_bias = torch.rand(batch_size, 1, no_heads, n, n, dtype=dtype, requires_grad=requires_grad).cuda()
mask_bias = inf * (mask - 1)
if requires_grad:
mask_bias = mask_bias.detach().clone().requires_grad_()
biases = [mask_bias, z_bias]
return q, kv, mask, biases

View File

@ -17,15 +17,16 @@ Unit tests to compare components of OpenFold run with the DeepSpeed memory-effic
attention kernel, DS4Sci_EvoformerAttention vs. a stock PyTorch attention implementation.
"""
import torch
import unittest
import numpy as np
import pickle
import torch
from torch.nn import functional as F
from openfold.data import data_transforms
from openfold.model.primitives import (
lecun_normal_init_,
Attention,
Attention
)
from openfold.utils.tensor_utils import tensor_tree_map
@ -39,15 +40,15 @@ class TestDeepSpeedKernel(unittest.TestCase):
def compare_attention_types(self, use_flash=False):
"""Compare attention with and without using DeepSpeed Evoformer kernel."""
batch_size = consts.batch_size
n_seq = consts.n_seq
n = 2 ** 12
n_seq = 18
n_res = 20
c_hidden = 32
no_heads = 4
eps = 2e-2
q, kv, mask, biases = random_attention_inputs(batch_size=batch_size,
n_seq=n_seq,
n=n,
n=n_res,
no_heads=no_heads,
c_hidden=c_hidden)
@ -61,7 +62,7 @@ class TestDeepSpeedKernel(unittest.TestCase):
if use_flash:
biases = [biases[0]]
flash_mask = mask.reshape(batch_size * n_seq, n)
flash_mask = mask.reshape(batch_size * n_seq, n_res)
real_out = a(q, kv, use_flash=True, flash_mask=flash_mask).cpu()
else:
real_out = a(q, kv, biases=biases).cpu()
@ -71,15 +72,79 @@ class TestDeepSpeedKernel(unittest.TestCase):
err = torch.max(torch.abs(ds_out - real_out))
self.assertTrue(err < eps, f'Error: {err}')
def test_ds_kernel_vs_attention(self):
def test_ds_kernel_vs_attention_forward(self):
"""Compare regular attention vs. DeepSpeed Evoformer kernel."""
self.compare_attention_types(use_flash=False)
@compare_utils.skip_unless_flash_attn_installed()
def test_ds_kernel_vs_flash_attention(self):
def test_ds_kernel_vs_flash_attn_forward(self):
"""Compare Flash Attention vs. DeepSpeed Evoformer kernel."""
self.compare_attention_types(use_flash=True)
def test_ds_kernel_vs_attention_backward(self):
"""Compare backward pass for regular attention vs. DeepSpeed Evoformer kernel."""
batch_size = consts.batch_size
n_seq = 18
n_res = 20
c_hidden = 32
no_heads = 4
eps = consts.eps
q, kv, mask, biases = random_attention_inputs(batch_size=batch_size,
n_seq=n_seq,
n=n_res,
no_heads=no_heads,
c_hidden=c_hidden,
requires_grad=True)
attn = Attention(
c_hidden, c_hidden, c_hidden, c_hidden, no_heads
).cuda()
with torch.no_grad():
lecun_normal_init_(attn.linear_g.weight)
lecun_normal_init_(attn.linear_o.weight)
def clone(t):
t = t.clone()
if t.requires_grad:
t.retain_grad()
return t
def init_attn():
a_clone = Attention(
c_hidden, c_hidden, c_hidden, c_hidden, no_heads
).cuda()
a_clone.load_state_dict(attn.state_dict())
return a_clone
q_repro = clone(q)
kv_repro = clone(kv)
biases_repro = [clone(b) for b in biases]
a = init_attn()
out_repro = a(q_repro, kv_repro, biases=biases_repro, use_deepspeed_evo_attention=True)
loss_repro = torch.mean(out_repro)
loss_repro.backward()
q_gt = clone(q)
kv_gt = clone(kv)
biases_gt = [clone(b) for b in biases]
a = init_attn()
out_gt = a(q_gt, kv_gt, biases=biases_gt)
loss_gt = torch.mean(out_gt)
loss_gt.backward()
pairs = zip([q_repro, kv_repro, biases_repro[0], biases_repro[1]],
[q_gt, kv_gt, biases_gt[0], biases_gt[1]])
for i, item in enumerate(pairs):
t_repro, t_gt = item
err = torch.max(torch.abs(t_repro.grad.cpu() - t_gt.grad.cpu()))
self.assertTrue(err < eps, f'Error item #{i}: {err}')
def compare_evoformer(self, dtype):
"""
Compare Evoformer output with and without using DeepSpeed Evoformer attention kernel.
@ -88,7 +153,9 @@ class TestDeepSpeedKernel(unittest.TestCase):
"""
n_res = 20
n_seq = 18
eps = 0.5
c_m_shape = (consts.c_m,)
c_z_shape = (consts.c_z,)
eps = 2e-2
activations = {
"msa": torch.rand(n_seq, n_res, consts.c_m, device='cuda', dtype=dtype),
@ -113,8 +180,10 @@ class TestDeepSpeedKernel(unittest.TestCase):
inplace_safe=False,
)
out_repro_msa = out_repro_msa.cpu()
out_repro_pair = out_repro_pair.cpu()
# In practice, layer norms applied later in the network make any
# kernel rounding errors negligible
out_repro_msa = F.layer_norm(out_repro_msa, c_m_shape).cpu()
out_repro_pair = F.layer_norm(out_repro_pair, c_z_shape).cpu()
out_repro_msa_ds, out_repro_pair_ds = model.evoformer.blocks[0](
activations["msa"],
@ -126,8 +195,8 @@ class TestDeepSpeedKernel(unittest.TestCase):
_mask_trans=False,
inplace_safe=False,
)
out_repro_msa_ds = out_repro_msa_ds.cpu()
out_repro_pair_ds = out_repro_pair_ds.cpu()
out_repro_msa_ds = F.layer_norm(out_repro_msa_ds, c_m_shape).cpu()
out_repro_pair_ds = F.layer_norm(out_repro_pair_ds, c_z_shape).cpu()
err = torch.mean(torch.abs(out_repro_msa - out_repro_msa_ds))
self.assertTrue(err < eps, f'MSA Error: {err}')
@ -188,7 +257,7 @@ class TestDeepSpeedKernel(unittest.TestCase):
def test_compare_model(self):
"""
Run full model with and without using DeepSpeed Evoformer attention kernel
and compare output coordinates
and compare output coordinates.
"""
eps = 0.5
with open("tests/test_data/sample_feats.pickle", "rb") as fp: