Skip to content

Commit 72eceff

Browse files
authored
[Bugfix] grammar_bitmask IndexError caused by outdated apply_grammar_bitmask method (#2022)
### What this PR does / why we need it? Fix #2033 Sync vllm-project/vllm#14702 to solve `grammar_bitmask` IndexError caused by outdated `apply_grammar_bitmask` method ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Tested by upstream vllm - vLLM version: v0.10.0 - vLLM main: vllm-project/vllm@6e599ee Signed-off-by: ApsarasX <[email protected]>
1 parent 75e28d0 commit 72eceff

File tree

1 file changed

+41
-29
lines changed

1 file changed

+41
-29
lines changed

vllm_ascend/worker/model_runner_v1.py

Lines changed: 41 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1348,40 +1348,52 @@ def apply_grammar_bitmask(
13481348
scheduler_output: "SchedulerOutput",
13491349
logits: torch.Tensor,
13501350
) -> torch.Tensor:
1351-
# Serialization of np.ndarray is much more efficient than a tensor,
1352-
# so we receive it in that format.
13531351
grammar_bitmask = scheduler_output.grammar_bitmask
13541352

1355-
# We receive the structured output bitmask from the scheduler, but the
1356-
# indices of the requests in the batch may not match the indices of
1357-
# the bitmask since the scheduler doesn't know how the gpu runner is
1358-
# ordering the requests in the batch. We need to sort the bitmask to
1359-
# match the order of the requests used here.
1353+
# We receive the structured output bitmask from the scheduler,
1354+
# compacted to contain bitmasks only for structured output requests.
1355+
# The order of the requests in the bitmask is not guaranteed to be the
1356+
# same as the order of the requests in the gpu runner's batch. We need
1357+
# to sort the bitmask to match the order of the requests used here.
1358+
1359+
# Get the batch indices of the structured output requests.
1360+
# Keep track of the number of speculative tokens scheduled for every
1361+
# request in the batch, as the logit indices are offset by this amount.
13601362
struct_out_req_batch_indices: dict[str, int] = {}
1361-
indices_match = True
1362-
for req_id in self.input_batch.req_ids:
1363-
mask_index = scheduler_output.structured_output_request_ids.get(
1364-
req_id)
1365-
if mask_index is None:
1366-
# not a structured output request
1367-
continue
1368-
batch_index = self.input_batch.req_id_to_index[req_id]
1369-
if batch_index != mask_index:
1370-
indices_match = False
1371-
struct_out_req_batch_indices[req_id] = batch_index
1372-
1373-
if not indices_match:
1374-
# Sort the bitmask to match the order of the requests
1375-
sorted_bitmask = np.zeros_like(grammar_bitmask)
1376-
for req_id, batch_index in struct_out_req_batch_indices.items():
1377-
orig_index = scheduler_output.structured_output_request_ids[
1378-
req_id]
1379-
sorted_bitmask[batch_index] = grammar_bitmask[orig_index]
1380-
grammar_bitmask = sorted_bitmask
1363+
cumulative_offset = 0
1364+
seq = sorted(self.input_batch.req_id_to_index.items(),
1365+
key=lambda x: x[1])
1366+
for req_id, batch_index in seq:
1367+
logit_index = batch_index + cumulative_offset
1368+
cumulative_offset += len(
1369+
scheduler_output.scheduled_spec_decode_tokens.get(req_id, []))
1370+
if req_id in scheduler_output.structured_output_request_ids:
1371+
struct_out_req_batch_indices[req_id] = logit_index
1372+
1373+
out_indices = []
1374+
1375+
# Reorder the bitmask to match the order of the requests in the batch.
1376+
sorted_bitmask = np.zeros_like(grammar_bitmask,
1377+
shape=(logits.shape[0],
1378+
grammar_bitmask.shape[1]))
1379+
cumulative_index = 0
1380+
seq = sorted(scheduler_output.structured_output_request_ids.items(),
1381+
key=lambda x: x[1])
1382+
for req_id, _ in seq:
1383+
logit_index = struct_out_req_batch_indices[req_id]
1384+
num_spec_tokens = len(
1385+
scheduler_output.scheduled_spec_decode_tokens.get(req_id, []))
1386+
for i in range(1 + num_spec_tokens):
1387+
sorted_bitmask[logit_index + i] = \
1388+
grammar_bitmask[cumulative_index + i]
1389+
out_indices.append(logit_index + i)
1390+
cumulative_index += 1 + num_spec_tokens
1391+
grammar_bitmask = sorted_bitmask
13811392

1393+
# Serialization of np.ndarray is much more efficient than a tensor,
1394+
# so we receive it in that format.
13821395
grammar_bitmask = torch.from_numpy(grammar_bitmask)
13831396

1384-
# TODO: compatibility with spec decode.
13851397
# NOTE:
13861398
# 1. XGrammar bitmask applying only supports CPU and GPU.
13871399
# 2. The logits and bitmask should be on the same device.
@@ -1391,7 +1403,7 @@ def apply_grammar_bitmask(
13911403
xgr.apply_token_bitmask_inplace(
13921404
logits,
13931405
grammar_bitmask,
1394-
indices=list(struct_out_req_batch_indices.values()),
1406+
indices=out_indices,
13951407
)
13961408
return logits.to(self.device).to(logits_dtype)
13971409

0 commit comments

Comments
 (0)