Skip to content

Update decode stream api #1780

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion bindings/python/py_src/tokenizers/decoders/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ class DecodeStream:
Class needed for streaming decode

"""
def __init__(self, skip_special_tokens):
def __init__(self, skip_special_tokens=True):
pass

class Decoder:
Expand Down
220 changes: 168 additions & 52 deletions bindings/python/src/decoders.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
use std::sync::{Arc, RwLock};

use crate::pre_tokenizers::from_string;
use crate::tokenizer::PyTokenizer;
use crate::utils::PyPattern;
Expand All @@ -8,6 +6,8 @@ use pyo3::prelude::*;
use pyo3::types::*;
use serde::de::Error;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use tk::decoders::bpe::BPEDecoder;
use tk::decoders::byte_fallback::ByteFallback;
use tk::decoders::byte_level::ByteLevel;
Expand Down Expand Up @@ -603,78 +603,194 @@ impl Decoder for PyDecoderWrapper {
}
}

/// Decoders Module
#[pymodule]
pub fn decoders(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<PyDecoder>()?;
m.add_class::<PyByteLevelDec>()?;
m.add_class::<PyReplaceDec>()?;
m.add_class::<PyWordPieceDec>()?;
m.add_class::<PyByteFallbackDec>()?;
m.add_class::<PyFuseDec>()?;
m.add_class::<PyStrip>()?;
m.add_class::<PyMetaspaceDec>()?;
m.add_class::<PyBPEDecoder>()?;
m.add_class::<PyCTCDecoder>()?;
m.add_class::<PySequenceDecoder>()?;
m.add_class::<PyDecodeStream>()?;
Ok(())
}

/// Class needed for streaming decode
///
#[pyclass(module = "tokenizers.decoders", name = "DecodeStream")]
#[derive(Clone)]
#[derive(Clone, Serialize, Deserialize)]
#[serde(tag = "type")]
pub struct PyDecodeStream {
/// Regular decode option that is kept throughout.
skip_special_tokens: bool,
/// A temporary buffer of the necessary token_ids needed
/// to produce valid string chunks.
/// This typically contains 3 parts:
/// - read
/// - prefix
/// - rest
///
/// Read is the bit necessary to surround the prefix
/// so decoding the whole ids produces a valid prefix.
/// Prefix is the previously produced string, kept around to trim off of
/// the next valid chunk
prefills: Vec<Vec<u32>>,
states: Vec<PyDecodeState>,
hash_to_index: std::collections::HashMap<String, usize>,
prefill_hashes: Vec<String>,
}

#[derive(Clone, Serialize, Deserialize)]
#[serde(tag = "state")]
struct PyDecodeState {
ids: Vec<u32>,
/// The previously returned chunk that needs to be discarded from the
/// decoding of the current ids to produce the next chunk
prefix: String,
/// The index within the ids corresponding to the prefix so we can drain
/// correctly
prefix_index: usize,
}

// Define enum for inputs
#[derive(FromPyObject)]
enum StepInput {
Map(HashMap<String, Vec<u32>>),
Single(Vec<u32>),
}

// Define enum for outputs
#[derive(IntoPyObject)]
enum StepOutput {
Map(HashMap<String, Option<String>>),
Single(Option<String>),
}

#[pymethods]
impl PyDecodeStream {
#[new]
#[pyo3(signature = (skip_special_tokens), text_signature = "(self, skip_special_tokens)")]
fn new(skip_special_tokens: bool) -> Self {
#[pyo3(signature = (skip_special_tokens=true), text_signature = "(self, skip_special_tokens=True)")]
fn new(skip_special_tokens: Option<bool>) -> Self {
let skip_special_tokens = skip_special_tokens.unwrap_or(true);
PyDecodeStream {
skip_special_tokens,
ids: vec![],
prefix: "".to_string(),
prefix_index: 0,
prefills: Vec::new(),
states: Vec::new(),
hash_to_index: HashMap::new(),
prefill_hashes: Vec::new(),
}
}

#[pyo3(signature = (tokenizer, id), text_signature = "(self, tokenizer, id)")]
fn step(&mut self, tokenizer: &PyTokenizer, id: u32) -> PyResult<Option<String>> {
ToPyResult(tk::tokenizer::step_decode_stream(
&tokenizer.tokenizer,
id,
self.skip_special_tokens,
&mut self.ids,
&mut self.prefix,
&mut self.prefix_index,
))
.into()
#[pyo3(signature = (tokenizer, inputs), text_signature = "(self, tokenizer, inputs)")]
fn step(&mut self, tokenizer: &PyTokenizer, inputs: StepInput) -> PyResult<StepOutput> {
match inputs {
StepInput::Map(map) => {
let mut output = HashMap::new();
for (hash, tokens) in map {
let (prefix_ids, last_token) = match tokens.split_last() {
Some((last, prefix)) => (prefix.to_vec(), *last),
None => {
return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
"Empty token sequence",
))
}
};
let idx = if let Some(&idx) = self.hash_to_index.get(&hash) {
idx
} else {
let state = PyDecodeState {
ids: prefix_ids.clone(),
prefix: String::new(),
prefix_index: prefix_ids.len(),
};
self.states.push(state);
self.prefill_hashes.push(hash.clone());
let new_idx = self.states.len() - 1;
self.hash_to_index.insert(hash.clone(), new_idx);
new_idx
};

let state = &mut self.states[idx];
let res = tk::tokenizer::step_decode_stream(
&tokenizer.tokenizer,
last_token,
self.skip_special_tokens,
&mut state.ids,
&mut state.prefix,
&mut state.prefix_index,
)
.map_err(|e| {
PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!(
"decode error: {}",
e
))
})?;
output.insert(hash, res);
}
Ok(StepOutput::Map(output))
}

StepInput::Single(tokens) => {
let (prefix_ids, last_token) = match tokens.split_last() {
Some((last, prefix)) => (prefix.to_vec(), *last),
None => {
return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
"Empty token sequence",
))
}
};
if self.states.is_empty() {
let state = PyDecodeState {
ids: prefix_ids.clone(),
prefix: String::new(),
prefix_index: 0,
};
self.states.push(state);
self.prefill_hashes.push("".into());
self.hash_to_index.insert("".into(), 0);
}
let state = &mut self.states[0];
let res = tk::tokenizer::step_decode_stream(
&tokenizer.tokenizer,
last_token,
self.skip_special_tokens,
&mut state.ids,
&mut state.prefix,
&mut state.prefix_index,
)
.map_err(|e| {
PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!(
"decode error: {}",
e
))
})?;

Ok(StepOutput::Single(res))
}
}
}

#[pyo3(signature = (hashes=None), text_signature = "(self, hashes=None)")]
fn finish(&mut self, hashes: Option<Vec<String>>) {
if let Some(hashes) = hashes {
// If no hashes are provided, we clear all states
use std::collections::HashSet;
let remove_set: HashSet<String> = hashes.into_iter().collect();
for hash in remove_set {
if self.hash_to_index.contains_key(&hash) {
// Remove the state and hash
let idx = self.hash_to_index[&hash];
self.states.remove(idx);
self.prefill_hashes.remove(idx);
self.hash_to_index.remove(&hash);
}
}
} else {
self.states.clear();
self.prefill_hashes.clear();
self.hash_to_index.clear();
return;
// Build a set of hashes to remove
}
}

#[getter]
fn prefill_hashes(&self) -> Vec<String> {
self.prefill_hashes.clone()
}
}

/// Decoders Module
#[pymodule]
pub fn decoders(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<PyDecoder>()?;
m.add_class::<PyByteLevelDec>()?;
m.add_class::<PyReplaceDec>()?;
m.add_class::<PyWordPieceDec>()?;
m.add_class::<PyByteFallbackDec>()?;
m.add_class::<PyFuseDec>()?;
m.add_class::<PyStrip>()?;
m.add_class::<PyMetaspaceDec>()?;
m.add_class::<PyBPEDecoder>()?;
m.add_class::<PyCTCDecoder>()?;
m.add_class::<PySequenceDecoder>()?;
m.add_class::<PyDecodeStream>()?;
Ok(())
}

#[cfg(test)]
mod test {
use std::sync::{Arc, RwLock};
Expand Down
53 changes: 44 additions & 9 deletions bindings/python/tests/bindings/test_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,15 @@
from tokenizers.decoders import ByteFallback, DecodeStream, Metaspace as DecoderMetaspace


def prefill_hash(ids):
h = 0xCBF29CE484222325
prime = 0x100000001B3
for i in ids:
h ^= i
h = (h * prime) & 0xFFFFFFFFFFFFFFFF
return f"{h:016x}"


from ..utils import bert_files, data_dir, multiprocessing_with_parallelism, roberta_files


Expand Down Expand Up @@ -366,10 +375,11 @@ def test_decode(self):

# Can decode stream
stream = DecodeStream(skip_special_tokens=False)
assert stream.step(tokenizer, 0) == "my"
assert stream.step(tokenizer, 1) == " name"
assert stream.step(tokenizer, 2) == " is"
assert stream.step(tokenizer, 3) == " john"
h = prefill_hash([])
assert stream.step(tokenizer, {h: [0]})[h] == "my"
assert stream.step(tokenizer, {h: [1]})[h] == " name"
assert stream.step(tokenizer, {h: [2]})[h] == " is"
assert stream.step(tokenizer, {h: [3]})[h] == " john"

def test_decode_stream(self):
vocab = [
Expand All @@ -381,19 +391,44 @@ def test_decode_stream(self):
tokenizer = Tokenizer(Unigram(vocab, 0, byte_fallback=True))
tokenizer.decoder = ByteFallback()
stream = DecodeStream(skip_special_tokens=False)
assert stream.step(tokenizer, 1) == " "
assert stream.step(tokenizer, 2) == None
assert stream.step(tokenizer, 3) == "é"
h = prefill_hash([])
assert stream.step(tokenizer, {h: [1]})[h] == " "
assert stream.step(tokenizer, {h: [2]})[h] is None
assert stream.step(tokenizer, {h: [3]})[h] == "é"

vocab = [
("<unk>", 0.0),
("▁This", -0.1),
]
tokenizer = Tokenizer(Unigram(vocab, 0, byte_fallback=False))
tokenizer.decoder = DecoderMetaspace()
stream = DecodeStream(skip_special_tokens=False)
h = prefill_hash([])
assert stream.step(tokenizer, {h: [1]})[h] == "This"
assert stream.step(tokenizer, {h: [1]})[h] == " This"

def test_decode_stream_prefills(self):
vocab = [
("<unk>", 0.0),
("▁This", -0.1),
("▁is", -0.2),
("▁a", -0.3),
("▁test", -0.4),
("▁sentence", -0.5),
("<0x20>", -0.6),
("<0xC3>", -0.6),
("<0xA9>", -0.6),
]
tokenizer = Tokenizer(Unigram(vocab, 0, byte_fallback=False))
tokenizer.decoder = DecoderMetaspace()
stream = DecodeStream(skip_special_tokens=False)
assert stream.step(tokenizer, 1) == "This"
assert stream.step(tokenizer, 1) == " This"
assert stream.step(tokenizer, [1]) == "This"
assert stream.step(tokenizer, [1]) == " This"
assert stream.step(tokenizer, {"": [1]}) == {"": " This"}
assert stream.step(tokenizer, {"": [1, 0, 3, 4, 5, 6]}) == {"": "<0x20>"}
stream.finish()
print(stream)
assert stream.step(tokenizer, {"": [6, 7, 8]}) == {"": "ë"}

def test_get_vocab(self):
tokenizer = Tokenizer(BPE())
Expand Down
7 changes: 4 additions & 3 deletions tokenizers/src/tokenizer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1012,6 +1012,7 @@ where
/// assert_eq!(decode_stream.step(0).unwrap(), Some("This".to_string()));
/// assert_eq!(decode_stream.step(0).unwrap(), Some(" This".to_string()));
/// ```
#[derive(Clone, Serialize)]
pub struct DecodeStream<'tok, M, N, PT, PP, D> {
/// A reference to the tokenizer
tokenizer: &'tok TokenizerImpl<M, N, PT, PP, D>,
Expand Down Expand Up @@ -1064,9 +1065,9 @@ where
/// See [`DecodeStream`]
pub fn step(&mut self, id: u32) -> Result<Option<String>> {
step_decode_stream(
self.tokenizer,
id,
self.skip_special_tokens,
self.tokenizer,
id,
self.skip_special_tokens,
&mut self.ids,
&mut self.prefix,
&mut self.prefix_index,
Expand Down
Loading