Skip to content

Commit 4c20539

Browse files
authored
Merge branch 'HazyResearch:main' into main
2 parents 08adf1f + 6b4a482 commit 4c20539

File tree

274 files changed

+23736
-854
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

274 files changed

+23736
-854
lines changed

.gitignore

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# Byte-compiled / optimized / DLL files
2+
__pycache__/
3+
*.py[cod]
4+
5+
# C extensions
6+
*.so
7+
8+
# Distribution / packaging
9+
bin/
10+
build/
11+
develop-eggs/
12+
dist/
13+
eggs/
14+
lib/
15+
lib64/
16+
parts/
17+
sdist/
18+
var/
19+
*.egg-info/
20+
.installed.cfg
21+
*.egg

MANIFEST.in

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
recursive-include csrc *.cu
2+
recursive-include csrc *.h
3+
recursive-include csrc *.cuh
4+
recursive-include csrc *.cpp
5+
6+
recursive-include flash_attn *.cu
7+
recursive-include flash_attn *.h
8+
recursive-include flash_attn *.cuh
9+
recursive-include flash_attn *.cpp

Makefile

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
2+
clean_dist:
3+
rm -rf dist/*
4+
5+
create_dist: clean_dist
6+
python setup.py sdist
7+
8+
upload_package: create_dist
9+
twine upload dist/*

README.md

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,27 @@ Paper: https://arxiv.org/abs/2205.14135
88
IEEE Spectrum [article](https://spectrum.ieee.org/mlperf-rankings-2022) about our submission to the MLPerf 2.0 benchmark using FlashAttention.
99
![FlashAttention](assets/flashattn_banner.jpg)
1010

11-
#### Triton implementation of FlashAttention
11+
## Usage
12+
13+
We've been very happy to see FlashAttention being widely adopted in such a short
14+
time after its release. This [page](https://github.com/HazyResearch/flash-attention/blob/main/usage.md)
15+
contains a partial list of places where FlashAttention is being used.
16+
17+
## Full model code and training script
18+
19+
We have released the full GPT model
20+
[implementation](https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/models/gpt.py).
21+
We also provide optimized implementations of other layers (e.g., MLP, LayerNorm,
22+
cross-entropy loss, rotary embedding). Overall this speeds up training by 3-5x
23+
compared to the baseline implementation from Huggingface, reaching up to 189
24+
TFLOPs/sec per A100, equivalent to 60.6\% model FLOPs utilization (we don't need
25+
any activation checkpointing).
26+
27+
We also include a training
28+
[script](https://github.com/HazyResearch/flash-attention/tree/main/training) to
29+
train GPT2 on Openwebtext and GPT3 on The Pile.
30+
31+
## Triton implementation of FlashAttention
1232

1333
Phil Tillet (OpenAI) has an experimental implementation of FlashAttention in Triton:
1434
https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py
@@ -18,9 +38,14 @@ and experiment with. The notations in the Triton implementation are also closer
1838
to what's used in our paper.
1939

2040

21-
## Alpha release (0.1).
41+
## Beta release (0.2).
42+
43+
To install (requiring CUDA 11, NVCC, and an Turing or Ampere GPU):
44+
```sh
45+
pip install flash-attn
46+
```
2247

23-
To compile (requiring CUDA 11, NVCC, and an Turing or Ampere GPU):
48+
Alternatively you can compile from source:
2449
```
2550
python setup.py install
2651
```
@@ -38,15 +63,15 @@ FlashAttention currently supports:
3863
3. Head dimensions that are multiples of 8, up to 128 (e.g., 8, 16, 24, ..., 128). Head dim > 64 backward requires A100.
3964

4065
Our tentative roadmap:
41-
1. [Jun 2022] Make package pip-installable.
66+
1. ~~[Jun 2022] Make package pip-installable~~[Done, thanks to lucidrains].
4267
2. ~~[Jun 2022] Support SM86 GPUs (e.g., RTX 3080, 3090)~~[Done].
4368
3. [Jun 2022] Refactor to use Cutlass.
4469
4. ~~[Jun 2022] Support SM75 GPUs (e.g. T4)~~[Done].
4570
5. ~~[Jun 2022] Support bf16~~[Done].
4671
6. ~~[Jul 2022] Implement cross-attention~~[Done].
4772
7. ~~[Jul 2022] Support head dimension 128~~[Done].
4873
8. [Jul 2022] Support SM70 GPUs (V100).
49-
9. [Aug 2022] Fuse rotary embedding.
74+
9. ~~[Aug 2022] Fuse rotary embedding~~[Done].
5075
10. [Aug 2022] Support attention bias (e.g. ALiBi, relative positional encoding).
5176

5277
## Speedup and Memory Savings
@@ -148,10 +173,10 @@ and for his thoughtful answers to our questions about CUDA.
148173
## Citation
149174
If you use this codebase, or otherwise found our work valuable, please cite:
150175
```
151-
@article{dao2022flashattention,
152-
title={FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness},
176+
@inproceedings{dao2022flashattention,
177+
title={Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness},
153178
author={Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
154-
journal={arXiv preprint arXiv:2205.14135},
179+
booktitle={Advances in Neural Information Processing Systems},
155180
year={2022}
156181
}
157182
```

assets/gpt2_training_curve.jpg

168 KB
Loading

assets/gpt2_training_efficiency.jpg

367 KB
Loading

assets/gpt3_training_curve.jpg

183 KB
Loading

assets/gpt3_training_efficiency.jpg

382 KB
Loading

csrc/flash_attn/fmha_api.cpp

Lines changed: 31 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,16 @@ void set_params_dgrad(FMHA_dgrad_params &params,
176176
params.dsoftmax_sum = dsoftmax_sum_d;
177177
}
178178

179+
void run_fmha_fwd(Launch_params<FMHA_fprop_params> &launch_params) {
180+
if (launch_params.params.d <= 32) {
181+
run_fmha_fwd_hdim32(launch_params);
182+
} else if (launch_params.params.d <= 64) {
183+
run_fmha_fwd_hdim64(launch_params);
184+
} else if (launch_params.params.d <= 128) {
185+
run_fmha_fwd_hdim128(launch_params);
186+
}
187+
}
188+
179189
std::vector<at::Tensor>
180190
mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
181191
const at::Tensor &k, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
@@ -299,21 +309,29 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
299309
// state
300310
// We use a custom RNG that increases the offset by batch_size * nheads * 32.
301311
int64_t counter_offset = launch_params.params.b * launch_params.params.h * 32;
302-
at::PhiloxCudaState rng_engine_inputs;
303312

304313
if( is_dropout ) {
305314
// See Note [Acquire lock when using random generators]
306315
std::lock_guard<std::mutex> lock(gen->mutex_);
307316
launch_params.params.philox_args = gen->philox_cuda_state(counter_offset);
308317
}
309318

310-
run_fmha_fp16_sm80(launch_params);
319+
run_fmha_fwd(launch_params);
311320

312321
std::vector<at::Tensor> result = {softmax_lse};
313322
if (return_softmax) {result.push_back(s);}
314323
return result;
315324
}
316325

326+
void run_fmha_bwd(FMHA_dgrad_params &params, cudaStream_t stream, const bool configure) {
327+
if (params.d <= 32) {
328+
run_fmha_bwd_hdim32(params, stream, configure);
329+
} else if (params.d <= 64) {
330+
run_fmha_bwd_hdim64(params, stream, configure);
331+
} else if (params.d <= 128) {
332+
run_fmha_bwd_hdim128(params, stream, configure);
333+
}
334+
}
317335

318336
std::vector<at::Tensor>
319337
mha_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
@@ -341,7 +359,7 @@ mha_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
341359
bool is_sm80 = dprops->major == 8 && dprops->minor == 0;
342360
bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;
343361
TORCH_CHECK(is_sm8x || is_sm75);
344-
auto launch = &run_fmha_dgrad_fp16_sm80;
362+
auto launch = &run_fmha_bwd;
345363

346364
bool is_dropout = p_dropout > 0.0;
347365
auto stream = at::cuda::getCurrentCUDAStream().stream();
@@ -454,17 +472,13 @@ mha_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
454472

455473
launch(params, stream, /*configure=*/true);
456474

457-
at::Tensor dk_accum, dv_accum;
458475
if (params.num_splits > 1) {
459-
// dk_accum = torch::zeros({total_k, num_heads, head_size}, opts.dtype(at::kFloat));
460-
// dv_accum = torch::zeros({total_k, num_heads, head_size}, opts.dtype(at::kFloat));
461-
// params.dk_accum_ptr = dk_accum.data_ptr();
462-
// params.dv_accum_ptr = dv_accum.data_ptr();
463-
dk.zero_();
464-
dv.zero_();
465-
} else {
466-
// params.dk_accum_ptr = nullptr;
467-
// params.dv_accum_ptr = nullptr;
476+
if (!dq_tmp.defined()) {
477+
dq_tmp = torch::zeros({total_q, num_heads, head_size}, opts.dtype(at::kFloat));
478+
params.o_tmp_ptr = dq_tmp.data_ptr(); // o_tmp stores dq_tmp in the backward pass
479+
} else {
480+
dq_tmp.zero_();
481+
}
468482
}
469483

470484
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
@@ -481,10 +495,10 @@ mha_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
481495

482496
launch(params, stream, /*configure=*/false);
483497

484-
// if (params.num_splits > 1) {
485-
// dk.copy_(dk_accum);
486-
// dv.copy_(dv_accum);
487-
// }
498+
if (params.num_splits > 1) {
499+
dq.copy_(dq_tmp);
500+
}
501+
488502
return { dq, dk, dv, softmax_d };
489503
}
490504

@@ -597,7 +611,6 @@ mha_fwd_block(const at::Tensor &q, // total_q x num_heads x head_size, t
597611
// number of times random will be generated per thread, to offset philox counter in thc random
598612
// state
599613
int64_t counter_offset = launch_params.elts_per_thread;
600-
at::PhiloxCudaState rng_engine_inputs;
601614

602615
if( is_dropout ) {
603616
// See Note [Acquire lock when using random generators]

csrc/flash_attn/src/.DS_Store

-6 KB
Binary file not shown.

0 commit comments

Comments
 (0)