diff --git a/bindings/python/py_src/tokenizers/decoders/__init__.pyi b/bindings/python/py_src/tokenizers/decoders/__init__.pyi index 488dc7708..230f15498 100644 --- a/bindings/python/py_src/tokenizers/decoders/__init__.pyi +++ b/bindings/python/py_src/tokenizers/decoders/__init__.pyi @@ -4,7 +4,7 @@ class DecodeStream: Class needed for streaming decode """ - def __init__(self, skip_special_tokens): + def __init__(self, skip_special_tokens=True): pass class Decoder: diff --git a/bindings/python/src/decoders.rs b/bindings/python/src/decoders.rs index d85289a25..0d90a2376 100644 --- a/bindings/python/src/decoders.rs +++ b/bindings/python/src/decoders.rs @@ -1,5 +1,3 @@ -use std::sync::{Arc, RwLock}; - use crate::pre_tokenizers::from_string; use crate::tokenizer::PyTokenizer; use crate::utils::PyPattern; @@ -8,6 +6,8 @@ use pyo3::prelude::*; use pyo3::types::*; use serde::de::Error; use serde::{Deserialize, Deserializer, Serialize, Serializer}; +use std::collections::HashMap; +use std::sync::{Arc, RwLock}; use tk::decoders::bpe::BPEDecoder; use tk::decoders::byte_fallback::ByteFallback; use tk::decoders::byte_level::ByteLevel; @@ -603,78 +603,194 @@ impl Decoder for PyDecoderWrapper { } } -/// Decoders Module -#[pymodule] -pub fn decoders(m: &Bound<'_, PyModule>) -> PyResult<()> { - m.add_class::()?; - m.add_class::()?; - m.add_class::()?; - m.add_class::()?; - m.add_class::()?; - m.add_class::()?; - m.add_class::()?; - m.add_class::()?; - m.add_class::()?; - m.add_class::()?; - m.add_class::()?; - m.add_class::()?; - Ok(()) -} - /// Class needed for streaming decode /// #[pyclass(module = "tokenizers.decoders", name = "DecodeStream")] -#[derive(Clone)] +#[derive(Clone, Serialize, Deserialize)] +#[serde(tag = "type")] pub struct PyDecodeStream { /// Regular decode option that is kept throughout. skip_special_tokens: bool, - /// A temporary buffer of the necessary token_ids needed - /// to produce valid string chunks. - /// This typically contains 3 parts: - /// - read - /// - prefix - /// - rest - /// - /// Read is the bit necessary to surround the prefix - /// so decoding the whole ids produces a valid prefix. - /// Prefix is the previously produced string, kept around to trim off of - /// the next valid chunk + prefills: Vec>, + states: Vec, + hash_to_index: std::collections::HashMap, + prefill_hashes: Vec, +} + +#[derive(Clone, Serialize, Deserialize)] +#[serde(tag = "state")] +struct PyDecodeState { ids: Vec, - /// The previously returned chunk that needs to be discarded from the - /// decoding of the current ids to produce the next chunk prefix: String, - /// The index within the ids corresponding to the prefix so we can drain - /// correctly prefix_index: usize, } +// Define enum for inputs +#[derive(FromPyObject)] +enum StepInput { + Map(HashMap>), + Single(Vec), +} + +// Define enum for outputs +#[derive(IntoPyObject)] +enum StepOutput { + Map(HashMap>), + Single(Option), +} + #[pymethods] impl PyDecodeStream { #[new] - #[pyo3(signature = (skip_special_tokens), text_signature = "(self, skip_special_tokens)")] - fn new(skip_special_tokens: bool) -> Self { + #[pyo3(signature = (skip_special_tokens=true), text_signature = "(self, skip_special_tokens=True)")] + fn new(skip_special_tokens: Option) -> Self { + let skip_special_tokens = skip_special_tokens.unwrap_or(true); PyDecodeStream { skip_special_tokens, - ids: vec![], - prefix: "".to_string(), - prefix_index: 0, + prefills: Vec::new(), + states: Vec::new(), + hash_to_index: HashMap::new(), + prefill_hashes: Vec::new(), } } - #[pyo3(signature = (tokenizer, id), text_signature = "(self, tokenizer, id)")] - fn step(&mut self, tokenizer: &PyTokenizer, id: u32) -> PyResult> { - ToPyResult(tk::tokenizer::step_decode_stream( - &tokenizer.tokenizer, - id, - self.skip_special_tokens, - &mut self.ids, - &mut self.prefix, - &mut self.prefix_index, - )) - .into() + #[pyo3(signature = (tokenizer, inputs), text_signature = "(self, tokenizer, inputs)")] + fn step(&mut self, tokenizer: &PyTokenizer, inputs: StepInput) -> PyResult { + match inputs { + StepInput::Map(map) => { + let mut output = HashMap::new(); + for (hash, tokens) in map { + let (prefix_ids, last_token) = match tokens.split_last() { + Some((last, prefix)) => (prefix.to_vec(), *last), + None => { + return Err(PyErr::new::( + "Empty token sequence", + )) + } + }; + let idx = if let Some(&idx) = self.hash_to_index.get(&hash) { + idx + } else { + let state = PyDecodeState { + ids: prefix_ids.clone(), + prefix: String::new(), + prefix_index: prefix_ids.len(), + }; + self.states.push(state); + self.prefill_hashes.push(hash.clone()); + let new_idx = self.states.len() - 1; + self.hash_to_index.insert(hash.clone(), new_idx); + new_idx + }; + + let state = &mut self.states[idx]; + let res = tk::tokenizer::step_decode_stream( + &tokenizer.tokenizer, + last_token, + self.skip_special_tokens, + &mut state.ids, + &mut state.prefix, + &mut state.prefix_index, + ) + .map_err(|e| { + PyErr::new::(format!( + "decode error: {}", + e + )) + })?; + output.insert(hash, res); + } + Ok(StepOutput::Map(output)) + } + + StepInput::Single(tokens) => { + let (prefix_ids, last_token) = match tokens.split_last() { + Some((last, prefix)) => (prefix.to_vec(), *last), + None => { + return Err(PyErr::new::( + "Empty token sequence", + )) + } + }; + if self.states.is_empty() { + let state = PyDecodeState { + ids: prefix_ids.clone(), + prefix: String::new(), + prefix_index: 0, + }; + self.states.push(state); + self.prefill_hashes.push("".into()); + self.hash_to_index.insert("".into(), 0); + } + let state = &mut self.states[0]; + let res = tk::tokenizer::step_decode_stream( + &tokenizer.tokenizer, + last_token, + self.skip_special_tokens, + &mut state.ids, + &mut state.prefix, + &mut state.prefix_index, + ) + .map_err(|e| { + PyErr::new::(format!( + "decode error: {}", + e + )) + })?; + + Ok(StepOutput::Single(res)) + } + } + } + + #[pyo3(signature = (hashes=None), text_signature = "(self, hashes=None)")] + fn finish(&mut self, hashes: Option>) { + if let Some(hashes) = hashes { + // If no hashes are provided, we clear all states + use std::collections::HashSet; + let remove_set: HashSet = hashes.into_iter().collect(); + for hash in remove_set { + if self.hash_to_index.contains_key(&hash) { + // Remove the state and hash + let idx = self.hash_to_index[&hash]; + self.states.remove(idx); + self.prefill_hashes.remove(idx); + self.hash_to_index.remove(&hash); + } + } + } else { + self.states.clear(); + self.prefill_hashes.clear(); + self.hash_to_index.clear(); + return; + // Build a set of hashes to remove + } + } + + #[getter] + fn prefill_hashes(&self) -> Vec { + self.prefill_hashes.clone() } } +/// Decoders Module +#[pymodule] +pub fn decoders(m: &Bound<'_, PyModule>) -> PyResult<()> { + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + Ok(()) +} + #[cfg(test)] mod test { use std::sync::{Arc, RwLock}; diff --git a/bindings/python/tests/bindings/test_tokenizer.py b/bindings/python/tests/bindings/test_tokenizer.py index d50f283e7..de9266970 100644 --- a/bindings/python/tests/bindings/test_tokenizer.py +++ b/bindings/python/tests/bindings/test_tokenizer.py @@ -12,6 +12,15 @@ from tokenizers.decoders import ByteFallback, DecodeStream, Metaspace as DecoderMetaspace +def prefill_hash(ids): + h = 0xCBF29CE484222325 + prime = 0x100000001B3 + for i in ids: + h ^= i + h = (h * prime) & 0xFFFFFFFFFFFFFFFF + return f"{h:016x}" + + from ..utils import bert_files, data_dir, multiprocessing_with_parallelism, roberta_files @@ -366,10 +375,11 @@ def test_decode(self): # Can decode stream stream = DecodeStream(skip_special_tokens=False) - assert stream.step(tokenizer, 0) == "my" - assert stream.step(tokenizer, 1) == " name" - assert stream.step(tokenizer, 2) == " is" - assert stream.step(tokenizer, 3) == " john" + h = prefill_hash([]) + assert stream.step(tokenizer, {h: [0]})[h] == "my" + assert stream.step(tokenizer, {h: [1]})[h] == " name" + assert stream.step(tokenizer, {h: [2]})[h] == " is" + assert stream.step(tokenizer, {h: [3]})[h] == " john" def test_decode_stream(self): vocab = [ @@ -381,19 +391,44 @@ def test_decode_stream(self): tokenizer = Tokenizer(Unigram(vocab, 0, byte_fallback=True)) tokenizer.decoder = ByteFallback() stream = DecodeStream(skip_special_tokens=False) - assert stream.step(tokenizer, 1) == " " - assert stream.step(tokenizer, 2) == None - assert stream.step(tokenizer, 3) == "é" + h = prefill_hash([]) + assert stream.step(tokenizer, {h: [1]})[h] == " " + assert stream.step(tokenizer, {h: [2]})[h] is None + assert stream.step(tokenizer, {h: [3]})[h] == "é" + + vocab = [ + ("", 0.0), + ("▁This", -0.1), + ] + tokenizer = Tokenizer(Unigram(vocab, 0, byte_fallback=False)) + tokenizer.decoder = DecoderMetaspace() + stream = DecodeStream(skip_special_tokens=False) + h = prefill_hash([]) + assert stream.step(tokenizer, {h: [1]})[h] == "This" + assert stream.step(tokenizer, {h: [1]})[h] == " This" + def test_decode_stream_prefills(self): vocab = [ ("", 0.0), ("▁This", -0.1), + ("▁is", -0.2), + ("▁a", -0.3), + ("▁test", -0.4), + ("▁sentence", -0.5), + ("<0x20>", -0.6), + ("<0xC3>", -0.6), + ("<0xA9>", -0.6), ] tokenizer = Tokenizer(Unigram(vocab, 0, byte_fallback=False)) tokenizer.decoder = DecoderMetaspace() stream = DecodeStream(skip_special_tokens=False) - assert stream.step(tokenizer, 1) == "This" - assert stream.step(tokenizer, 1) == " This" + assert stream.step(tokenizer, [1]) == "This" + assert stream.step(tokenizer, [1]) == " This" + assert stream.step(tokenizer, {"": [1]}) == {"": " This"} + assert stream.step(tokenizer, {"": [1, 0, 3, 4, 5, 6]}) == {"": "<0x20>"} + stream.finish() + print(stream) + assert stream.step(tokenizer, {"": [6, 7, 8]}) == {"": "ë"} def test_get_vocab(self): tokenizer = Tokenizer(BPE()) diff --git a/tokenizers/src/tokenizer/mod.rs b/tokenizers/src/tokenizer/mod.rs index f4a136091..0c9788fb2 100644 --- a/tokenizers/src/tokenizer/mod.rs +++ b/tokenizers/src/tokenizer/mod.rs @@ -1012,6 +1012,7 @@ where /// assert_eq!(decode_stream.step(0).unwrap(), Some("This".to_string())); /// assert_eq!(decode_stream.step(0).unwrap(), Some(" This".to_string())); /// ``` +#[derive(Clone, Serialize)] pub struct DecodeStream<'tok, M, N, PT, PP, D> { /// A reference to the tokenizer tokenizer: &'tok TokenizerImpl, @@ -1064,9 +1065,9 @@ where /// See [`DecodeStream`] pub fn step(&mut self, id: u32) -> Result> { step_decode_stream( - self.tokenizer, - id, - self.skip_special_tokens, + self.tokenizer, + id, + self.skip_special_tokens, &mut self.ids, &mut self.prefix, &mut self.prefix_index,