diff --git a/benchmarks/benchmark_one_concurrent_req.py b/benchmarks/benchmark_one_concurrent_req.py new file mode 100644 index 000000000000..ecfeb249bfe8 --- /dev/null +++ b/benchmarks/benchmark_one_concurrent_req.py @@ -0,0 +1,387 @@ +# SPDX-License-Identifier: Apache-2.0 +import argparse +import asyncio +import logging +import random +import time +from dataclasses import dataclass +from typing import Optional + +import aiohttp # Import aiohttp +import numpy as np +from tqdm import tqdm + +from backend_request_func import RequestFuncInput, RequestFuncOutput +from benchmark_dataset import RandomDataset, SampleRequest + +try: + from vllm.transformers_utils.tokenizer import get_tokenizer +except ImportError: + from backend_request_func import get_tokenizer + +logger = logging.getLogger(__name__) + + +@dataclass +class BenchmarkMetrics: + completed: int + total_input: int + total_output: int + mean_ttft_ms: float + median_ttft_ms: float + std_ttft_ms: float + percentiles_ttft_ms: list[tuple[float, float]] + mean_itl_ms: float + median_itl_ms: float + std_itl_ms: float + percentiles_itl_ms: list[tuple[float, float]] + mean_e2el_ms: float + median_e2el_ms: float + std_e2el_ms: float + percentiles_e2el_ms: list[tuple[float, float]] + + +async def reset_cache(reset_url: str): + """Sends a POST request to reset the prefix cache.""" + logger.debug("Resetting prefix cache at %s", reset_url) + try: + async with ( + aiohttp.ClientSession() as session, + session.post(reset_url) as response, + ): + response.raise_for_status() # Raise an exception for bad status codes + logger.debug("Prefix cache reset successful: %s", response.status) + except aiohttp.ClientConnectorError as e: + logger.error("Failed to connect to cache reset endpoint %s: %s}", reset_url, e) + except aiohttp.ClientResponseError as e: + logger.error( + "Cache reset request failed with status %s: %s", e.status, e.message + ) + except Exception as e: + logger.error("An unexpected error occurred during cache reset: %s", e) + + +async def sequential_benchmark( + backend: str, + api_url: str, + model_id: str, + tokenizer, + input_requests: list[SampleRequest], + request_func, + selected_percentiles: list[float], + cache_reset_url: Optional[str] = None, +): + """ + Benchmark that processes requests sequentially, waiting for each to complete + before starting the next one. Resets prefix cache between requests. + """ + outputs = [] + + pbar = tqdm(total=len(input_requests)) + + # Small request to force a forward pass. + # Used for resetting the prefix cache. + dummy_req_input = RequestFuncInput( + model=model_id, + prompt="0", + api_url=api_url, + prompt_len=1, + output_len=1, + ) + + print("Starting initial single prompt test run...") + test_output = await request_func(request_func_input=dummy_req_input) + if not test_output.success: + raise ValueError( + "Initial test run failed - Please check your configuration. Error: %s", + test_output.error, + ) + else: + print("Initial test run completed. Starting sequential benchmark...") + + benchmark_start_time = time.perf_counter() + + # Process requests sequentially + for request in input_requests: + prompt, prompt_len, output_len = ( + request.prompt, + request.prompt_len, + request.expected_output_len, + ) + + logger.info("Sending request with len %s", request.prompt_len) + logger.debug('Request str: "%s"', request.prompt[:50]) + request_start_time = time.perf_counter() + + request_func_input = RequestFuncInput( + model=model_id, + prompt=prompt, + api_url=api_url, + prompt_len=prompt_len, + output_len=output_len, + ) + + output = await request_func(request_func_input=request_func_input) + + request_end_time = time.perf_counter() + # Add timing information + if output.success and not hasattr(output, "latency"): + output.latency = request_end_time - request_start_time + logger.info("Finished request with latency %.4f s", output.latency) + + outputs.append(output) + pbar.update(1) + + # Reset prefix cache if configured, except after the very last request + if cache_reset_url: + await request_func(request_func_input=dummy_req_input) + await reset_cache(cache_reset_url) + + pbar.close() + + benchmark_duration = time.perf_counter() - benchmark_start_time + + # Calculate metrics + metrics = calculate_metrics( + input_requests=input_requests, + outputs=outputs, + dur_s=benchmark_duration, + tokenizer=tokenizer, + selected_percentiles=selected_percentiles, + ) + + print_results(metrics, benchmark_duration) + + result = { + "duration": benchmark_duration, + "completed": metrics.completed, + "total_input_tokens": metrics.total_input, + "total_output_tokens": metrics.total_output, + "input_lens": [request.prompt_len for request in input_requests], + "output_lens": [ + output.output_tokens if output.success else 0 for output in outputs + ], + "ttfts": [output.ttft for output in outputs if output.success], + "itls": [output.itl for output in outputs if output.success], + "generated_texts": [ + output.generated_text for output in outputs if output.success + ], + "errors": [output.error for output in outputs if not output.success], + } + + # Add summary statistics + for stat_name in ["ttft", "itl", "e2el"]: + for metric_name in ["mean", "median", "std"]: + result[f"{metric_name}_{stat_name}_ms"] = getattr( + metrics, f"{metric_name}_{stat_name}_ms" + ) + + for p, value in getattr(metrics, f"percentiles_{stat_name}_ms"): + p_word = str(int(p)) if int(p) == p else str(p) + result[f"p{p_word}_{stat_name}_ms"] = value + + return result + + +def calculate_metrics( + input_requests: list[SampleRequest], + outputs: list[RequestFuncOutput], + dur_s: float, + tokenizer, + selected_percentiles: list[float], +) -> BenchmarkMetrics: + """Calculate benchmark metrics from results.""" + total_input = 0 + completed = 0 + total_output = 0 + ttfts = [] + itls = [] + e2els = [] + + for i, output in enumerate(outputs): + if output.success: + output_len = output.output_tokens + + if not output_len: + # Use tokenizer to count output tokens if not provided + output_len = len( + tokenizer(output.generated_text, add_special_tokens=False).input_ids + ) + + total_output += output_len + total_input += input_requests[i].prompt_len + + if hasattr(output, "ttft") and output.ttft is not None: + ttfts.append(output.ttft) + + if hasattr(output, "itl") and output.itl: + # Ensure itl is a list of floats + if isinstance(output.itl, list): + itls.extend(output.itl) + else: + logger.warning( + "Expected list for ITL but got %s. Appending as is.", + type(output.itl), + ) + itls.append(output.itl) + + if hasattr(output, "latency") and output.latency is not None: + e2els.append(output.latency) + + completed += 1 + + return BenchmarkMetrics( + completed=completed, + total_input=total_input, + total_output=total_output, + mean_ttft_ms=np.mean(ttfts or [0]) * 1000, + median_ttft_ms=np.median(ttfts or [0]) * 1000, + std_ttft_ms=np.std(ttfts or [0]) * 1000, + percentiles_ttft_ms=[ + (p, np.percentile(ttfts or [0], p) * 1000) for p in selected_percentiles + ], + mean_itl_ms=np.mean(itls or [0]) * 1000, + median_itl_ms=np.median(itls or [0]) * 1000, + std_itl_ms=np.std(itls or [0]) * 1000, + percentiles_itl_ms=[ + (p, np.percentile(itls or [0], p) * 1000) for p in selected_percentiles + ], + mean_e2el_ms=np.mean(e2els or [0]) * 1000, + median_e2el_ms=np.median(e2els or [0]) * 1000, + std_e2el_ms=np.std(e2els or [0]) * 1000, + percentiles_e2el_ms=[ + (p, np.percentile(e2els or [0], p) * 1000) for p in selected_percentiles + ], + ) + + +def print_results(metrics: BenchmarkMetrics, benchmark_duration: float): + """Print benchmark results in a formatted way.""" + print("{s:{c}^{n}}".format(s=" Sequential Benchmark Result ", n=60, c="=")) + print("{:<40} {:<10}".format("Successful requests:", metrics.completed)) + print("{:<40} {:<10.2f}".format("Benchmark duration (s):", benchmark_duration)) + print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input)) + print("{:<40} {:<10}".format("Total generated tokens:", metrics.total_output)) + + def print_metric_stats(metric_name, header): + print("{s:{c}^{n}}".format(s=header, n=60, c="-")) + print( + "{:<40} {:<10.2f}".format( + f"Mean {metric_name} (ms):", + getattr(metrics, f"mean_{metric_name.lower()}_ms"), + ) + ) + print( + "{:<40} {:<10.2f}".format( + f"Median {metric_name} (ms):", + getattr(metrics, f"median_{metric_name.lower()}_ms"), + ) + ) + + for p, value in getattr(metrics, f"percentiles_{metric_name.lower()}_ms"): + p_word = str(int(p)) if int(p) == p else str(p) + print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name} (ms):", value)) + + print_metric_stats("TTFT", "Time to First Token") + print_metric_stats("ITL", "Inter-token Latency") + print_metric_stats("E2EL", "End-to-end Latency") + print("=" * 60) + + +async def main_async(args): + # Import needed functions based on your setup + from backend_request_func import ASYNC_REQUEST_FUNCS + + backend = args.backend + model_id = args.model + tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model + + # Set up API URL + if args.base_url is not None: + api_url = f"{args.base_url}{args.endpoint}" + else: + api_url = f"http://{args.host}:{args.port}{args.endpoint}" + + # Set up Cache Reset URL + cache_reset_url = f"http://{args.host}:{args.port}/reset_prefix_cache" + logger.info("Prefix cache reset configured at: %s", cache_reset_url) + + # Get tokenizer + tokenizer = get_tokenizer(tokenizer_id, trust_remote_code=args.trust_remote_code) + + # Get request function + if backend in ASYNC_REQUEST_FUNCS: + request_func = ASYNC_REQUEST_FUNCS[backend] + else: + raise ValueError(f"Unknown backend: {backend}") + + input_requests = RandomDataset().sample( + tokenizer=tokenizer, + num_requests=args.num_requests, + prefix_len=0, + input_len=args.input_len, + output_len=args.output_len, + range_ratio=0.0, + ) + + # Run benchmark + result = await sequential_benchmark( + backend=backend, + api_url=api_url, + model_id=model_id, + tokenizer=tokenizer, + input_requests=input_requests, + request_func=request_func, + selected_percentiles=[50, 90, 95, 99], + cache_reset_url=cache_reset_url, + ) + + return result + + +def main(args): + print(args) + random.seed(args.seed) + np.random.seed(args.seed) + + asyncio.run(main_async(args)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Sequential benchmark for LLM serving") + parser.add_argument( + "--backend", type=str, default="vllm", help="Backend to use for requests" + ) + parser.add_argument( + "--base-url", + type=str, + default=None, + help="Server base URL (overrides --host and --port)", + ) + parser.add_argument("--host", type=str, default="127.0.0.1") + parser.add_argument("--port", type=int, default=8000) + parser.add_argument( + "--endpoint", type=str, default="/v1/completions", help="API endpoint" + ) + parser.add_argument("--model", type=str, required=True, help="Name of the model") + parser.add_argument( + "--tokenizer", type=str, help="Name of the tokenizer (defaults to model name)" + ) + parser.add_argument( + "--num-requests", type=int, default=100, help="Number of requests to process" + ) + parser.add_argument( + "--input-len", type=int, default=128, help="Input len for generated prompts" + ) + parser.add_argument( + "--output-len", type=int, default=None, help="Override output len for requests" + ) + parser.add_argument("--seed", type=int, default=42) + parser.add_argument( + "--trust-remote-code", + action="store_true", + help="Trust remote code from HuggingFace", + ) + + args = parser.parse_args() + main(args) diff --git a/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh b/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh index e90b72a7cf24..2b07e64c3c91 100755 --- a/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh +++ b/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh @@ -4,6 +4,7 @@ set -xe # Models to run MODELS=( "Qwen/Qwen3-0.6B" + "deepseek-ai/deepseek-vl2-tiny" ) # Number of prefill and decode instances to create diff --git a/tests/v1/kv_connector/nixl_integration/toy_proxy_server.py b/tests/v1/kv_connector/nixl_integration/toy_proxy_server.py index 13071f581375..9eaffcffe08a 100644 --- a/tests/v1/kv_connector/nixl_integration/toy_proxy_server.py +++ b/tests/v1/kv_connector/nixl_integration/toy_proxy_server.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 import argparse +import asyncio import itertools import os import uuid @@ -8,7 +9,7 @@ import httpx from fastapi import FastAPI, Request -from fastapi.responses import StreamingResponse +from fastapi.responses import JSONResponse, StreamingResponse from vllm.logger import init_logger @@ -195,6 +196,70 @@ async def stream_service_response(client_info: dict, endpoint: str, yield chunk +async def _forward_reset_cache(client_session: httpx.AsyncClient, host: str, + port: int) -> dict: + target_url = f"http://{host}:{port}/reset_prefix_cache" + + try: + response: httpx.Response = await client_session.post(target_url, + timeout=5.0) + + return { + "status_code": response.status_code, + "error_type": None, + "error_message": None, + } + except Exception as e: + logger.error("Exception occurred sending POST to %s: %s - %s", + target_url, e.__class__.__name__, str(e)) + return { + "status_code": None, + "error_type": e.__class__.__name__, + "error_message": str(e), + } + + +@app.post("/reset_prefix_cache") +async def reset_prefix_cache_on_all_servers(request: Request): + """ + Forwards a reset_prefix_cache request to all prefill and decode servers. + """ + tasks = [] + + def add_reset_tasks_for_servers(server_list): + for server_info in server_list: + tasks.append( + _forward_reset_cache(server_info['client'], + server_info['host'], server_info['port'])) + + add_reset_tasks_for_servers(request.app.state.prefill_clients) + add_reset_tasks_for_servers(request.app.state.decode_clients) + + if not tasks: + return JSONResponse(content={ + "message": + "No prefill or decode servers configured to reset." + }, + status_code=200) + + all_results = await asyncio.gather(*tasks) + + num_prefill_servers = len(request.app.state.prefill_clients) + prefill_server_results = all_results[:num_prefill_servers] + decode_server_results = all_results[num_prefill_servers:] + + response_data = { + "message": + "Simple POST /reset_prefix_cache command forwarded to P/D workers.", + "prefill_servers_status": prefill_server_results, + "decode_servers_status": decode_server_results + } + all_downstream_ok = all( + result.get("error_type") is None for result in all_results) + status_code = 200 if all_downstream_ok else 207 # 207 Multi-Status + return JSONResponse(content=response_data, status_code=status_code) + + @app.post("/v1/completions") async def handle_completions(request: Request): try: diff --git a/tests/v1/kv_connector/unit/test_multi_connector.py b/tests/v1/kv_connector/unit/test_multi_connector.py new file mode 100644 index 000000000000..64da0d79bf33 --- /dev/null +++ b/tests/v1/kv_connector/unit/test_multi_connector.py @@ -0,0 +1,241 @@ +# SPDX-License-Identifier: Apache-2.0 +import filecmp +import shutil +import tempfile +from collections import defaultdict +from pathlib import Path + +from vllm import LLM, SamplingParams +from vllm.config import KVTransferConfig, VllmConfig +from vllm.distributed.kv_transfer.kv_connector.factory import ( + KVConnectorFactory) +from vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector import ( # noqa + SharedStorageConnector) + +MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct" + +PROMPT_CONTEXT = "Hi " * 100 +PROMPTS = [ + PROMPT_CONTEXT + "Hello, my name is", + PROMPT_CONTEXT + "The capital of France is", +] + +SAMPLING_PARAMS = SamplingParams(temperature=0, max_tokens=20) + + +class TestSharedStorageConnector(SharedStorageConnector): + + def __init__(self, config: VllmConfig, role): + self.name = config.kv_transfer_config.kv_connector_extra_config["name"] + self._connector = SharedStorageConnector(config, role) + self.call_record: dict[str, int] = defaultdict(int) + # Use a unique temp file per connector + self._event_file = tempfile.gettempdir( + ) + f"/connector_{self.name}_events.log" + # Start with an empty file + with open(self._event_file, "w") as _: + pass + + def __getattribute__(self, name): + if name in ("_connector", "call_record", "name", "_event_file", + "__class__", "__dict__", "__getattribute__", + "__init__"): # avoid recursion + return object.__getattribute__(self, name) + if not hasattr(self._connector, name): + return object.__getattribute__(self, name) + attr = getattr(self._connector, name) + + # Intercept calls to the connector interface and write an event + # for each one to a file, which can be read back in the main test proc. + if callable(attr): + + def wrapper(*args, **kwargs): + self.call_record[name] += 1 + # Log the event as a line to the file + try: + with open(self._event_file, "a") as f: + f.write(name + "\n") + except Exception as e: + print(f"[ERROR] Could not log event {name} " + f"for {self.name}: {e}") + return attr(*args, **kwargs) + + return wrapper + return attr + + +KVConnectorFactory.register_connector("TestSharedStorageConnector", + TestSharedStorageConnector.__module__, + TestSharedStorageConnector.__name__) + + +# Helper function to compare directories recursively +def _compare_directories(dir1: Path, dir2: Path) -> bool: + """Compares two directories recursively for identical content.""" + dcmp = filecmp.dircmp(dir1, dir2) + if dcmp.left_only or dcmp.right_only or dcmp.diff_files: + print(f"Differences found between {dir1} and {dir2}:") + print(f" Left only: {dcmp.left_only}") + print(f" Right only: {dcmp.right_only}") + print(f" Different files: {dcmp.diff_files}") + return False + for sub_dir in dcmp.common_dirs: + if not _compare_directories(dir1 / sub_dir, dir2 / sub_dir): + return False + return True + + +def test_multi_shared_storage_connector_consistency(): + """ + Tests that MultiConnector with two SharedStorageConnectors saves + identical KV cache data to separate storage locations. + """ + storage_1_path = Path("storage_1/") + storage_2_path = Path("storage_2/") + shutil.rmtree(storage_1_path, ignore_errors=True) + shutil.rmtree(storage_2_path, ignore_errors=True) + storage_1_path.mkdir() + storage_2_path.mkdir() + + # Configure MultiConnector with two SharedStorageConnectors + kv_transfer_config = KVTransferConfig( + kv_connector="MultiConnector", + kv_role="kv_both", + kv_connector_extra_config={ + "connectors": [{ + "kv_connector": "TestSharedStorageConnector", + "kv_role": "kv_both", + "kv_connector_extra_config": { + "shared_storage_path": str(storage_1_path), + "name": "storage1", + } + }, { + "kv_connector": "TestSharedStorageConnector", + "kv_role": "kv_both", + "kv_connector_extra_config": { + "shared_storage_path": str(storage_2_path), + "name": "storage2", + } + }] + }, + ) + + llm = LLM( + model=MODEL_NAME, + enforce_eager=True, + gpu_memory_utilization=0.5, + kv_transfer_config=kv_transfer_config, + ) + # Run generation - this should trigger saving KV cache + _ = llm.generate(PROMPTS, SAMPLING_PARAMS) + + # --- Verification --- + + # Check that both storage directories were populated + local_subdirs = list(storage_1_path.iterdir()) + external_subdirs = list(storage_2_path.iterdir()) + + assert len( + local_subdirs + ) > 0, f"Local storage path {storage_1_path} is empty after generation." + assert len(external_subdirs) > 0, ( + f"External storage path {storage_2_path} is empty after generation.") + assert len(local_subdirs) == len(external_subdirs), ( + f"Mismatch in number of cache entries: " + f"Local={len(local_subdirs)}, External={len(external_subdirs)}") + + # The subdirectories should correspond to the prompt hashes + # Since prompts are the same, the hash directories should be the same name + local_subdir_names = sorted([d.name for d in local_subdirs]) + external_subdir_names = sorted([d.name for d in external_subdirs]) + assert local_subdir_names == external_subdir_names, ( + "Cache directory names do not match between local and external storage" + ) + + # Compare the contents of each corresponding cache directory + for subdir_name in local_subdir_names: + print(f"Comparing contents of cache directory: {subdir_name}") + assert _compare_directories(storage_1_path / subdir_name, + storage_2_path / subdir_name), \ + (f"Contents differ for cache directory '{subdir_name}' between " + f"{storage_1_path} and {storage_2_path}") + + events = get_connector_events() + # get_num_new_matched_tokens will be called on each connector in turn. + # neither of them have hits so update_state_after_alloc won't be called. + assert events["storage1"][:3] == [ + 'get_num_new_matched_tokens', 'build_connector_meta', + 'bind_connector_metadata' + ] + assert events["storage2"][:3] == [ + 'get_num_new_matched_tokens', 'build_connector_meta', + 'bind_connector_metadata' + ] + + # Reset prefix cache or else we'll just get the tokens back from there. + llm.reset_prefix_cache() + + # Run generation again - this should trigger loading from the first + # connector. + _ = llm.generate(PROMPTS, SAMPLING_PARAMS) + + events = get_connector_events() + # get_num_new_matched_tokens will return new tokens from the first + # connector so update_state_after_alloc will be called once blocks + # are allocated for the first connector. + # get_num_new_matched_tokens *won't* be called on the second connector + # in this case. + assert events["storage1"][:4] == [ + 'get_num_new_matched_tokens', 'update_state_after_alloc', + 'build_connector_meta', 'bind_connector_metadata' + ] + assert events["storage2"][:2] == [ + 'build_connector_meta', 'bind_connector_metadata' + ] + + # Delete storage1 connector state + shutil.rmtree(storage_1_path) + + # Reset prefix cache or else we'll just get the tokens back from there. + llm.reset_prefix_cache() + + # Run generation again - this should trigger loading from the first + # connector. + _ = llm.generate(PROMPTS, SAMPLING_PARAMS) + + events = get_connector_events() + # get_num_new_matched_tokens will be called for the first connector but it + # won't have a hit so update_state_after_alloc won't be called. + # get_num_new_matched_tokens will also be called on the second connector, + # but it should have a hit so update_state_after_alloc will be called. + assert events["storage1"][:3] == [ + 'get_num_new_matched_tokens', 'build_connector_meta', + 'bind_connector_metadata' + ] + assert events["storage2"][:4] == [ + 'get_num_new_matched_tokens', 'update_state_after_alloc', + 'build_connector_meta', 'bind_connector_metadata' + ] + + # Clean up + shutil.rmtree(storage_1_path) + shutil.rmtree(storage_2_path) + + +def get_connector_events() -> dict[str, list[str]]: + # Read in connector events and reset the files. + import glob + event_files = glob.glob(tempfile.gettempdir() + "/connector_*_events.log") + connector_events = {} + for fname in event_files: + name = fname.split("connector_")[1].split("_events.log")[0] + try: + with open(fname, "r+") as f: + connector_events[name] = [ + line.strip() for line in f if line.strip() + ] + f.truncate(0) + except Exception as e: + print(f"[ERROR] Could not read connector events for {name}: {e}") + + return connector_events diff --git a/vllm/distributed/kv_transfer/kv_connector/factory.py b/vllm/distributed/kv_transfer/kv_connector/factory.py index 6766d5a24542..f998f5dd7b15 100644 --- a/vllm/distributed/kv_transfer/kv_connector/factory.py +++ b/vllm/distributed/kv_transfer/kv_connector/factory.py @@ -110,3 +110,8 @@ def create_connector_v1( "NixlConnector", "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector", "NixlConnector") + +KVConnectorFactory.register_connector( + "MultiConnector", + "vllm.distributed.kv_transfer.kv_connector.v1.multi_connector", + "MultiConnector") diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py index 03c99f20e775..ef4460a592bd 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py @@ -22,7 +22,6 @@ import enum from abc import ABC, abstractmethod -from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Optional import torch @@ -48,7 +47,6 @@ class KVConnectorRole(enum.Enum): WORKER = 1 -@dataclass class KVConnectorMetadata: """ Abstract Metadata used to communicate between the @@ -185,7 +183,8 @@ def get_finished( finished generating tokens. Returns: - ids of requests that have finished asynchronous (recving, sending). + ids of requests that have finished asynchronous transfer, + tuple of (sending/saving ids, recving/loading ids). The finished saves/sends req ids must belong to a set provided in a call to this method (this call or a prior one). """ diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py index 2cb68dc1ff67..eff7435c0a0a 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Optional import torch from lmcache.integration.vllm.vllm_v1_adapter import LMCacheConnectorV1Impl @@ -25,6 +25,9 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): super().__init__(vllm_config=vllm_config, role=role) self._lmcache_engine = LMCacheConnectorV1Impl(vllm_config, role, self) + self.async_save_supported = hasattr(self._lmcache_engine, + "get_finished") + # ============================== # Worker-side methods # ============================== @@ -86,6 +89,14 @@ def wait_for_save(self): """ self._lmcache_engine.wait_for_save() + def get_finished( + self, finished_req_ids: set[str] + ) -> tuple[Optional[set[str]], Optional[set[str]]]: + if not self.async_save_supported: + return None, None + + return self._lmcache_engine.get_finished(finished_req_ids) + # ============================== # Scheduler-side methods # ============================== @@ -104,8 +115,10 @@ def get_num_new_matched_tokens( computed tokens for this request Returns: - the number of tokens that can be loaded from the - external KV cache beyond what is already computed. + * the number of tokens that can be loaded from the + external KV cache beyond what is already computed. + * true if external KV cache tokens will be loaded + asynchronously (between scheduler steps). """ return self._lmcache_engine.get_num_new_matched_tokens( request, num_computed_tokens), False @@ -131,3 +144,23 @@ def build_connector_meta( scheduler_output (SchedulerOutput): the scheduler output object. """ return self._lmcache_engine.build_connector_meta(scheduler_output) + + def request_finished( + self, + request: "Request", + block_ids: list[int], + ) -> tuple[bool, Optional[dict[str, Any]]]: + """ + Called when a request has finished, before its blocks are freed. + + Returns: + * True if the request is being saved/sent asynchronously and blocks + should not be freed until the request_id is returned from + get_finished(). + * Optional KVTransferParams to be included in the request outputs + returned by the engine. + """ + if not self.async_save_supported: + return False, None + + return self._lmcache_engine.request_finished(request, block_ids) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py new file mode 100644 index 000000000000..cc61a6e99cc2 --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py @@ -0,0 +1,181 @@ +# SPDX-License-Identifier: Apache-2.0 +import copy +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Optional + +import torch + +from vllm.config import KVTransferConfig, VllmConfig +from vllm.distributed.kv_transfer.kv_connector.factory import ( + KVConnectorFactory) +from vllm.distributed.kv_transfer.kv_connector.v1.base import ( + KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) +from vllm.logger import init_logger +from vllm.v1.core.sched.output import SchedulerOutput + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionMetadata + from vllm.forward_context import ForwardContext + from vllm.v1.core.kv_cache_manager import KVCacheBlocks + from vllm.v1.request import Request + +logger = init_logger(__name__) + + +@dataclass +class MultiKVConnectorMetadata(KVConnectorMetadata): + metadata: tuple[KVConnectorMetadata, ...] + extra_async_saves: Optional[dict[str, int]] = None + + +class MultiConnector(KVConnectorBase_V1): + + def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): + super().__init__(vllm_config=vllm_config, role=role) + self._connectors = [] + ktcs = vllm_config.kv_transfer_config.kv_connector_extra_config.get( + "connectors") + assert ktcs is not None + for ktc in ktcs: + temp_config = copy.copy(vllm_config) + temp_config.kv_transfer_config = KVTransferConfig(**ktc) + self._connectors.append( + KVConnectorFactory.create_connector_v1(temp_config, role)) + + # A mapping from request id to the connector that is assigned to it. + self._requests_to_connector: dict[str, KVConnectorBase_V1] = {} + + # Keeps track of *additional* remaining async saves (beyond 1) to be + # finished per request. Not needed for async loads since we only allow + # a single connector to load. + # Propagated from scheduler to worker side via the connector metadata. + self._extra_async_saves: dict[str, int] = {} + + def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): + for c in self._connectors: + c.register_kv_caches(kv_caches) + + # We must override the base class method here because we need to bind + # the metadata to each connector in the order of the connectors in the + # MultiKVConnectorMetadata. + def bind_connector_metadata( + self, connector_metadata: KVConnectorMetadata) -> None: + assert isinstance(connector_metadata, MultiKVConnectorMetadata) + if connector_metadata.extra_async_saves: + self._extra_async_saves.update( + connector_metadata.extra_async_saves) + for c, cm in zip(self._connectors, connector_metadata.metadata): + c.bind_connector_metadata(cm) + + def clear_connector_metadata(self) -> None: + for c in self._connectors: + c.clear_connector_metadata() + + # ============================== + # Worker-side methods + # ============================== + def start_load_kv(self, forward_context: "ForwardContext", + **kwargs) -> None: + for c in self._connectors: + c.start_load_kv(forward_context, **kwargs) + + def wait_for_layer_load(self, layer_name: str) -> None: + for c in self._connectors: + c.wait_for_layer_load(layer_name) + + def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor, + attn_metadata: "AttentionMetadata", **kwargs) -> None: + for c in self._connectors: + c.save_kv_layer(layer_name, kv_layer, attn_metadata, **kwargs) + + def wait_for_save(self): + for c in self._connectors: + c.wait_for_save() + + def get_finished( + self, finished_req_ids: set[str] + ) -> tuple[Optional[set[str]], Optional[set[str]]]: + finished_sending: set[str] = set() + finished_recving: set[str] = set() + for c in self._connectors: + sending, recving = c.get_finished(finished_req_ids) + if not recving and not sending: + continue + # Aggregate finished recving request ids. + finished_recving.update(recving or ()) + # Aggregate finished sending request ids - only include + # once we've drained the "extra" count (for cases where + # more than one connector is async-saving the same request). + for req_id in sending or (): + extra_pending = self._extra_async_saves.get(req_id) + if extra_pending is None: + finished_sending.add(req_id) + continue + assert extra_pending > 0 + if extra_pending == 1: + del self._extra_async_saves[req_id] + else: + self._extra_async_saves[req_id] = extra_pending - 1 + + return finished_sending or None, finished_recving or None + + # ============================== + # Scheduler-side methods + # ============================== + def get_num_new_matched_tokens( + self, + request: "Request", + num_computed_tokens: int, + ) -> tuple[int, bool]: + for c in self._connectors: + toks, load_async = c.get_num_new_matched_tokens( + request, num_computed_tokens) + # The first connector that has new matched tokens will be assigned + # to this request. + if toks > 0: + self._requests_to_connector[request.request_id] = c + return toks, load_async + return 0, False + + def update_state_after_alloc(self, request: "Request", + blocks: "KVCacheBlocks", + num_external_tokens: int): + # If the request is not assigned to any connector, we do nothing. + if request.request_id not in self._requests_to_connector: + return + # We assume that the request is assigned to only one connector. + c = self._requests_to_connector.pop(request.request_id) + c.update_state_after_alloc(request, blocks, num_external_tokens) + + def build_connector_meta( + self, + scheduler_output: SchedulerOutput) -> MultiKVConnectorMetadata: + metadata = MultiKVConnectorMetadata(metadata=tuple( + c.build_connector_meta(scheduler_output) + for c in self._connectors)) + if self._extra_async_saves: + metadata.extra_async_saves = self._extra_async_saves + self._extra_async_saves = {} + return metadata + + def request_finished( + self, + request: "Request", + blocks: "KVCacheBlocks", + ) -> tuple[bool, Optional[dict[str, Any]]]: + async_saves = 0 + kv_txfer_params = None + for c in self._connectors: + async_save, txfer_params = c.request_finished(request, blocks) + if async_save: + async_saves += 1 + if txfer_params is not None: + if kv_txfer_params is not None: + #TODO we can probably change this to merge the dicts here, + # checking for key clashes. + raise RuntimeError( + "Only one connector can produce KV transfer params") + kv_txfer_params = txfer_params + if async_saves > 1: + self._extra_async_saves[request.request_id] = async_saves - 1 + return async_saves > 0, kv_txfer_params diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index abd1ea2bea82..e7f1f2a3a6b7 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -208,7 +208,17 @@ def get_num_new_matched_tokens( rounded_num_prompt_tokens = round_down( len(request.prompt_token_ids), self.block_size) count = max(rounded_num_prompt_tokens - num_computed_tokens, 0) - return count, count > 0 + if count > 0: + return count, True + + # NOTE: if count is 0 here, we have less than block_size + # tokens to pull after subtracting the local prefix cache hit. + # The remote only sends fully computed blocks, so there is + # nothing to transfer but we still need to notify the + # prefill worker so that the remote blocks are freed. + if all(p in params for p in ("remote_engine_id", "remote_host", + "remote_port")): + self._reqs_need_recv[request.request_id] = (request, []) # No remote prefill for this request. return 0, False @@ -224,10 +234,6 @@ def update_state_after_alloc(self, request: "Request", num_external_tokens, params) if params is not None and params.get("do_remote_prefill"): - # NOTE(rob): if prompt < block_size, no remote blocks - # since the remote only sends fully computed blocks, so - # skip recving for this request. num_external_tokens - # should be 0 if there are no remote blocks. if params.get("remote_block_ids"): if all(p in params for p in ("remote_engine_id", "remote_host", "remote_port")): @@ -684,7 +690,8 @@ def _read_blocks( # just notify P worker that we have the blocks we need. num_local_blocks = len(local_block_ids) if num_local_blocks == 0: - self.nixl_wrapper.send_notif(dst_engine_id, + agent_name = self._remote_agents[dst_engine_id] + self.nixl_wrapper.send_notif(agent_name, notif_msg=request_id.encode("utf-8")) return diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index f338e4ba1440..a9d85e534115 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -345,32 +345,38 @@ def schedule(self) -> SchedulerOutput: skipped_waiting_requests.appendleft(request) continue + num_external_computed_tokens = 0 + load_kv_async = False + # Get already-cached tokens. if num_prealloc_computed_tokens == 0: new_computed_blocks, num_native_computed_tokens = \ self.kv_cache_manager.get_computed_blocks( request) + + # Get externally-cached tokens if using a KVConnector. + if self.connector is not None: + num_external_computed_tokens, load_kv_async = ( + self.connector.get_num_new_matched_tokens( + request, num_native_computed_tokens)) + + # Total computed tokens (local + external). + num_computed_tokens = (num_native_computed_tokens + + num_external_computed_tokens) else: # P/D: skip checking prefix cache if loaded from remote kvs. new_computed_blocks = KVCacheBlocks.create_empty() num_native_computed_tokens = 0 - # Get externally-cached tokens if using a KVConnector. - num_external_computed_tokens, load_kv_async = ( - (0, False) if self.connector is None else - self.connector.get_num_new_matched_tokens( - request, num_native_computed_tokens)) - - # Total computed tokens (local + external). - num_computed_tokens = (num_native_computed_tokens + - num_external_computed_tokens + - num_prealloc_computed_tokens) + # Total computed tokens (allocated in prior step). + num_computed_tokens = num_prealloc_computed_tokens encoder_inputs_to_schedule = None new_encoder_budget = encoder_budget # P/D: loading remote KV, do not allocate for new work. if load_kv_async: + assert num_external_computed_tokens > 0 num_new_tokens = 0 # Number of tokens to be scheduled. else: @@ -405,13 +411,21 @@ def schedule(self) -> SchedulerOutput: delay_cache_blocks=load_kv_async, ) if new_blocks is None: + # P/D: if the request is recved on this step, + # then we need to free the kv cache blocks + if num_prealloc_computed_tokens > 0: + assert request.num_computed_tokens != 0 + self.kv_cache_manager.free(request) + request.num_computed_tokens = 0 + # The request cannot be scheduled. break # KVConnector: update internal state after allocation. # This information is used to determine if a load is # needed for this request. - if self.connector is not None: + if num_external_computed_tokens: + assert self.connector is not None self.connector.update_state_after_alloc( request, new_computed_blocks + new_blocks,