Skip to content

Commit aacc10f

Browse files
committed
Fix race condition in Triton bwd for non-po2 headdims
1 parent 1fb12af commit aacc10f

File tree

2 files changed

+23
-27
lines changed

2 files changed

+23
-27
lines changed

flash_attn/flash_attn_triton.py

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
- Implement cross-attention (not just self-attention).
88
- Support arbitrary seqlens (not just multiples of 128), for both forward and backward.
99
- [WIP] Support all head dimensions up to 128 (not just 16, 32, 64, 128), for both the forward pass
10-
and backward pass. For the backward pass, head dims that are not 16, 32, 64, 128 will require
10+
and backward pass. For the backward pass, head dims that are not 64, 128 will require
1111
more testing since there seems to be some race conditions due to the Triton compiler.
1212
- Speed up the forward pass a bit, and only store the LSE instead of m and l.
1313
- Make the backward for d=128 much faster by reducing register spilling.
@@ -17,9 +17,9 @@
1717
Differences between this Triton version and the CUDA version:
1818
- Triton version doesn't support dropout.
1919
- Triton forward is generally faster than CUDA forward.
20-
- Triton backward is faster than CUDA backward when batch * nheads is small, and might be slightly
21-
slower in other cases.
22-
- Triton version does yet not support different sequence lengths in a batch (i.e., RaggedTensor/NestedTensor).
20+
- Triton backward is faster than CUDA backward when batch * nheads is small, and when headdim=64. It is slightly
21+
slower when headdim=128 and batch * nheads is large.
22+
- Triton version doesn't yet support different sequence lengths in a batch (i.e., RaggedTensor/NestedTensor).
2323
"""
2424

2525
import math
@@ -282,7 +282,7 @@ def _bwd_kernel_one_col_block(
282282
qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf"))
283283
# There seems to be a race condition when headdim=48/96, and dq, dk, dv are wrong.
284284
# Also wrong for headdim=64.
285-
if not EVEN_M:
285+
if not (EVEN_M & EVEN_HEADDIM):
286286
tl.debug_barrier()
287287
lse_i = tl.load(LSE + offs_m_curr)
288288
p = tl.exp(qk * softmax_scale - lse_i[:, None])
@@ -316,6 +316,9 @@ def _bwd_kernel_one_col_block(
316316
if not EVEN_M:
317317
tl.debug_barrier()
318318
dp = tl.dot(do, v, trans_b=True)
319+
# There's a race condition for headdim=48
320+
if not EVEN_HEADDIM:
321+
tl.debug_barrier()
319322
# compute ds = p * (dp - delta[:, None])
320323
# Putting the subtraction after the dp matmul (instead of before) is slightly faster
321324
Di = tl.load(D + offs_m_curr)
@@ -390,10 +393,6 @@ def _bwd_kernel_one_col_block(
390393

391394

392395
def init_to_zero(name):
393-
# def fn(nargs):
394-
# with torch.no_grad():
395-
# nargs[name].zero_()
396-
# return fn
397396
return lambda nargs: nargs[name].zero_()
398397

399398
@triton.autotune(
@@ -406,15 +405,8 @@ def init_to_zero(name):
406405
# triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')),
407406
# triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=4, num_stages=1, pre_hook=init_to_zero('DQ')),
408407
# triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=4, num_stages=1, pre_hook=init_to_zero('DQ')),
409-
# triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "SEQUENCE_PARALLEL": False}, num_warps=8, num_stages=1),
410-
# triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "SEQUENCE_PARALLEL": True}, num_warps=8, num_stages=1),
411-
# triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=4, num_stages=1),
412-
# triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=4, num_stages=1),
413-
# triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=4, num_stages=1),
414-
# triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=4, num_stages=1),
415408
],
416409
key=['CACHE_KEY_SEQLEN_Q', 'CACHE_KEY_SEQLEN_K', 'IS_CAUSAL', 'BLOCK_HEADDIM'],
417-
# reset_to_zero=['DQ']
418410
)
419411
@triton.heuristics(
420412
{

tests/test_flash_attn.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import math
2+
from functools import partial
23

34
import torch
45
import torch.nn.functional as F
@@ -858,14 +859,14 @@ def test_flash_attn_multigpu():
858859

859860
@pytest.mark.skipif(not is_sm80, reason='Triton version is only tested on A100')
860861
@pytest.mark.parametrize('dtype', ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
861-
# @pytest.mark.parametrize('dtype', [torch.float16])
862+
# @pytest.mark.parametrize('dtype', [torch.bfloat16])
862863
@pytest.mark.parametrize('causal', [False, True])
863864
# @pytest.mark.parametrize('causal', [False])
864865
@pytest.mark.parametrize('d', [40, 48, 64, 128, 80, 88, 96])
865-
# @pytest.mark.parametrize('d', [40])
866+
# @pytest.mark.parametrize('d', [48])
866867
# @pytest.mark.parametrize('seqlen', [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048])
867-
@pytest.mark.parametrize('seqlen_q,seqlen_k', [(113, 203), (128, 217), (113, 211), (108, 256), (256, 512), (512, 256), (1024, 1024), (2048, 2048)])
868-
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(1024, 1024)])
868+
@pytest.mark.parametrize('seqlen_q,seqlen_k', [(113, 203), (128, 217), (113, 211), (108, 256), (256, 512), (512, 256), (1024, 1024), (1023, 1024), (1024, 1023), (2048, 2048)])
869+
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(1023, 1023)])
869870
def test_flash_attn_triton(seqlen_q, seqlen_k, d, causal, dtype):
870871
if seqlen_q >= 2048 and torch.cuda.get_device_properties('cuda').total_memory <= 16 * 2**30:
871872
pytest.skip() # Reference implementation OOM
@@ -916,13 +917,13 @@ def test_flash_attn_triton(seqlen_q, seqlen_k, d, causal, dtype):
916917

917918
@pytest.mark.skipif(not is_sm80, reason='Triton version is only tested on A100')
918919
@pytest.mark.parametrize('dtype', ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
919-
# @pytest.mark.parametrize('dtype', [torch.float16])
920+
# @pytest.mark.parametrize('dtype', [torch.bfloat16])
920921
@pytest.mark.parametrize('causal', [False, True])
921922
# @pytest.mark.parametrize('causal', [True])
922-
@pytest.mark.parametrize('d', [40, 48, 64, 128, 80, 88, 96])
923-
# @pytest.mark.parametrize('d', [64])
923+
# @pytest.mark.parametrize('d', [40, 48, 64, 128, 80, 88, 96])
924+
@pytest.mark.parametrize('d', [64, 128])
924925
# @pytest.mark.parametrize('seqlen', [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048])
925-
@pytest.mark.parametrize('seqlen_q,seqlen_k', [(113, 203), (128, 217), (113, 211), (108, 256), (256, 512), (512, 256), (1024, 1024), (1023, 1024), (2048, 2048)])
926+
@pytest.mark.parametrize('seqlen_q,seqlen_k', [(113, 203), (128, 217), (91, 211), (108, 256), (256, 512), (512, 256), (1024, 1024), (1023, 1024), (1024, 1023), (2048, 2048)])
926927
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(1023, 1024)])
927928
def test_flash_attn_triton_race_condition(seqlen_q, seqlen_k, d, causal, dtype):
928929
if seqlen_q >= 2048 and torch.cuda.get_device_properties('cuda').total_memory <= 16 * 2**30:
@@ -941,21 +942,24 @@ def test_flash_attn_triton_race_condition(seqlen_q, seqlen_k, d, causal, dtype):
941942
g = torch.randn_like(output_0)
942943
dq_0, dk_0, dv_0 = torch.autograd.grad(output_0, (q, k, v), g)
943944

944-
# Disable the SEQUENCE_PARALLEL option for the bwd to make sure it's deterministic
945+
# The SEQUENCE_PARALLEL option for the bwd to makes dq non-deterministic
946+
deterministic_dq = False
947+
equal_fn = (torch.equal if deterministic_dq
948+
else partial(torch.allclose, atol=1e-3 if dtype == torch.bfloat16 else 1e-5))
945949
for i in range(10000):
946950
output = flash_attn_func(q, k, v, causal)
947951
output_equal = torch.equal(output, output_0)
948952
if not output_equal: # Printing / computing diff sometimes makes the race condition disappear
949953
print(f'Output max diff: {(output - output_0).abs().max().item()}')
950954
assert torch.equal(output, output_0)
951955
dq, dk, dv = torch.autograd.grad(output, (q, k, v), g)
952-
dq_equal = torch.equal(dq, dq_0)
956+
dq_equal = equal_fn(dq, dq_0)
953957
dk_equal = torch.equal(dk, dk_0)
954958
dv_equal = torch.equal(dv, dv_0)
955959
if not (dq_equal and dk_equal and dv_equal):
956960
print(f'dQ max diff: {(dq - dq_0).abs().max().item()}')
957961
print(f'dK max diff: {(dk - dk_0).abs().max().item()}')
958962
print(f'dV max diff: {(dv - dv_0).abs().max().item()}')
959-
assert torch.equal(dq, dq_0)
963+
assert equal_fn(dq, dq_0)
960964
assert torch.equal(dk, dk_0)
961965
assert torch.equal(dv, dv_0)

0 commit comments

Comments
 (0)