diff --git a/server/Makefile-flash-att-v2 b/server/Makefile-flash-att-v2 index bbff00909..ff2d6b913 100644 --- a/server/Makefile-flash-att-v2 +++ b/server/Makefile-flash-att-v2 @@ -1,4 +1,4 @@ -flash_att_v2_commit_cuda := v2.5.8 +flash_att_v2_commit_cuda := v2.7.3 flash_att_v2_commit_rocm := 2554f490101742ccdc56620a938f847f61754be6 diff --git a/server/lorax_server/models/flash_causal_lm.py b/server/lorax_server/models/flash_causal_lm.py index 92937511d..0ed096d29 100644 --- a/server/lorax_server/models/flash_causal_lm.py +++ b/server/lorax_server/models/flash_causal_lm.py @@ -44,6 +44,7 @@ from lorax_server.utils.state import ( BLOCK_SIZE, FLASH_INFER, + FLASH_DECODING, get_max_prefill_tokens, get_speculative_tokens, get_supports_chunking, @@ -1235,7 +1236,7 @@ def init_kv_cache( element_size = torch.tensor([], dtype=dtype).element_size() x = BLOCK_SIZE // element_size - if FLASH_INFER: + if FLASH_INFER or FLASH_DECODING: self.kv_cache = [ ( torch.empty( diff --git a/server/lorax_server/utils/paged_attention.py b/server/lorax_server/utils/paged_attention.py index 7d7b4c82c..41c75387e 100644 --- a/server/lorax_server/utils/paged_attention.py +++ b/server/lorax_server/utils/paged_attention.py @@ -4,7 +4,7 @@ from lorax_server.utils.attention.common import Seqlen from lorax_server.utils.import_utils import SYSTEM -from lorax_server.utils.state import FLASH_INFER +from lorax_server.utils.state import FLASH_INFER, FLASH_DECODING _PARTITION_SIZE = 512 @@ -36,7 +36,7 @@ def reshape_and_cache( v_scale: float = 1.0, fp8_kv: bool = False, ): - if FLASH_INFER: + if FLASH_INFER or FLASH_DECODING: if fp8_kv: key = static_per_tensor_quantize(key, k_scale).view(torch.uint8) value = static_per_tensor_quantize(value, v_scale).view(torch.uint8) @@ -102,6 +102,38 @@ def attention( out = torch.empty_like(query) + if FLASH_DECODING: + max_q = 1 + max_k = max_s + import flash_attn_2_cuda + + # TODO fixme when flash contains the fix. + # Number of splits is not correctly handled + # by the current path + # https://github.com/Dao-AILab/flash-attention/blob/320fb59487658f033f56711efd3d61b7c7a6f8f3/csrc/flash_attn/flash_api.cpp#L577 + # This fails becuase we're using causal, therefore window_right is set to 0 and the split logic is never applied. + return flash_attn_2_cuda.varlen_fwd( + query, + key_cache, + value_cache, + out, + seqlen.cu_seqlen_q, + seqlen.cu_seqlen_k, + None, + block_tables, + None, + max_q, + max_k, + 0.0, # dropout + softmax_scale, + False, # zero_tensors + True, # causal + -1, # Window_left + -1, # Window right + False, # return softmax + None, # generator + )[0] + if SYSTEM == "xpu": query = query.contiguous() ipex.llm.modules.PagedAttention.single_query_cached_kv_attention( diff --git a/server/lorax_server/utils/state.py b/server/lorax_server/utils/state.py index d17716545..1cced81f5 100644 --- a/server/lorax_server/utils/state.py +++ b/server/lorax_server/utils/state.py @@ -11,6 +11,7 @@ LORAX_PROFILER_DIR = os.environ.get("LORAX_PROFILER_DIR", None) PREFIX_CACHING = bool(int(os.environ.get("PREFIX_CACHING", "0"))) +FLASH_DECODING = bool(int(os.environ.get("FLASH_DECODING", "0"))) CHUNKED_PREFILL = bool(int(os.environ.get("CHUNKED_PREFILL", "0"))) LORAX_SPECULATION_MAX_BATCH_SIZE = int(os.environ.get("LORAX_SPECULATION_MAX_BATCH_SIZE", 32)) @@ -18,6 +19,8 @@ FLASH_INFER = bool(int(os.environ.get("FLASH_INFER", "0"))) or PREFIX_CACHING if FLASH_INFER: logger.info("Backend = flashinfer") +elif FLASH_DECODING: + logger.info("Backend = flashdecoding") else: logger.info("Backend = fa2") @@ -34,6 +37,8 @@ BLOCK_SIZE: int if FLASH_INFER: BLOCK_SIZE = 1 +elif FLASH_DECODING: + BLOCK_SIZE = 256 else: BLOCK_SIZE = 16