@@ -1348,40 +1348,52 @@ def apply_grammar_bitmask(
1348
1348
scheduler_output : "SchedulerOutput" ,
1349
1349
logits : torch .Tensor ,
1350
1350
) -> torch .Tensor :
1351
- # Serialization of np.ndarray is much more efficient than a tensor,
1352
- # so we receive it in that format.
1353
1351
grammar_bitmask = scheduler_output .grammar_bitmask
1354
1352
1355
- # We receive the structured output bitmask from the scheduler, but the
1356
- # indices of the requests in the batch may not match the indices of
1357
- # the bitmask since the scheduler doesn't know how the gpu runner is
1358
- # ordering the requests in the batch. We need to sort the bitmask to
1359
- # match the order of the requests used here.
1353
+ # We receive the structured output bitmask from the scheduler,
1354
+ # compacted to contain bitmasks only for structured output requests.
1355
+ # The order of the requests in the bitmask is not guaranteed to be the
1356
+ # same as the order of the requests in the gpu runner's batch. We need
1357
+ # to sort the bitmask to match the order of the requests used here.
1358
+
1359
+ # Get the batch indices of the structured output requests.
1360
+ # Keep track of the number of speculative tokens scheduled for every
1361
+ # request in the batch, as the logit indices are offset by this amount.
1360
1362
struct_out_req_batch_indices : dict [str , int ] = {}
1361
- indices_match = True
1362
- for req_id in self .input_batch .req_ids :
1363
- mask_index = scheduler_output .structured_output_request_ids .get (
1364
- req_id )
1365
- if mask_index is None :
1366
- # not a structured output request
1367
- continue
1368
- batch_index = self .input_batch .req_id_to_index [req_id ]
1369
- if batch_index != mask_index :
1370
- indices_match = False
1371
- struct_out_req_batch_indices [req_id ] = batch_index
1372
-
1373
- if not indices_match :
1374
- # Sort the bitmask to match the order of the requests
1375
- sorted_bitmask = np .zeros_like (grammar_bitmask )
1376
- for req_id , batch_index in struct_out_req_batch_indices .items ():
1377
- orig_index = scheduler_output .structured_output_request_ids [
1378
- req_id ]
1379
- sorted_bitmask [batch_index ] = grammar_bitmask [orig_index ]
1380
- grammar_bitmask = sorted_bitmask
1363
+ cumulative_offset = 0
1364
+ seq = sorted (self .input_batch .req_id_to_index .items (),
1365
+ key = lambda x : x [1 ])
1366
+ for req_id , batch_index in seq :
1367
+ logit_index = batch_index + cumulative_offset
1368
+ cumulative_offset += len (
1369
+ scheduler_output .scheduled_spec_decode_tokens .get (req_id , []))
1370
+ if req_id in scheduler_output .structured_output_request_ids :
1371
+ struct_out_req_batch_indices [req_id ] = logit_index
1372
+
1373
+ out_indices = []
1374
+
1375
+ # Reorder the bitmask to match the order of the requests in the batch.
1376
+ sorted_bitmask = np .zeros_like (grammar_bitmask ,
1377
+ shape = (logits .shape [0 ],
1378
+ grammar_bitmask .shape [1 ]))
1379
+ cumulative_index = 0
1380
+ seq = sorted (scheduler_output .structured_output_request_ids .items (),
1381
+ key = lambda x : x [1 ])
1382
+ for req_id , _ in seq :
1383
+ logit_index = struct_out_req_batch_indices [req_id ]
1384
+ num_spec_tokens = len (
1385
+ scheduler_output .scheduled_spec_decode_tokens .get (req_id , []))
1386
+ for i in range (1 + num_spec_tokens ):
1387
+ sorted_bitmask [logit_index + i ] = \
1388
+ grammar_bitmask [cumulative_index + i ]
1389
+ out_indices .append (logit_index + i )
1390
+ cumulative_index += 1 + num_spec_tokens
1391
+ grammar_bitmask = sorted_bitmask
1381
1392
1393
+ # Serialization of np.ndarray is much more efficient than a tensor,
1394
+ # so we receive it in that format.
1382
1395
grammar_bitmask = torch .from_numpy (grammar_bitmask )
1383
1396
1384
- # TODO: compatibility with spec decode.
1385
1397
# NOTE:
1386
1398
# 1. XGrammar bitmask applying only supports CPU and GPU.
1387
1399
# 2. The logits and bitmask should be on the same device.
@@ -1391,7 +1403,7 @@ def apply_grammar_bitmask(
1391
1403
xgr .apply_token_bitmask_inplace (
1392
1404
logits ,
1393
1405
grammar_bitmask ,
1394
- indices = list ( struct_out_req_batch_indices . values ()) ,
1406
+ indices = out_indices ,
1395
1407
)
1396
1408
return logits .to (self .device ).to (logits_dtype )
1397
1409
0 commit comments