1
1
import math
2
+ from functools import partial
2
3
3
4
import torch
4
5
import torch .nn .functional as F
@@ -858,14 +859,14 @@ def test_flash_attn_multigpu():
858
859
859
860
@pytest .mark .skipif (not is_sm80 , reason = 'Triton version is only tested on A100' )
860
861
@pytest .mark .parametrize ('dtype' , ([torch .float16 ] if is_sm75 else [torch .float16 , torch .bfloat16 ]))
861
- # @pytest.mark.parametrize('dtype', [torch.float16 ])
862
+ # @pytest.mark.parametrize('dtype', [torch.bfloat16 ])
862
863
@pytest .mark .parametrize ('causal' , [False , True ])
863
864
# @pytest.mark.parametrize('causal', [False])
864
865
@pytest .mark .parametrize ('d' , [40 , 48 , 64 , 128 , 80 , 88 , 96 ])
865
- # @pytest.mark.parametrize('d', [40 ])
866
+ # @pytest.mark.parametrize('d', [48 ])
866
867
# @pytest.mark.parametrize('seqlen', [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048])
867
- @pytest .mark .parametrize ('seqlen_q,seqlen_k' , [(113 , 203 ), (128 , 217 ), (113 , 211 ), (108 , 256 ), (256 , 512 ), (512 , 256 ), (1024 , 1024 ), (2048 , 2048 )])
868
- # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(1024, 1024 )])
868
+ @pytest .mark .parametrize ('seqlen_q,seqlen_k' , [(113 , 203 ), (128 , 217 ), (113 , 211 ), (108 , 256 ), (256 , 512 ), (512 , 256 ), (1024 , 1024 ), (1023 , 1024 ), ( 1024 , 1023 ), ( 2048 , 2048 )])
869
+ # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(1023, 1023 )])
869
870
def test_flash_attn_triton (seqlen_q , seqlen_k , d , causal , dtype ):
870
871
if seqlen_q >= 2048 and torch .cuda .get_device_properties ('cuda' ).total_memory <= 16 * 2 ** 30 :
871
872
pytest .skip () # Reference implementation OOM
@@ -916,13 +917,13 @@ def test_flash_attn_triton(seqlen_q, seqlen_k, d, causal, dtype):
916
917
917
918
@pytest .mark .skipif (not is_sm80 , reason = 'Triton version is only tested on A100' )
918
919
@pytest .mark .parametrize ('dtype' , ([torch .float16 ] if is_sm75 else [torch .float16 , torch .bfloat16 ]))
919
- # @pytest.mark.parametrize('dtype', [torch.float16 ])
920
+ # @pytest.mark.parametrize('dtype', [torch.bfloat16 ])
920
921
@pytest .mark .parametrize ('causal' , [False , True ])
921
922
# @pytest.mark.parametrize('causal', [True])
922
- @pytest .mark .parametrize ('d' , [40 , 48 , 64 , 128 , 80 , 88 , 96 ])
923
- # @pytest.mark.parametrize('d', [64])
923
+ # @pytest.mark.parametrize('d', [40, 48, 64, 128, 80, 88, 96])
924
+ @pytest .mark .parametrize ('d' , [64 , 128 ])
924
925
# @pytest.mark.parametrize('seqlen', [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048])
925
- @pytest .mark .parametrize ('seqlen_q,seqlen_k' , [(113 , 203 ), (128 , 217 ), (113 , 211 ), (108 , 256 ), (256 , 512 ), (512 , 256 ), (1024 , 1024 ), (1023 , 1024 ), (2048 , 2048 )])
926
+ @pytest .mark .parametrize ('seqlen_q,seqlen_k' , [(113 , 203 ), (128 , 217 ), (91 , 211 ), (108 , 256 ), (256 , 512 ), (512 , 256 ), (1024 , 1024 ), (1023 , 1024 ), ( 1024 , 1023 ), (2048 , 2048 )])
926
927
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(1023, 1024)])
927
928
def test_flash_attn_triton_race_condition (seqlen_q , seqlen_k , d , causal , dtype ):
928
929
if seqlen_q >= 2048 and torch .cuda .get_device_properties ('cuda' ).total_memory <= 16 * 2 ** 30 :
@@ -941,21 +942,24 @@ def test_flash_attn_triton_race_condition(seqlen_q, seqlen_k, d, causal, dtype):
941
942
g = torch .randn_like (output_0 )
942
943
dq_0 , dk_0 , dv_0 = torch .autograd .grad (output_0 , (q , k , v ), g )
943
944
944
- # Disable the SEQUENCE_PARALLEL option for the bwd to make sure it's deterministic
945
+ # The SEQUENCE_PARALLEL option for the bwd to makes dq non-deterministic
946
+ deterministic_dq = False
947
+ equal_fn = (torch .equal if deterministic_dq
948
+ else partial (torch .allclose , atol = 1e-3 if dtype == torch .bfloat16 else 1e-5 ))
945
949
for i in range (10000 ):
946
950
output = flash_attn_func (q , k , v , causal )
947
951
output_equal = torch .equal (output , output_0 )
948
952
if not output_equal : # Printing / computing diff sometimes makes the race condition disappear
949
953
print (f'Output max diff: { (output - output_0 ).abs ().max ().item ()} ' )
950
954
assert torch .equal (output , output_0 )
951
955
dq , dk , dv = torch .autograd .grad (output , (q , k , v ), g )
952
- dq_equal = torch . equal (dq , dq_0 )
956
+ dq_equal = equal_fn (dq , dq_0 )
953
957
dk_equal = torch .equal (dk , dk_0 )
954
958
dv_equal = torch .equal (dv , dv_0 )
955
959
if not (dq_equal and dk_equal and dv_equal ):
956
960
print (f'dQ max diff: { (dq - dq_0 ).abs ().max ().item ()} ' )
957
961
print (f'dK max diff: { (dk - dk_0 ).abs ().max ().item ()} ' )
958
962
print (f'dV max diff: { (dv - dv_0 ).abs ().max ().item ()} ' )
959
- assert torch . equal (dq , dq_0 )
963
+ assert equal_fn (dq , dq_0 )
960
964
assert torch .equal (dk , dk_0 )
961
965
assert torch .equal (dv , dv_0 )
0 commit comments