From 5baef7b28a564addf798cc823748521bcfeebf5f Mon Sep 17 00:00:00 2001 From: Garvys Date: Wed, 23 Sep 2020 16:38:04 +0200 Subject: [PATCH 01/12] is_final and is_final_unchecked methods of the Fst trait are now required --- rustfst-cli/src/cmds/compose.rs | 18 +++++------ .../algorithms/supported_algorithms.py | 2 +- rustfst/src/algorithms/closure/closure_fst.rs | 8 +++++ rustfst/src/algorithms/compose/add_on.rs | 8 +++++ rustfst/src/algorithms/compose/compose_fst.rs | 8 +++++ rustfst/src/algorithms/compose/matcher_fst.rs | 8 +++++ rustfst/src/algorithms/concat/concat_fst.rs | 8 +++++ .../algorithms/determinize/determinize_fsa.rs | 8 +++++ .../factor_weight/factor_weight_fst.rs | 8 +++++ .../src/algorithms/lazy/cache/arc_cache.rs | 8 +++++ .../src/algorithms/lazy/cache/cache_status.rs | 30 ++++++++++++++++++- .../src/algorithms/lazy/cache/first_cache.rs | 8 +++++ .../src/algorithms/lazy/cache/fst_cache.rs | 3 ++ .../lazy/cache/simple_hash_map_cache.rs | 17 +++++++++++ .../algorithms/lazy/cache/simple_vec_cache.rs | 19 ++++++++++++ rustfst/src/algorithms/lazy/fst_op.rs | 2 +- rustfst/src/algorithms/lazy/lazy_fst.rs | 10 +++++++ rustfst/src/algorithms/lazy/lazy_fst_2.rs | 13 ++++++++ rustfst/src/algorithms/replace/replace_fst.rs | 8 +++++ .../algorithms/rm_epsilon/rm_epsilon_fst.rs | 8 +++++ rustfst/src/algorithms/union/union_fst.rs | 8 +++++ rustfst/src/fst_impls/arc.rs | 8 +++++ rustfst/src/fst_impls/const_fst/fst.rs | 13 ++++++++ rustfst/src/fst_impls/vector_fst/fst.rs | 13 ++++++++ rustfst/src/fst_traits/fst.rs | 11 ++----- 25 files changed, 234 insertions(+), 21 deletions(-) diff --git a/rustfst-cli/src/cmds/compose.rs b/rustfst-cli/src/cmds/compose.rs index 569b35975..709d7ab11 100644 --- a/rustfst-cli/src/cmds/compose.rs +++ b/rustfst-cli/src/cmds/compose.rs @@ -2,13 +2,6 @@ use std::sync::Arc; use anyhow::Result; -use rustfst::algorithms::compose::{ - compose, ComposeFst, ComposeFstOpOptions, LabelReachableData, MatcherFst, -}; -use rustfst::fst_impls::VectorFst; -use rustfst::semirings::TropicalWeight; - -use crate::binary_fst_algorithm::BinaryFstAlgorithm; use rustfst::algorithms::compose::compose_filters::{ AltSequenceComposeFilterBuilder, ComposeFilterBuilder, }; @@ -20,9 +13,16 @@ use rustfst::algorithms::compose::lookahead_matchers::{ LabelLookAheadMatcher, LookaheadMatcher, MatcherFlagsTrait, }; use rustfst::algorithms::compose::matchers::{MatchType, Matcher, MatcherFlags, SortedMatcher}; -use rustfst::algorithms::lazy::SimpleHashMapCache; +use rustfst::algorithms::compose::{ + compose, ComposeFst, ComposeFstOpOptions, LabelReachableData, MatcherFst, +}; +use rustfst::algorithms::lazy::SimpleVecCache; use rustfst::algorithms::tr_compares::ILabelCompare; use rustfst::algorithms::tr_sort; +use rustfst::fst_impls::VectorFst; +use rustfst::semirings::TropicalWeight; + +use crate::binary_fst_algorithm::BinaryFstAlgorithm; #[derive(Debug, Clone, Copy)] pub enum ComposeType { @@ -136,7 +136,7 @@ impl BinaryFstAlgorithm for ComposeAlgorithm { None, ); - let dyn_fst = ComposeFst::<_, _, SimpleHashMapCache<_>>::new_with_options( + let dyn_fst = ComposeFst::<_, _, SimpleVecCache<_>>::new_with_options( graph1look, fst_2, compose_options, diff --git a/rustfst-python-bench/rustfst_python_bench/algorithms/supported_algorithms.py b/rustfst-python-bench/rustfst_python_bench/algorithms/supported_algorithms.py index 793a4e3ef..b36c9b477 100644 --- a/rustfst-python-bench/rustfst_python_bench/algorithms/supported_algorithms.py +++ b/rustfst-python-bench/rustfst_python_bench/algorithms/supported_algorithms.py @@ -37,5 +37,5 @@ def get(cls, algoname): SupportedAlgorithms.register("reverse", ReverseAlgorithm) # SupportedAlgorithms.register("rmfinalepsilon", RmFinalEpsilonAlgorithm) SupportedAlgorithms.register("shortestpath", ShortestPathAlgorithm) -# SupportedAlgorithms.register("compose", ComposeAlgorithm) +SupportedAlgorithms.register("compose", ComposeAlgorithm) diff --git a/rustfst/src/algorithms/closure/closure_fst.rs b/rustfst/src/algorithms/closure/closure_fst.rs index 46dc116b0..f639d97ac 100644 --- a/rustfst/src/algorithms/closure/closure_fst.rs +++ b/rustfst/src/algorithms/closure/closure_fst.rs @@ -95,6 +95,14 @@ where self.0.num_trs_unchecked(s) } + fn is_final(&self, state_id: usize) -> Result { + self.0.is_final(state_id) + } + + unsafe fn is_final_unchecked(&self, state_id: usize) -> bool { + self.0.is_final_unchecked(state_id) + } + fn get_trs(&self, state_id: usize) -> Result { self.0.get_trs(state_id) } diff --git a/rustfst/src/algorithms/compose/add_on.rs b/rustfst/src/algorithms/compose/add_on.rs index 68bb33995..8d2d4df67 100644 --- a/rustfst/src/algorithms/compose/add_on.rs +++ b/rustfst/src/algorithms/compose/add_on.rs @@ -57,6 +57,14 @@ impl, T> CoreFst for FstAddOn { self.fst.num_trs_unchecked(s) } + fn is_final(&self, state_id: usize) -> Result { + self.fst.is_final(state_id) + } + + unsafe fn is_final_unchecked(&self, state_id: usize) -> bool { + self.fst.is_final_unchecked(state_id) + } + fn get_trs(&self, state_id: usize) -> Result { self.fst.get_trs(state_id) } diff --git a/rustfst/src/algorithms/compose/compose_fst.rs b/rustfst/src/algorithms/compose/compose_fst.rs index 3b9474128..964578ee0 100644 --- a/rustfst/src/algorithms/compose/compose_fst.rs +++ b/rustfst/src/algorithms/compose/compose_fst.rs @@ -129,6 +129,14 @@ where self.0.num_trs_unchecked(s) } + fn is_final(&self, state_id: usize) -> Result { + self.0.is_final(state_id) + } + + unsafe fn is_final_unchecked(&self, state_id: usize) -> bool { + self.0.is_final_unchecked(state_id) + } + fn get_trs(&self, state_id: usize) -> Result { self.0.get_trs(state_id) } diff --git a/rustfst/src/algorithms/compose/matcher_fst.rs b/rustfst/src/algorithms/compose/matcher_fst.rs index f3eb94dc7..8520f2a22 100644 --- a/rustfst/src/algorithms/compose/matcher_fst.rs +++ b/rustfst/src/algorithms/compose/matcher_fst.rs @@ -112,6 +112,14 @@ impl, M, T> CoreFst for MatcherFst { self.fst_add_on.num_trs_unchecked(s) } + fn is_final(&self, state_id: usize) -> Result { + self.fst_add_on.is_final(state_id) + } + + unsafe fn is_final_unchecked(&self, state_id: usize) -> bool { + self.fst_add_on.is_final_unchecked(state_id) + } + fn get_trs(&self, state_id: usize) -> Result { self.fst_add_on.get_trs(state_id) } diff --git a/rustfst/src/algorithms/concat/concat_fst.rs b/rustfst/src/algorithms/concat/concat_fst.rs index 9bc574d5a..4f3c7db16 100644 --- a/rustfst/src/algorithms/concat/concat_fst.rs +++ b/rustfst/src/algorithms/concat/concat_fst.rs @@ -83,6 +83,14 @@ where self.0.num_trs_unchecked(s) } + fn is_final(&self, state_id: usize) -> Result { + self.0.is_final(state_id) + } + + unsafe fn is_final_unchecked(&self, state_id: usize) -> bool { + self.0.is_final_unchecked(state_id) + } + fn get_trs(&self, state_id: usize) -> Result { self.0.get_trs(state_id) } diff --git a/rustfst/src/algorithms/determinize/determinize_fsa.rs b/rustfst/src/algorithms/determinize/determinize_fsa.rs index 39e6254fb..18c884015 100644 --- a/rustfst/src/algorithms/determinize/determinize_fsa.rs +++ b/rustfst/src/algorithms/determinize/determinize_fsa.rs @@ -43,6 +43,14 @@ where self.0.num_trs_unchecked(s) } + fn is_final(&self, state_id: usize) -> Result { + self.0.is_final(state_id) + } + + unsafe fn is_final_unchecked(&self, state_id: usize) -> bool { + self.0.is_final_unchecked(state_id) + } + fn get_trs(&self, state_id: usize) -> Result { self.0.get_trs(state_id) } diff --git a/rustfst/src/algorithms/factor_weight/factor_weight_fst.rs b/rustfst/src/algorithms/factor_weight/factor_weight_fst.rs index 45e30e996..7966ae140 100644 --- a/rustfst/src/algorithms/factor_weight/factor_weight_fst.rs +++ b/rustfst/src/algorithms/factor_weight/factor_weight_fst.rs @@ -50,6 +50,14 @@ where self.0.num_trs_unchecked(s) } + fn is_final(&self, state_id: usize) -> Result { + self.0.is_final(state_id) + } + + unsafe fn is_final_unchecked(&self, state_id: usize) -> bool { + self.0.is_final_unchecked(state_id) + } + fn get_trs(&self, state_id: usize) -> Result { self.0.get_trs(state_id) } diff --git a/rustfst/src/algorithms/lazy/cache/arc_cache.rs b/rustfst/src/algorithms/lazy/cache/arc_cache.rs index 5d3153b48..a4470acad 100644 --- a/rustfst/src/algorithms/lazy/cache/arc_cache.rs +++ b/rustfst/src/algorithms/lazy/cache/arc_cache.rs @@ -52,4 +52,12 @@ impl> FstCache for Arc { fn len_final_weights(&self) -> usize { self.deref().len_final_weights() } + + fn is_final(&self, state_id: usize) -> CacheStatus { + self.deref().is_final(state_id) + } + + unsafe fn is_final_unchecked(&self, state_id: usize) -> bool { + self.deref().is_final_unchecked(state_id) + } } diff --git a/rustfst/src/algorithms/lazy/cache/cache_status.rs b/rustfst/src/algorithms/lazy/cache/cache_status.rs index 7c6ae0892..cf8d00dc5 100644 --- a/rustfst/src/algorithms/lazy/cache/cache_status.rs +++ b/rustfst/src/algorithms/lazy/cache/cache_status.rs @@ -1,4 +1,4 @@ -#[derive(Debug, Clone, Copy)] +#[derive(Debug, Clone, Copy, PartialOrd, PartialEq)] pub enum CacheStatus { NotComputed, Computed(T), @@ -12,10 +12,38 @@ impl CacheStatus { } } + pub fn ok_or(self, err: E) -> Result { + match self { + CacheStatus::Computed(v) => Ok(v), + CacheStatus::NotComputed => Err(err), + } + } + + pub fn ok_or_else E>(self, err: F) -> Result { + match self { + CacheStatus::Computed(v) => Ok(v), + CacheStatus::NotComputed => Err(err()), + } + } + + pub fn unwrap(self) -> T { + match self { + CacheStatus::Computed(e) => e, + CacheStatus::NotComputed => unreachable!(), + } + } + pub fn into_option(self) -> Option { match self { CacheStatus::Computed(e) => Some(e), CacheStatus::NotComputed => None, } } + + pub fn is_computed(&self) -> bool { + match self { + CacheStatus::Computed(_) => true, + CacheStatus::NotComputed => false, + } + } } diff --git a/rustfst/src/algorithms/lazy/cache/first_cache.rs b/rustfst/src/algorithms/lazy/cache/first_cache.rs index b503bfbd2..8d09dd799 100644 --- a/rustfst/src/algorithms/lazy/cache/first_cache.rs +++ b/rustfst/src/algorithms/lazy/cache/first_cache.rs @@ -83,4 +83,12 @@ impl> FstCache for FirstCache { fn len_final_weights(&self) -> usize { self.cache.len_final_weights() } + + fn is_final(&self, state_id: usize) -> CacheStatus { + self.cache.is_final(state_id) + } + + unsafe fn is_final_unchecked(&self, state_id: usize) -> bool { + self.cache.is_final_unchecked(state_id) + } } diff --git a/rustfst/src/algorithms/lazy/cache/fst_cache.rs b/rustfst/src/algorithms/lazy/cache/fst_cache.rs index 179018f29..f73bdcb16 100644 --- a/rustfst/src/algorithms/lazy/cache/fst_cache.rs +++ b/rustfst/src/algorithms/lazy/cache/fst_cache.rs @@ -22,4 +22,7 @@ pub trait FstCache: Debug { fn len_trs(&self) -> usize; fn len_final_weights(&self) -> usize; + + fn is_final(&self, state_id: StateId) -> CacheStatus; + unsafe fn is_final_unchecked(&self, state_id: StateId) -> bool; } diff --git a/rustfst/src/algorithms/lazy/cache/simple_hash_map_cache.rs b/rustfst/src/algorithms/lazy/cache/simple_hash_map_cache.rs index ad3005bdc..acad80284 100644 --- a/rustfst/src/algorithms/lazy/cache/simple_hash_map_cache.rs +++ b/rustfst/src/algorithms/lazy/cache/simple_hash_map_cache.rs @@ -145,4 +145,21 @@ impl FstCache for SimpleHashMapCache { let cached_data = self.final_weights.lock().unwrap(); cached_data.data.len() } + + fn is_final(&self, state_id: usize) -> CacheStatus { + match self.final_weights.lock().unwrap().data.get(&state_id) { + Some(e) => CacheStatus::Computed(e.is_some()), + None => CacheStatus::NotComputed, + } + } + + unsafe fn is_final_unchecked(&self, state_id: usize) -> bool { + self.final_weights + .lock() + .unwrap() + .data + .get(&state_id) + .unwrap() + .is_some() + } } diff --git a/rustfst/src/algorithms/lazy/cache/simple_vec_cache.rs b/rustfst/src/algorithms/lazy/cache/simple_vec_cache.rs index e1c35c5c0..fea3bb68a 100644 --- a/rustfst/src/algorithms/lazy/cache/simple_vec_cache.rs +++ b/rustfst/src/algorithms/lazy/cache/simple_vec_cache.rs @@ -147,4 +147,23 @@ impl FstCache for SimpleVecCache { let cached_data = self.final_weights.lock().unwrap(); cached_data.data.len() } + + fn is_final(&self, state_id: usize) -> CacheStatus { + let cached_data = self.final_weights.lock().unwrap(); + match cached_data.data.get(state_id) { + Some(e) => match e { + CacheStatus::Computed(v) => CacheStatus::Computed(v.is_some()), + CacheStatus::NotComputed => CacheStatus::NotComputed, + }, + None => CacheStatus::NotComputed, + } + } + + unsafe fn is_final_unchecked(&self, state_id: usize) -> bool { + let cached_data = self.final_weights.lock().unwrap(); + match cached_data.data.get_unchecked(state_id) { + CacheStatus::Computed(e) => e.is_some(), + CacheStatus::NotComputed => unreachable!(), + } + } } diff --git a/rustfst/src/algorithms/lazy/fst_op.rs b/rustfst/src/algorithms/lazy/fst_op.rs index 47d23aee9..40641039c 100644 --- a/rustfst/src/algorithms/lazy/fst_op.rs +++ b/rustfst/src/algorithms/lazy/fst_op.rs @@ -9,7 +9,7 @@ use crate::{StateId, TrsVec}; pub trait FstOp: Debug { // was FstImpl fn compute_start(&self) -> Result>; - fn compute_trs(&self, id: usize) -> Result>; + fn compute_trs(&self, id: StateId) -> Result>; fn compute_final_weight(&self, id: StateId) -> Result>; // Computed at construction time diff --git a/rustfst/src/algorithms/lazy/lazy_fst.rs b/rustfst/src/algorithms/lazy/lazy_fst.rs index b5b3befe2..b97272477 100644 --- a/rustfst/src/algorithms/lazy/lazy_fst.rs +++ b/rustfst/src/algorithms/lazy/lazy_fst.rs @@ -67,6 +67,16 @@ impl, Cache: FstCache> CoreFst for LazyFst Result { + self.cache + .is_final(state_id) + .ok_or_else(|| format_err!("Final weight for state {} not computed yet", state_id)) + } + + unsafe fn is_final_unchecked(&self, state_id: usize) -> bool { + self.cache.is_final_unchecked(state_id) + } + fn get_trs(&self, state_id: usize) -> Result { match self.cache.get_trs(state_id) { CacheStatus::Computed(trs) => Ok(trs), diff --git a/rustfst/src/algorithms/lazy/lazy_fst_2.rs b/rustfst/src/algorithms/lazy/lazy_fst_2.rs index 68f0e7a2c..a6bc6f355 100644 --- a/rustfst/src/algorithms/lazy/lazy_fst_2.rs +++ b/rustfst/src/algorithms/lazy/lazy_fst_2.rs @@ -65,6 +65,19 @@ impl, Cache: FstCache> CoreFst for LazyFst2 Result { + match self.cache.is_final(state_id) { + CacheStatus::Computed(e) => Ok(e), + CacheStatus::NotComputed => { + bail!("Final weight for state {} not computed yet", state_id) + } + } + } + + unsafe fn is_final_unchecked(&self, state_id: usize) -> bool { + self.cache.is_final_unchecked(state_id) + } + fn get_trs(&self, state_id: usize) -> Result { match self.cache.get_trs(state_id) { CacheStatus::Computed(trs) => Ok(trs), diff --git a/rustfst/src/algorithms/replace/replace_fst.rs b/rustfst/src/algorithms/replace/replace_fst.rs index 355df5fa1..96e75c87b 100644 --- a/rustfst/src/algorithms/replace/replace_fst.rs +++ b/rustfst/src/algorithms/replace/replace_fst.rs @@ -74,6 +74,14 @@ where self.0.num_trs_unchecked(s) } + fn is_final(&self, state_id: usize) -> Result { + self.0.is_final(state_id) + } + + unsafe fn is_final_unchecked(&self, state_id: usize) -> bool { + self.0.is_final_unchecked(state_id) + } + fn get_trs(&self, state_id: usize) -> Result { self.0.get_trs(state_id) } diff --git a/rustfst/src/algorithms/rm_epsilon/rm_epsilon_fst.rs b/rustfst/src/algorithms/rm_epsilon/rm_epsilon_fst.rs index 52ad4bd5d..11d078315 100644 --- a/rustfst/src/algorithms/rm_epsilon/rm_epsilon_fst.rs +++ b/rustfst/src/algorithms/rm_epsilon/rm_epsilon_fst.rs @@ -60,6 +60,14 @@ where self.0.num_trs_unchecked(s) } + fn is_final(&self, state_id: usize) -> Result { + self.0.is_final(state_id) + } + + unsafe fn is_final_unchecked(&self, state_id: usize) -> bool { + self.0.is_final_unchecked(state_id) + } + fn get_trs(&self, state_id: usize) -> Result { self.0.get_trs(state_id) } diff --git a/rustfst/src/algorithms/union/union_fst.rs b/rustfst/src/algorithms/union/union_fst.rs index 2096db2b3..07e9af039 100644 --- a/rustfst/src/algorithms/union/union_fst.rs +++ b/rustfst/src/algorithms/union/union_fst.rs @@ -84,6 +84,14 @@ where self.0.num_trs_unchecked(s) } + fn is_final(&self, state_id: usize) -> Result { + self.0.is_final(state_id) + } + + unsafe fn is_final_unchecked(&self, state_id: usize) -> bool { + self.0.is_final_unchecked(state_id) + } + fn get_trs(&self, state_id: usize) -> Result { self.0.get_trs(state_id) } diff --git a/rustfst/src/fst_impls/arc.rs b/rustfst/src/fst_impls/arc.rs index c223c61d0..b747adde2 100644 --- a/rustfst/src/fst_impls/arc.rs +++ b/rustfst/src/fst_impls/arc.rs @@ -63,6 +63,14 @@ impl> CoreFst for Arc { self.deref().num_trs_unchecked(s) } + fn is_final(&self, state_id: usize) -> Result { + self.deref().is_final(state_id) + } + + unsafe fn is_final_unchecked(&self, state_id: usize) -> bool { + self.deref().is_final_unchecked(state_id) + } + fn get_trs(&self, state_id: usize) -> Result { self.deref().get_trs(state_id) } diff --git a/rustfst/src/fst_impls/const_fst/fst.rs b/rustfst/src/fst_impls/const_fst/fst.rs index 23ea55ecb..4ece6446c 100644 --- a/rustfst/src/fst_impls/const_fst/fst.rs +++ b/rustfst/src/fst_impls/const_fst/fst.rs @@ -65,6 +65,19 @@ impl CoreFst for ConstFst { self.states.get_unchecked(s).ntrs } + fn is_final(&self, state_id: usize) -> Result { + let s = self + .states + .get(state_id) + .ok_or_else(|| format_err!("State {:?} doesn't exist", state_id))?; + Ok(s.final_weight.is_some()) + } + + unsafe fn is_final_unchecked(&self, state_id: usize) -> bool { + let s = self.states.get_unchecked(state_id); + s.final_weight.is_some() + } + fn get_trs(&self, state_id: usize) -> Result { let state = self .states diff --git a/rustfst/src/fst_impls/vector_fst/fst.rs b/rustfst/src/fst_impls/vector_fst/fst.rs index 8272f2464..5d0fc0c8b 100644 --- a/rustfst/src/fst_impls/vector_fst/fst.rs +++ b/rustfst/src/fst_impls/vector_fst/fst.rs @@ -67,6 +67,19 @@ impl CoreFst for VectorFst { self.states.get_unchecked(s).trs.len() } + fn is_final(&self, state_id: usize) -> Result { + let s = self + .states + .get(state_id) + .ok_or_else(|| format_err!("State {:?} doesn't exist", state_id))?; + Ok(s.final_weight.is_some()) + } + + unsafe fn is_final_unchecked(&self, state_id: usize) -> bool { + let s = self.states.get_unchecked(state_id); + s.final_weight.is_some() + } + fn get_trs(&self, state_id: usize) -> Result { let state = self .states diff --git a/rustfst/src/fst_traits/fst.rs b/rustfst/src/fst_traits/fst.rs index 49b44dddf..ceab6afad 100644 --- a/rustfst/src/fst_traits/fst.rs +++ b/rustfst/src/fst_traits/fst.rs @@ -110,11 +110,7 @@ pub trait CoreFst { /// assert_eq!(fst.is_final(s2).unwrap(), true); /// assert!(fst.is_final(s2 + 1).is_err()); /// ``` - #[inline] - fn is_final(&self, state_id: StateId) -> Result { - let w = self.final_weight(state_id)?; - Ok(w.is_some()) - } + fn is_final(&self, state_id: StateId) -> Result; /// Returns whether or not the state with identifier passed as parameters is a final state. /// @@ -122,10 +118,7 @@ pub trait CoreFst { /// /// Unsafe behaviour if `state` is not present in Fst. /// - #[inline] - unsafe fn is_final_unchecked(&self, state: StateId) -> bool { - self.final_weight_unchecked(state).is_some() - } + unsafe fn is_final_unchecked(&self, state_id: StateId) -> bool; /// Check whether a state is the start state or not. #[inline] From fb8382fc14cfbcbc76aac123b36dbaea8048bb51 Mon Sep 17 00:00:00 2001 From: Garvys Date: Wed, 23 Sep 2020 17:36:24 +0200 Subject: [PATCH 02/12] Add num_input_epsilons_unchecked and num_output_epsilons_unchecked --- rustfst/src/algorithms/closure/closure_fst.rs | 8 ++++++++ rustfst/src/algorithms/compose/add_on.rs | 8 ++++++++ .../alt_sequence_compose_filter.rs | 14 ++++++++------ rustfst/src/algorithms/compose/compose_fst.rs | 8 ++++++++ rustfst/src/algorithms/compose/matcher_fst.rs | 8 ++++++++ rustfst/src/algorithms/concat/concat_fst.rs | 8 ++++++++ .../algorithms/determinize/determinize_fsa.rs | 8 ++++++++ .../factor_weight/factor_weight_fst.rs | 8 ++++++++ rustfst/src/algorithms/lazy/cache/arc_cache.rs | 8 ++++++++ .../lazy/cache/cache_internal_types.rs | 7 +++++++ rustfst/src/algorithms/lazy/cache/first_cache.rs | 8 ++++++++ rustfst/src/algorithms/lazy/cache/fst_cache.rs | 3 +++ .../lazy/cache/simple_hash_map_cache.rs | 11 +++++++++++ .../algorithms/lazy/cache/simple_vec_cache.rs | 10 ++++++++++ rustfst/src/algorithms/lazy/lazy_fst.rs | 8 ++++++++ rustfst/src/algorithms/lazy/lazy_fst_2.rs | 8 ++++++++ rustfst/src/algorithms/replace/replace_fst.rs | 8 ++++++++ .../src/algorithms/rm_epsilon/rm_epsilon_fst.rs | 8 ++++++++ rustfst/src/algorithms/union/union_fst.rs | 8 ++++++++ rustfst/src/fst_impls/arc.rs | 8 ++++++++ rustfst/src/fst_impls/const_fst/fst.rs | 8 ++++++++ rustfst/src/fst_impls/vector_fst/fst.rs | 8 ++++++++ rustfst/src/fst_traits/fst.rs | 16 ++++++++++++++++ 23 files changed, 191 insertions(+), 6 deletions(-) diff --git a/rustfst/src/algorithms/closure/closure_fst.rs b/rustfst/src/algorithms/closure/closure_fst.rs index f639d97ac..a982e632c 100644 --- a/rustfst/src/algorithms/closure/closure_fst.rs +++ b/rustfst/src/algorithms/closure/closure_fst.rs @@ -119,9 +119,17 @@ where self.0.num_input_epsilons(state) } + unsafe fn num_input_epsilons_unchecked(&self, state: usize) -> usize { + self.0.num_input_epsilons_unchecked(state) + } + fn num_output_epsilons(&self, state: usize) -> Result { self.0.num_output_epsilons(state) } + + unsafe fn num_output_epsilons_unchecked(&self, state: usize) -> usize { + self.0.num_output_epsilons_unchecked(state) + } } impl<'a, W, F> StateIterator<'a> for ClosureFst diff --git a/rustfst/src/algorithms/compose/add_on.rs b/rustfst/src/algorithms/compose/add_on.rs index 8d2d4df67..3665f8688 100644 --- a/rustfst/src/algorithms/compose/add_on.rs +++ b/rustfst/src/algorithms/compose/add_on.rs @@ -81,9 +81,17 @@ impl, T> CoreFst for FstAddOn { self.fst.num_input_epsilons(state) } + unsafe fn num_input_epsilons_unchecked(&self, state: usize) -> usize { + self.fst.num_input_epsilons_unchecked(state) + } + fn num_output_epsilons(&self, state: usize) -> Result { self.fst.num_output_epsilons(state) } + + unsafe fn num_output_epsilons_unchecked(&self, state: usize) -> usize { + self.fst.num_output_epsilons_unchecked(state) + } } impl<'a, F: StateIterator<'a>, T> StateIterator<'a> for FstAddOn { diff --git a/rustfst/src/algorithms/compose/compose_filters/alt_sequence_compose_filter.rs b/rustfst/src/algorithms/compose/compose_filters/alt_sequence_compose_filter.rs index 81633eb40..94bf66b48 100644 --- a/rustfst/src/algorithms/compose/compose_filters/alt_sequence_compose_filter.rs +++ b/rustfst/src/algorithms/compose/compose_filters/alt_sequence_compose_filter.rs @@ -93,12 +93,14 @@ impl, M2: Matcher> ComposeFilter self.s2 = s2; self.fs = filter_state.clone(); // TODO: Could probably use unchecked here as the state should exist. - let fst2 = self.fst2(); - let na2 = fst2.num_trs(self.s2)?; - let ne2 = fst2.num_input_epsilons(self.s2)?; - let fin2 = fst2.is_final(self.s2)?; - self.alleps2 = na2 == ne2 && !fin2; - self.noeps2 = ne2 == 0; + unsafe { + let fst2 = self.fst2(); + let na2 = fst2.num_trs_unchecked(self.s2); + let ne2 = fst2.num_input_epsilons_unchecked(self.s2); + let fin2 = fst2.is_final_unchecked(self.s2); + self.alleps2 = na2 == ne2 && !fin2; + self.noeps2 = ne2 == 0; + } } Ok(()) } diff --git a/rustfst/src/algorithms/compose/compose_fst.rs b/rustfst/src/algorithms/compose/compose_fst.rs index 964578ee0..0e0fc6538 100644 --- a/rustfst/src/algorithms/compose/compose_fst.rs +++ b/rustfst/src/algorithms/compose/compose_fst.rs @@ -153,9 +153,17 @@ where self.0.num_input_epsilons(state) } + unsafe fn num_input_epsilons_unchecked(&self, state: usize) -> usize { + self.0.num_input_epsilons_unchecked(state) + } + fn num_output_epsilons(&self, state: usize) -> Result { self.0.num_output_epsilons(state) } + + unsafe fn num_output_epsilons_unchecked(&self, state: usize) -> usize { + self.0.num_output_epsilons_unchecked(state) + } } impl<'a, W, CFB, Cache> StateIterator<'a> for ComposeFst diff --git a/rustfst/src/algorithms/compose/matcher_fst.rs b/rustfst/src/algorithms/compose/matcher_fst.rs index 8520f2a22..761788df9 100644 --- a/rustfst/src/algorithms/compose/matcher_fst.rs +++ b/rustfst/src/algorithms/compose/matcher_fst.rs @@ -136,9 +136,17 @@ impl, M, T> CoreFst for MatcherFst { self.fst_add_on.num_input_epsilons(state) } + unsafe fn num_input_epsilons_unchecked(&self, state: usize) -> usize { + self.fst_add_on.num_input_epsilons_unchecked(state) + } + fn num_output_epsilons(&self, state: usize) -> Result { self.fst_add_on.num_output_epsilons(state) } + + unsafe fn num_output_epsilons_unchecked(&self, state: usize) -> usize { + self.fst_add_on.num_output_epsilons_unchecked(state) + } } impl<'a, W, F: StateIterator<'a>, M, T> StateIterator<'a> for MatcherFst { diff --git a/rustfst/src/algorithms/concat/concat_fst.rs b/rustfst/src/algorithms/concat/concat_fst.rs index 4f3c7db16..e92a8a7ab 100644 --- a/rustfst/src/algorithms/concat/concat_fst.rs +++ b/rustfst/src/algorithms/concat/concat_fst.rs @@ -107,9 +107,17 @@ where self.0.num_input_epsilons(state) } + unsafe fn num_input_epsilons_unchecked(&self, state: usize) -> usize { + self.0.num_input_epsilons_unchecked(state) + } + fn num_output_epsilons(&self, state: usize) -> Result { self.0.num_output_epsilons(state) } + + unsafe fn num_output_epsilons_unchecked(&self, state: usize) -> usize { + self.0.num_output_epsilons_unchecked(state) + } } impl<'a, W, F> StateIterator<'a> for ConcatFst diff --git a/rustfst/src/algorithms/determinize/determinize_fsa.rs b/rustfst/src/algorithms/determinize/determinize_fsa.rs index 18c884015..0eb639ff1 100644 --- a/rustfst/src/algorithms/determinize/determinize_fsa.rs +++ b/rustfst/src/algorithms/determinize/determinize_fsa.rs @@ -67,9 +67,17 @@ where self.0.num_input_epsilons(state) } + unsafe fn num_input_epsilons_unchecked(&self, state: usize) -> usize { + self.0.num_input_epsilons_unchecked(state) + } + fn num_output_epsilons(&self, state: usize) -> Result { self.0.num_output_epsilons(state) } + + unsafe fn num_output_epsilons_unchecked(&self, state: usize) -> usize { + self.0.num_output_epsilons_unchecked(state) + } } impl<'a, W, F, CD> StateIterator<'a> for DeterminizeFsa diff --git a/rustfst/src/algorithms/factor_weight/factor_weight_fst.rs b/rustfst/src/algorithms/factor_weight/factor_weight_fst.rs index 7966ae140..22c47791c 100644 --- a/rustfst/src/algorithms/factor_weight/factor_weight_fst.rs +++ b/rustfst/src/algorithms/factor_weight/factor_weight_fst.rs @@ -74,9 +74,17 @@ where self.0.num_input_epsilons(state) } + unsafe fn num_input_epsilons_unchecked(&self, state: usize) -> usize { + self.0.num_input_epsilons_unchecked(state) + } + fn num_output_epsilons(&self, state: usize) -> Result { self.0.num_output_epsilons(state) } + + unsafe fn num_output_epsilons_unchecked(&self, state: usize) -> usize { + self.0.num_output_epsilons_unchecked(state) + } } impl<'a, W, F, B, FI> StateIterator<'a> for FactorWeightFst diff --git a/rustfst/src/algorithms/lazy/cache/arc_cache.rs b/rustfst/src/algorithms/lazy/cache/arc_cache.rs index a4470acad..0b8e6ad3a 100644 --- a/rustfst/src/algorithms/lazy/cache/arc_cache.rs +++ b/rustfst/src/algorithms/lazy/cache/arc_cache.rs @@ -41,10 +41,18 @@ impl> FstCache for Arc { self.deref().num_input_epsilons(id) } + unsafe fn num_input_epsilons_unchecked(&self, id: usize) -> usize { + self.deref().num_input_epsilons_unchecked(id) + } + fn num_output_epsilons(&self, id: usize) -> Option { self.deref().num_output_epsilons(id) } + unsafe fn num_output_epsilons_unchecked(&self, id: usize) -> usize { + self.deref().num_output_epsilons_unchecked(id) + } + fn len_trs(&self) -> usize { self.deref().len_trs() } diff --git a/rustfst/src/algorithms/lazy/cache/cache_internal_types.rs b/rustfst/src/algorithms/lazy/cache/cache_internal_types.rs index 8f2304c67..632d41b28 100644 --- a/rustfst/src/algorithms/lazy/cache/cache_internal_types.rs +++ b/rustfst/src/algorithms/lazy/cache/cache_internal_types.rs @@ -62,6 +62,13 @@ impl CachedData>> { None => CacheStatus::NotComputed, } } + + pub unsafe fn get_unchecked(&self, idx: usize) -> &T { + match self.data.get_unchecked(idx) { + CacheStatus::Computed(e) => e, + CacheStatus::NotComputed => unreachable!(), + } + } } impl Default for CachedData> { diff --git a/rustfst/src/algorithms/lazy/cache/first_cache.rs b/rustfst/src/algorithms/lazy/cache/first_cache.rs index 8d09dd799..73fa6956b 100644 --- a/rustfst/src/algorithms/lazy/cache/first_cache.rs +++ b/rustfst/src/algorithms/lazy/cache/first_cache.rs @@ -72,10 +72,18 @@ impl> FstCache for FirstCache { self.cache.num_input_epsilons(id) } + unsafe fn num_input_epsilons_unchecked(&self, id: usize) -> usize { + self.cache.num_input_epsilons_unchecked(id) + } + fn num_output_epsilons(&self, id: usize) -> Option { self.cache.num_output_epsilons(id) } + unsafe fn num_output_epsilons_unchecked(&self, id: usize) -> usize { + self.cache.num_output_epsilons_unchecked(id) + } + fn len_trs(&self) -> usize { self.cache.len_trs() } diff --git a/rustfst/src/algorithms/lazy/cache/fst_cache.rs b/rustfst/src/algorithms/lazy/cache/fst_cache.rs index f73bdcb16..25fe164aa 100644 --- a/rustfst/src/algorithms/lazy/cache/fst_cache.rs +++ b/rustfst/src/algorithms/lazy/cache/fst_cache.rs @@ -18,7 +18,10 @@ pub trait FstCache: Debug { fn num_trs(&self, id: StateId) -> Option; fn num_input_epsilons(&self, id: usize) -> Option; + unsafe fn num_input_epsilons_unchecked(&self, id: usize) -> usize; + fn num_output_epsilons(&self, id: usize) -> Option; + unsafe fn num_output_epsilons_unchecked(&self, id: usize) -> usize; fn len_trs(&self) -> usize; fn len_final_weights(&self) -> usize; diff --git a/rustfst/src/algorithms/lazy/cache/simple_hash_map_cache.rs b/rustfst/src/algorithms/lazy/cache/simple_hash_map_cache.rs index acad80284..0bfdd4869 100644 --- a/rustfst/src/algorithms/lazy/cache/simple_hash_map_cache.rs +++ b/rustfst/src/algorithms/lazy/cache/simple_hash_map_cache.rs @@ -5,6 +5,7 @@ use crate::algorithms::lazy::cache::cache_internal_types::{CachedData, StartStat use crate::algorithms::lazy::{CacheStatus, FstCache}; use crate::semirings::Semiring; use crate::{StateId, Trs, TrsVec, EPS_LABEL}; +use unsafe_unwrap::UnsafeUnwrap; #[derive(Debug)] pub struct SimpleHashMapCache { @@ -131,11 +132,21 @@ impl FstCache for SimpleHashMapCache { cached_data.data.get(&id).map(|v| v.niepsilons) } + unsafe fn num_input_epsilons_unchecked(&self, id: usize) -> usize { + let cached_data = self.trs.lock().unwrap(); + cached_data.data.get(&id).unsafe_unwrap().niepsilons + } + fn num_output_epsilons(&self, id: usize) -> Option { let cached_data = self.trs.lock().unwrap(); cached_data.data.get(&id).map(|v| v.noepsilons) } + unsafe fn num_output_epsilons_unchecked(&self, id: usize) -> usize { + let cached_data = self.trs.lock().unwrap(); + cached_data.data.get(&id).unsafe_unwrap().noepsilons + } + fn len_trs(&self) -> usize { let cached_data = self.trs.lock().unwrap(); cached_data.data.len() diff --git a/rustfst/src/algorithms/lazy/cache/simple_vec_cache.rs b/rustfst/src/algorithms/lazy/cache/simple_vec_cache.rs index fea3bb68a..b0b6e504e 100644 --- a/rustfst/src/algorithms/lazy/cache/simple_vec_cache.rs +++ b/rustfst/src/algorithms/lazy/cache/simple_vec_cache.rs @@ -133,11 +133,21 @@ impl FstCache for SimpleVecCache { cached_data.get(id).map(|e| e.niepsilons).into_option() } + unsafe fn num_input_epsilons_unchecked(&self, id: usize) -> usize { + let cached_data = self.trs.lock().unwrap(); + cached_data.get_unchecked(id).niepsilons + } + fn num_output_epsilons(&self, id: usize) -> Option { let cached_data = self.trs.lock().unwrap(); cached_data.get(id).map(|e| e.noepsilons).into_option() } + unsafe fn num_output_epsilons_unchecked(&self, id: usize) -> usize { + let cached_data = self.trs.lock().unwrap(); + cached_data.get_unchecked(id).noepsilons + } + fn len_trs(&self) -> usize { let cached_data = self.trs.lock().unwrap(); cached_data.data.len() diff --git a/rustfst/src/algorithms/lazy/lazy_fst.rs b/rustfst/src/algorithms/lazy/lazy_fst.rs index b97272477..0c1576f93 100644 --- a/rustfst/src/algorithms/lazy/lazy_fst.rs +++ b/rustfst/src/algorithms/lazy/lazy_fst.rs @@ -102,11 +102,19 @@ impl, Cache: FstCache> CoreFst for LazyFst usize { + self.cache.num_input_epsilons_unchecked(state) + } + fn num_output_epsilons(&self, state: usize) -> Result { self.cache .num_output_epsilons(state) .ok_or_else(|| format_err!("State {:?} doesn't exist", state)) } + + unsafe fn num_output_epsilons_unchecked(&self, state: usize) -> usize { + self.cache.num_output_epsilons_unchecked(state) + } } impl<'a, W, Op, Cache> StateIterator<'a> for LazyFst diff --git a/rustfst/src/algorithms/lazy/lazy_fst_2.rs b/rustfst/src/algorithms/lazy/lazy_fst_2.rs index a6bc6f355..dcec09a65 100644 --- a/rustfst/src/algorithms/lazy/lazy_fst_2.rs +++ b/rustfst/src/algorithms/lazy/lazy_fst_2.rs @@ -104,11 +104,19 @@ impl, Cache: FstCache> CoreFst for LazyFst2 usize { + self.cache.num_input_epsilons_unchecked(state) + } + fn num_output_epsilons(&self, state: usize) -> Result { self.cache .num_output_epsilons(state) .ok_or_else(|| format_err!("State {:?} doesn't exist", state)) } + + unsafe fn num_output_epsilons_unchecked(&self, state: usize) -> usize { + self.cache.num_output_epsilons_unchecked(state) + } } impl<'a, W, Op, Cache> StateIterator<'a> for LazyFst2 diff --git a/rustfst/src/algorithms/replace/replace_fst.rs b/rustfst/src/algorithms/replace/replace_fst.rs index 96e75c87b..23e8a930e 100644 --- a/rustfst/src/algorithms/replace/replace_fst.rs +++ b/rustfst/src/algorithms/replace/replace_fst.rs @@ -98,9 +98,17 @@ where self.0.num_input_epsilons(state) } + unsafe fn num_input_epsilons_unchecked(&self, state: usize) -> usize { + self.0.num_input_epsilons_unchecked(state) + } + fn num_output_epsilons(&self, state: usize) -> Result { self.0.num_output_epsilons(state) } + + unsafe fn num_output_epsilons_unchecked(&self, state: usize) -> usize { + self.0.num_output_epsilons_unchecked(state) + } } impl<'a, W, F, B> StateIterator<'a> for ReplaceFst diff --git a/rustfst/src/algorithms/rm_epsilon/rm_epsilon_fst.rs b/rustfst/src/algorithms/rm_epsilon/rm_epsilon_fst.rs index 11d078315..57a299e6b 100644 --- a/rustfst/src/algorithms/rm_epsilon/rm_epsilon_fst.rs +++ b/rustfst/src/algorithms/rm_epsilon/rm_epsilon_fst.rs @@ -84,9 +84,17 @@ where self.0.num_input_epsilons(state) } + unsafe fn num_input_epsilons_unchecked(&self, state: usize) -> usize { + self.0.num_input_epsilons_unchecked(state) + } + fn num_output_epsilons(&self, state: usize) -> Result { self.0.num_output_epsilons(state) } + + unsafe fn num_output_epsilons_unchecked(&self, state: usize) -> usize { + self.0.num_output_epsilons_unchecked(state) + } } impl<'a, W, F, B> StateIterator<'a> for RmEpsilonFst diff --git a/rustfst/src/algorithms/union/union_fst.rs b/rustfst/src/algorithms/union/union_fst.rs index 07e9af039..0de23c759 100644 --- a/rustfst/src/algorithms/union/union_fst.rs +++ b/rustfst/src/algorithms/union/union_fst.rs @@ -108,9 +108,17 @@ where self.0.num_input_epsilons(state) } + unsafe fn num_input_epsilons_unchecked(&self, state: usize) -> usize { + self.0.num_input_epsilons_unchecked(state) + } + fn num_output_epsilons(&self, state: usize) -> Result { self.0.num_output_epsilons(state) } + + unsafe fn num_output_epsilons_unchecked(&self, state: usize) -> usize { + self.0.num_output_epsilons_unchecked(state) + } } impl<'a, W, F> StateIterator<'a> for UnionFst diff --git a/rustfst/src/fst_impls/arc.rs b/rustfst/src/fst_impls/arc.rs index b747adde2..7209dbfda 100644 --- a/rustfst/src/fst_impls/arc.rs +++ b/rustfst/src/fst_impls/arc.rs @@ -87,9 +87,17 @@ impl> CoreFst for Arc { self.deref().num_input_epsilons(state) } + unsafe fn num_input_epsilons_unchecked(&self, state: usize) -> usize { + self.deref().num_input_epsilons_unchecked(state) + } + fn num_output_epsilons(&self, state: usize) -> Result { self.deref().num_output_epsilons(state) } + + unsafe fn num_output_epsilons_unchecked(&self, state: usize) -> usize { + self.deref().num_output_epsilons_unchecked(state) + } } impl<'a, W: Semiring + 'a, F: FstIterator<'a, W>> FstIterator<'a, W> for Arc { diff --git a/rustfst/src/fst_impls/const_fst/fst.rs b/rustfst/src/fst_impls/const_fst/fst.rs index 4ece6446c..2191db373 100644 --- a/rustfst/src/fst_impls/const_fst/fst.rs +++ b/rustfst/src/fst_impls/const_fst/fst.rs @@ -111,6 +111,10 @@ impl CoreFst for ConstFst { .niepsilons) } + unsafe fn num_input_epsilons_unchecked(&self, state: usize) -> usize { + self.states.get_unchecked(state).niepsilons + } + fn num_output_epsilons(&self, state: usize) -> Result { Ok(self .states @@ -118,4 +122,8 @@ impl CoreFst for ConstFst { .ok_or_else(|| format_err!("State {:?} doesn't exist", state))? .noepsilons) } + + unsafe fn num_output_epsilons_unchecked(&self, state: usize) -> usize { + self.states.get_unchecked(state).noepsilons + } } diff --git a/rustfst/src/fst_impls/vector_fst/fst.rs b/rustfst/src/fst_impls/vector_fst/fst.rs index 5d0fc0c8b..268bc7e09 100644 --- a/rustfst/src/fst_impls/vector_fst/fst.rs +++ b/rustfst/src/fst_impls/vector_fst/fst.rs @@ -107,6 +107,10 @@ impl CoreFst for VectorFst { .niepsilons) } + unsafe fn num_input_epsilons_unchecked(&self, state: usize) -> usize { + self.states.get_unchecked(state).niepsilons + } + fn num_output_epsilons(&self, state: usize) -> Result { Ok(self .states @@ -114,4 +118,8 @@ impl CoreFst for VectorFst { .ok_or_else(|| format_err!("State {:?} doesn't exist", state))? .noepsilons) } + + unsafe fn num_output_epsilons_unchecked(&self, state: usize) -> usize { + self.states.get_unchecked(state).noepsilons + } } diff --git a/rustfst/src/fst_traits/fst.rs b/rustfst/src/fst_traits/fst.rs index ceab6afad..16abcd0cb 100644 --- a/rustfst/src/fst_traits/fst.rs +++ b/rustfst/src/fst_traits/fst.rs @@ -187,6 +187,14 @@ pub trait CoreFst { /// ``` fn num_input_epsilons(&self, state: StateId) -> Result; + /// Returns the number of trs with epsilon input labels leaving a state. + /// + /// # Safety + /// + /// Unsafe behaviour if `state` is not present in Fst. + /// + unsafe fn num_input_epsilons_unchecked(&self, state: StateId) -> usize; + /// Returns the number of trs with epsilon output labels leaving a state. /// /// # Example : @@ -210,6 +218,14 @@ pub trait CoreFst { /// assert_eq!(fst.num_output_epsilons(s1).unwrap(), 0); /// ``` fn num_output_epsilons(&self, state: StateId) -> Result; + + /// Returns the number of trs with epsilon output labels leaving a state. + /// + /// # Safety + /// + /// Unsafe behaviour if `state` is not present in Fst. + /// + unsafe fn num_output_epsilons_unchecked(&self, state: StateId) -> usize; } /// Trait defining the minimum interface necessary for a wFST. From 5152540f8b88120aa9e25754bd5daa2d4dce1332 Mon Sep 17 00:00:00 2001 From: Garvys Date: Wed, 23 Sep 2020 17:47:04 +0200 Subject: [PATCH 03/12] num_{in/out}epsilons now return CacheStatus --- rustfst/src/algorithms/lazy/cache/arc_cache.rs | 4 ++-- rustfst/src/algorithms/lazy/cache/first_cache.rs | 4 ++-- rustfst/src/algorithms/lazy/cache/fst_cache.rs | 4 ++-- .../algorithms/lazy/cache/simple_hash_map_cache.rs | 14 ++++++++++---- .../src/algorithms/lazy/cache/simple_vec_cache.rs | 8 ++++---- 5 files changed, 20 insertions(+), 14 deletions(-) diff --git a/rustfst/src/algorithms/lazy/cache/arc_cache.rs b/rustfst/src/algorithms/lazy/cache/arc_cache.rs index 0b8e6ad3a..97a2a1fda 100644 --- a/rustfst/src/algorithms/lazy/cache/arc_cache.rs +++ b/rustfst/src/algorithms/lazy/cache/arc_cache.rs @@ -37,7 +37,7 @@ impl> FstCache for Arc { self.deref().num_trs(id) } - fn num_input_epsilons(&self, id: usize) -> Option { + fn num_input_epsilons(&self, id: usize) -> CacheStatus { self.deref().num_input_epsilons(id) } @@ -45,7 +45,7 @@ impl> FstCache for Arc { self.deref().num_input_epsilons_unchecked(id) } - fn num_output_epsilons(&self, id: usize) -> Option { + fn num_output_epsilons(&self, id: usize) -> CacheStatus { self.deref().num_output_epsilons(id) } diff --git a/rustfst/src/algorithms/lazy/cache/first_cache.rs b/rustfst/src/algorithms/lazy/cache/first_cache.rs index 73fa6956b..b8db30104 100644 --- a/rustfst/src/algorithms/lazy/cache/first_cache.rs +++ b/rustfst/src/algorithms/lazy/cache/first_cache.rs @@ -68,7 +68,7 @@ impl> FstCache for FirstCache { self.cache.num_trs(id) } - fn num_input_epsilons(&self, id: usize) -> Option { + fn num_input_epsilons(&self, id: usize) -> CacheStatus { self.cache.num_input_epsilons(id) } @@ -76,7 +76,7 @@ impl> FstCache for FirstCache { self.cache.num_input_epsilons_unchecked(id) } - fn num_output_epsilons(&self, id: usize) -> Option { + fn num_output_epsilons(&self, id: usize) -> CacheStatus { self.cache.num_output_epsilons(id) } diff --git a/rustfst/src/algorithms/lazy/cache/fst_cache.rs b/rustfst/src/algorithms/lazy/cache/fst_cache.rs index 25fe164aa..cc1cccee6 100644 --- a/rustfst/src/algorithms/lazy/cache/fst_cache.rs +++ b/rustfst/src/algorithms/lazy/cache/fst_cache.rs @@ -17,10 +17,10 @@ pub trait FstCache: Debug { fn num_known_states(&self) -> usize; fn num_trs(&self, id: StateId) -> Option; - fn num_input_epsilons(&self, id: usize) -> Option; + fn num_input_epsilons(&self, id: usize) -> CacheStatus; unsafe fn num_input_epsilons_unchecked(&self, id: usize) -> usize; - fn num_output_epsilons(&self, id: usize) -> Option; + fn num_output_epsilons(&self, id: usize) -> CacheStatus; unsafe fn num_output_epsilons_unchecked(&self, id: usize) -> usize; fn len_trs(&self) -> usize; diff --git a/rustfst/src/algorithms/lazy/cache/simple_hash_map_cache.rs b/rustfst/src/algorithms/lazy/cache/simple_hash_map_cache.rs index 0bfdd4869..48bce01da 100644 --- a/rustfst/src/algorithms/lazy/cache/simple_hash_map_cache.rs +++ b/rustfst/src/algorithms/lazy/cache/simple_hash_map_cache.rs @@ -127,9 +127,12 @@ impl FstCache for SimpleHashMapCache { cached_data.data.get(&id).map(|v| v.trs.len()) } - fn num_input_epsilons(&self, id: usize) -> Option { + fn num_input_epsilons(&self, id: usize) -> CacheStatus { let cached_data = self.trs.lock().unwrap(); - cached_data.data.get(&id).map(|v| v.niepsilons) + match cached_data.data.get(&id) { + Some(e) => CacheStatus::Computed(e.niepsilons), + None => CacheStatus::NotComputed, + } } unsafe fn num_input_epsilons_unchecked(&self, id: usize) -> usize { @@ -137,9 +140,12 @@ impl FstCache for SimpleHashMapCache { cached_data.data.get(&id).unsafe_unwrap().niepsilons } - fn num_output_epsilons(&self, id: usize) -> Option { + fn num_output_epsilons(&self, id: usize) -> CacheStatus { let cached_data = self.trs.lock().unwrap(); - cached_data.data.get(&id).map(|v| v.noepsilons) + match cached_data.data.get(&id) { + Some(e) => CacheStatus::Computed(e.noepsilons), + None => CacheStatus::NotComputed, + } } unsafe fn num_output_epsilons_unchecked(&self, id: usize) -> usize { diff --git a/rustfst/src/algorithms/lazy/cache/simple_vec_cache.rs b/rustfst/src/algorithms/lazy/cache/simple_vec_cache.rs index b0b6e504e..85f433d2c 100644 --- a/rustfst/src/algorithms/lazy/cache/simple_vec_cache.rs +++ b/rustfst/src/algorithms/lazy/cache/simple_vec_cache.rs @@ -128,9 +128,9 @@ impl FstCache for SimpleVecCache { cached_data.get(id).map(|e| e.trs.len()).into_option() } - fn num_input_epsilons(&self, id: usize) -> Option { + fn num_input_epsilons(&self, id: usize) -> CacheStatus { let cached_data = self.trs.lock().unwrap(); - cached_data.get(id).map(|e| e.niepsilons).into_option() + cached_data.get(id).map(|e| e.niepsilons) } unsafe fn num_input_epsilons_unchecked(&self, id: usize) -> usize { @@ -138,9 +138,9 @@ impl FstCache for SimpleVecCache { cached_data.get_unchecked(id).niepsilons } - fn num_output_epsilons(&self, id: usize) -> Option { + fn num_output_epsilons(&self, id: usize) -> CacheStatus { let cached_data = self.trs.lock().unwrap(); - cached_data.get(id).map(|e| e.noepsilons).into_option() + cached_data.get(id).map(|e| e.noepsilons) } unsafe fn num_output_epsilons_unchecked(&self, id: usize) -> usize { From 6cb4cf314193ceb670884e9e7cd8f772db7b1819 Mon Sep 17 00:00:00 2001 From: Garvys Date: Wed, 23 Sep 2020 17:55:08 +0200 Subject: [PATCH 04/12] Fix ci --- .../rustfst_python_bench/algorithms/supported_algorithms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rustfst-python-bench/rustfst_python_bench/algorithms/supported_algorithms.py b/rustfst-python-bench/rustfst_python_bench/algorithms/supported_algorithms.py index b36c9b477..793a4e3ef 100644 --- a/rustfst-python-bench/rustfst_python_bench/algorithms/supported_algorithms.py +++ b/rustfst-python-bench/rustfst_python_bench/algorithms/supported_algorithms.py @@ -37,5 +37,5 @@ def get(cls, algoname): SupportedAlgorithms.register("reverse", ReverseAlgorithm) # SupportedAlgorithms.register("rmfinalepsilon", RmFinalEpsilonAlgorithm) SupportedAlgorithms.register("shortestpath", ShortestPathAlgorithm) -SupportedAlgorithms.register("compose", ComposeAlgorithm) +# SupportedAlgorithms.register("compose", ComposeAlgorithm) From 23a749839b92d930795f313b135409a5cc79dfb0 Mon Sep 17 00:00:00 2001 From: Garvys Date: Wed, 23 Sep 2020 18:11:36 +0200 Subject: [PATCH 05/12] Move to unchecked for AltSequenceComposeFilter's and SequenceComposeFilter's set_state --- .../alt_sequence_compose_filter.rs | 1 - .../compose_filters/sequence_compose_filter.rs | 15 ++++++++------- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/rustfst/src/algorithms/compose/compose_filters/alt_sequence_compose_filter.rs b/rustfst/src/algorithms/compose/compose_filters/alt_sequence_compose_filter.rs index 94bf66b48..29a0b75ac 100644 --- a/rustfst/src/algorithms/compose/compose_filters/alt_sequence_compose_filter.rs +++ b/rustfst/src/algorithms/compose/compose_filters/alt_sequence_compose_filter.rs @@ -92,7 +92,6 @@ impl, M2: Matcher> ComposeFilter self.s1 = s1; self.s2 = s2; self.fs = filter_state.clone(); - // TODO: Could probably use unchecked here as the state should exist. unsafe { let fst2 = self.fst2(); let na2 = fst2.num_trs_unchecked(self.s2); diff --git a/rustfst/src/algorithms/compose/compose_filters/sequence_compose_filter.rs b/rustfst/src/algorithms/compose/compose_filters/sequence_compose_filter.rs index 61dd30345..924a97bd5 100644 --- a/rustfst/src/algorithms/compose/compose_filters/sequence_compose_filter.rs +++ b/rustfst/src/algorithms/compose/compose_filters/sequence_compose_filter.rs @@ -90,13 +90,14 @@ impl, M2: Matcher> ComposeFilter self.s1 = s1; self.s2 = s2; self.fs = filter_state.clone(); - // TODO: Could probably use unchecked here as the state should exist. - let fst1 = self.fst1(); - let na1 = fst1.num_trs(self.s1)?; - let ne1 = fst1.num_output_epsilons(self.s1)?; - let fin1 = fst1.is_final(self.s1)?; - self.alleps1 = na1 == ne1 && !fin1; - self.noeps1 = ne1 == 0; + unsafe { + let fst1 = self.fst1(); + let na1 = fst1.num_trs_unchecked(self.s1)?; + let ne1 = fst1.num_output_epsilons_unchecked(self.s1)?; + let fin1 = fst1.is_final_unchecked(self.s1)?; + self.alleps1 = na1 == ne1 && !fin1; + self.noeps1 = ne1 == 0; + } } Ok(()) } From ca82fe62a7e4e55afb0d3030f8b39af9b7ef8fdb Mon Sep 17 00:00:00 2001 From: Garvys Date: Wed, 23 Sep 2020 18:18:55 +0200 Subject: [PATCH 06/12] Move to unchecked for all compose filter --- .../compose_filters/match_compose_filter.rs | 26 ++++++++++--------- .../sequence_compose_filter.rs | 6 ++--- 2 files changed, 17 insertions(+), 15 deletions(-) diff --git a/rustfst/src/algorithms/compose/compose_filters/match_compose_filter.rs b/rustfst/src/algorithms/compose/compose_filters/match_compose_filter.rs index 81ba7b40b..f426d19eb 100644 --- a/rustfst/src/algorithms/compose/compose_filters/match_compose_filter.rs +++ b/rustfst/src/algorithms/compose/compose_filters/match_compose_filter.rs @@ -96,23 +96,25 @@ impl, M2: Matcher> ComposeFilter self.s2 = s2; self.fs = filter_state.clone(); - let fst1 = self.fst1(); - let fst2 = self.fst2(); + unsafe { + let fst1 = self.fst1(); + let fst2 = self.fst2(); - let na1 = fst1.num_trs(s1)?; - let na2 = fst2.num_trs(s2)?; + let na1 = fst1.num_trs_unchecked(s1); + let na2 = fst2.num_trs_unchecked(s2); - let ne1 = fst1.num_output_epsilons(s1)?; - let ne2 = fst2.num_input_epsilons(s2)?; + let ne1 = fst1.num_output_epsilons_unchecked(s1); + let ne2 = fst2.num_input_epsilons_unchecked(s2); - let f1 = fst1.is_final(s1)?; - let f2 = fst2.is_final(s2)?; + let f1 = fst1.is_final_unchecked(s1); + let f2 = fst2.is_final_unchecked(s2); - self.alleps1 = na1 == ne1 && !f1; - self.alleps2 = na2 == ne2 && !f2; + self.alleps1 = na1 == ne1 && !f1; + self.alleps2 = na2 == ne2 && !f2; - self.noeps1 = ne1 == 0; - self.noeps2 = ne2 == 0; + self.noeps1 = ne1 == 0; + self.noeps2 = ne2 == 0; + } } Ok(()) } diff --git a/rustfst/src/algorithms/compose/compose_filters/sequence_compose_filter.rs b/rustfst/src/algorithms/compose/compose_filters/sequence_compose_filter.rs index 924a97bd5..4935e22f4 100644 --- a/rustfst/src/algorithms/compose/compose_filters/sequence_compose_filter.rs +++ b/rustfst/src/algorithms/compose/compose_filters/sequence_compose_filter.rs @@ -92,9 +92,9 @@ impl, M2: Matcher> ComposeFilter self.fs = filter_state.clone(); unsafe { let fst1 = self.fst1(); - let na1 = fst1.num_trs_unchecked(self.s1)?; - let ne1 = fst1.num_output_epsilons_unchecked(self.s1)?; - let fin1 = fst1.is_final_unchecked(self.s1)?; + let na1 = fst1.num_trs_unchecked(self.s1); + let ne1 = fst1.num_output_epsilons_unchecked(self.s1); + let fin1 = fst1.is_final_unchecked(self.s1); self.alleps1 = na1 == ne1 && !fin1; self.noeps1 = ne1 == 0; } From cafbfab5de28062f88c5d2fdd4d2ba52a9f5ddab Mon Sep 17 00:00:00 2001 From: Garvys Date: Thu, 24 Sep 2020 10:58:45 +0200 Subject: [PATCH 07/12] Move to unchecked in PushLabelsComposeFilter --- .../lookahead_filters/push_labels_compose_filter.rs | 12 +++++++----- rustfst/src/algorithms/lazy/lazy_fst.rs | 4 ++-- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/rustfst/src/algorithms/compose/lookahead_filters/push_labels_compose_filter.rs b/rustfst/src/algorithms/compose/lookahead_filters/push_labels_compose_filter.rs index e1e8e14ba..c3acf8e81 100644 --- a/rustfst/src/algorithms/compose/lookahead_filters/push_labels_compose_filter.rs +++ b/rustfst/src/algorithms/compose/lookahead_filters/push_labels_compose_filter.rs @@ -137,11 +137,13 @@ where { return Ok(()); } - self.ntrsa = if self.lookahead_output() { - self.filter.fst1().num_trs(s1)? - } else { - self.filter.fst2().num_trs(s2)? - }; + unsafe { + self.ntrsa = if self.lookahead_output() { + self.filter.fst1().num_trs_unchecked(s1) + } else { + self.filter.fst2().num_trs_unchecked(s2) + }; + } let fs2 = filter_state.state2(); let flabel = fs2.state(); self.matcher1.clear_multi_eps_labels(); diff --git a/rustfst/src/algorithms/lazy/lazy_fst.rs b/rustfst/src/algorithms/lazy/lazy_fst.rs index 0c1576f93..001d6bcee 100644 --- a/rustfst/src/algorithms/lazy/lazy_fst.rs +++ b/rustfst/src/algorithms/lazy/lazy_fst.rs @@ -267,8 +267,8 @@ where } } unsafe { fst_out.set_trs_unchecked(s, trs_owner.trs().to_vec()) }; - if let Some(f_w) = self.final_weight(s)? { - fst_out.set_final(s, f_w)?; + if let Some(f_w) = unsafe {self.final_weight_unchecked(s)} { + unsafe {fst_out.set_final_unchecked(s, f_w)}; } } fst_out.set_properties(self.properties()); From 89d58bf8d0bf866a888aef4b62da07db03f724a4 Mon Sep 17 00:00:00 2001 From: Garvys Date: Thu, 24 Sep 2020 16:00:17 +0200 Subject: [PATCH 08/12] Avoid move when creating Arc in ordered_expand --- rustfst/src/algorithms/compose/compose_fst_op.rs | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/rustfst/src/algorithms/compose/compose_fst_op.rs b/rustfst/src/algorithms/compose/compose_fst_op.rs index a9617141d..089c8cc5e 100644 --- a/rustfst/src/algorithms/compose/compose_fst_op.rs +++ b/rustfst/src/algorithms/compose/compose_fst_op.rs @@ -137,7 +137,9 @@ impl> ComposeFstOp { } else { Tr::new(NO_LABEL, EPS_LABEL, W::one(), sb) }; - let mut trs = vec![]; + let mut trs_vec = TrsVec(Arc::new(vec![])); + // Safe as there is no other Arc to the same allocation. + let trs = Arc::get_mut(&mut trs_vec.0).unwrap(); match selector { Selector::Fst1Matcher2 => { @@ -147,10 +149,10 @@ impl> ComposeFstOp { match_input, &mut compose_filter, selector, - &mut trs, + trs, )?; for tr in self.fst1.get_trs(sb)?.trs() { - self.match_tr(sa, tr, match_input, &mut compose_filter, selector, &mut trs)?; + self.match_tr(sa, tr, match_input, &mut compose_filter, selector, trs)?; } } Selector::Fst2Matcher1 => { @@ -160,14 +162,14 @@ impl> ComposeFstOp { match_input, &mut compose_filter, selector, - &mut trs, + trs, )?; for tr in self.fst2.get_trs(sb)?.trs() { - self.match_tr(sa, tr, match_input, &mut compose_filter, selector, &mut trs)?; + self.match_tr(sa, tr, match_input, &mut compose_filter, selector, trs)?; } } } - Ok(TrsVec(Arc::new(trs))) + Ok(trs_vec) } fn add_tr( From b8b79f42238c97297c9d6110c6e06438f65872ed Mon Sep 17 00:00:00 2001 From: Garvys Date: Thu, 24 Sep 2020 16:52:01 +0200 Subject: [PATCH 09/12] Factorize a bit flags computation in PushLabelsComposeFilter --- .../push_labels_compose_filter.rs | 28 ++++++++++++------- .../compose/matchers/multi_eps_matcher.rs | 2 +- rustfst/src/algorithms/lazy/lazy_fst.rs | 4 +-- 3 files changed, 21 insertions(+), 13 deletions(-) diff --git a/rustfst/src/algorithms/compose/lookahead_filters/push_labels_compose_filter.rs b/rustfst/src/algorithms/compose/lookahead_filters/push_labels_compose_filter.rs index c3acf8e81..465e804aa 100644 --- a/rustfst/src/algorithms/compose/lookahead_filters/push_labels_compose_filter.rs +++ b/rustfst/src/algorithms/compose/lookahead_filters/push_labels_compose_filter.rs @@ -47,6 +47,8 @@ where filter_builder: CFB, w: PhantomData, smt: PhantomData, + flags_matcher1: MultiEpsMatcherFlags, + flags_matcher2: MultiEpsMatcherFlags, } impl ComposeFilterBuilder for PushLabelsComposeFilterBuilder @@ -72,10 +74,24 @@ where Self: Sized, { let filter_builder = CFB::new(fst1, fst2, matcher1, matcher2)?; + let filter = filter_builder.build()?; + let flags_matcher1 = if filter.lookahead_output() { + MultiEpsMatcherFlags::MULTI_EPS_LIST + } else { + MultiEpsMatcherFlags::MULTI_EPS_LOOP + }; + let flags_matcher2 = if filter.lookahead_output() { + MultiEpsMatcherFlags::MULTI_EPS_LOOP + } else { + MultiEpsMatcherFlags::MULTI_EPS_LIST + }; + Ok(Self { filter_builder, w: PhantomData, smt: PhantomData, + flags_matcher1, + flags_matcher2, }) } @@ -85,21 +101,13 @@ where let matcher1 = MultiEpsMatcher::new_with_opts( Arc::clone(filter.fst1()), MatchType::MatchOutput, - if filter.lookahead_output() { - MultiEpsMatcherFlags::MULTI_EPS_LIST - } else { - MultiEpsMatcherFlags::MULTI_EPS_LOOP - }, + self.flags_matcher1, Arc::clone(filter.matcher1_shared()), )?; let matcher2 = MultiEpsMatcher::new_with_opts( Arc::clone(filter.fst2()), MatchType::MatchInput, - if filter.lookahead_output() { - MultiEpsMatcherFlags::MULTI_EPS_LOOP - } else { - MultiEpsMatcherFlags::MULTI_EPS_LIST - }, + self.flags_matcher2, Arc::clone(filter.matcher2_shared()), )?; Ok(Self::CF { diff --git a/rustfst/src/algorithms/compose/matchers/multi_eps_matcher.rs b/rustfst/src/algorithms/compose/matchers/multi_eps_matcher.rs index b83187f7b..e30b6e26e 100644 --- a/rustfst/src/algorithms/compose/matchers/multi_eps_matcher.rs +++ b/rustfst/src/algorithms/compose/matchers/multi_eps_matcher.rs @@ -5,7 +5,7 @@ use std::sync::Arc; use anyhow::Result; use itertools::Itertools; -use nom::lib::std::collections::BTreeSet; +use std::collections::BTreeSet; use bitflags::bitflags; diff --git a/rustfst/src/algorithms/lazy/lazy_fst.rs b/rustfst/src/algorithms/lazy/lazy_fst.rs index 001d6bcee..e26f4fa56 100644 --- a/rustfst/src/algorithms/lazy/lazy_fst.rs +++ b/rustfst/src/algorithms/lazy/lazy_fst.rs @@ -267,8 +267,8 @@ where } } unsafe { fst_out.set_trs_unchecked(s, trs_owner.trs().to_vec()) }; - if let Some(f_w) = unsafe {self.final_weight_unchecked(s)} { - unsafe {fst_out.set_final_unchecked(s, f_w)}; + if let Some(f_w) = unsafe { self.final_weight_unchecked(s) } { + unsafe { fst_out.set_final_unchecked(s, f_w) }; } } fst_out.set_properties(self.properties()); From cfbb298f5de44759f6b8ae3ac89969952b4e1f6d Mon Sep 17 00:00:00 2001 From: Garvys Date: Fri, 25 Sep 2020 16:41:20 +0200 Subject: [PATCH 10/12] Implement compute_consuming --- rustfst/src/algorithms/compose/compose.rs | 2 +- rustfst/src/algorithms/compose/compose_fst.rs | 8 ++++ .../src/algorithms/lazy/cache/fst_cache.rs | 5 +++ .../algorithms/lazy/cache/simple_vec_cache.rs | 40 +++++++++++++++++ rustfst/src/algorithms/lazy/lazy_fst.rs | 45 +++++++++++++++++++ .../src/fst_impls/vector_fst/mutable_fst.rs | 15 ++++++- rustfst/src/fst_traits/mutable_fst.rs | 10 ++++- 7 files changed, 122 insertions(+), 3 deletions(-) diff --git a/rustfst/src/algorithms/compose/compose.rs b/rustfst/src/algorithms/compose/compose.rs index e52519d2a..00898fe37 100644 --- a/rustfst/src/algorithms/compose/compose.rs +++ b/rustfst/src/algorithms/compose/compose.rs @@ -48,7 +48,7 @@ pub fn compose_with_config< config: ComposeConfig, ) -> Result { let mut ofst: F3 = match config.compose_filter { - ComposeFilterEnum::AutoFilter => ComposeFst::new_auto(fst1, fst2)?.compute()?, + ComposeFilterEnum::AutoFilter => ComposeFst::new_auto(fst1, fst2)?.compute_2(), ComposeFilterEnum::NullFilter => ComposeFst::< _, NullComposeFilterBuilder<_, SortedMatcher<_, _>, SortedMatcher<_, _>>, diff --git a/rustfst/src/algorithms/compose/compose_fst.rs b/rustfst/src/algorithms/compose/compose_fst.rs index 0e0fc6538..44bce4fe7 100644 --- a/rustfst/src/algorithms/compose/compose_fst.rs +++ b/rustfst/src/algorithms/compose/compose_fst.rs @@ -5,6 +5,7 @@ use crate::algorithms::compose::compose_filters::{ }; use crate::algorithms::compose::matchers::{GenericMatcher, Matcher}; use crate::algorithms::compose::{ComposeFstOp, ComposeFstOpOptions, ComposeStateTuple}; +use crate::algorithms::lazy::cache::fst_cache::FullFstCache; use crate::algorithms::lazy::{FstCache, LazyFst, SimpleVecCache, StateTable}; use crate::fst_properties::FstProperties; use crate::fst_traits::{ @@ -86,6 +87,13 @@ impl, Cache: FstCache> ComposeFst + AllocableFst>(&self) -> Result { self.0.compute() } + + pub fn compute_2 + AllocableFst>(self) -> F2 + where + Cache: FullFstCache, + { + self.0.compute_2() + } } impl, F2: ExpandedFst> diff --git a/rustfst/src/algorithms/lazy/cache/fst_cache.rs b/rustfst/src/algorithms/lazy/cache/fst_cache.rs index cc1cccee6..d594bbcb8 100644 --- a/rustfst/src/algorithms/lazy/cache/fst_cache.rs +++ b/rustfst/src/algorithms/lazy/cache/fst_cache.rs @@ -1,6 +1,7 @@ use std::fmt::Debug; use crate::algorithms::lazy::CacheStatus; +use crate::fst_traits::MutableFst; use crate::semirings::Semiring; use crate::{StateId, TrsVec}; @@ -29,3 +30,7 @@ pub trait FstCache: Debug { fn is_final(&self, state_id: StateId) -> CacheStatus; unsafe fn is_final_unchecked(&self, state_id: StateId) -> bool; } + +pub trait FullFstCache: FstCache { + fn into_fst>(self) -> F; +} diff --git a/rustfst/src/algorithms/lazy/cache/simple_vec_cache.rs b/rustfst/src/algorithms/lazy/cache/simple_vec_cache.rs index 85f433d2c..33b4df738 100644 --- a/rustfst/src/algorithms/lazy/cache/simple_vec_cache.rs +++ b/rustfst/src/algorithms/lazy/cache/simple_vec_cache.rs @@ -1,7 +1,9 @@ use std::sync::Mutex; use crate::algorithms::lazy::cache::cache_internal_types::{CachedData, FinalWeight, StartState}; +use crate::algorithms::lazy::cache::fst_cache::FullFstCache; use crate::algorithms::lazy::{CacheStatus, FstCache}; +use crate::fst_traits::MutableFst; use crate::semirings::Semiring; use crate::{StateId, Trs, TrsVec, EPS_LABEL}; @@ -177,3 +179,41 @@ impl FstCache for SimpleVecCache { } } } + +impl FullFstCache for SimpleVecCache { + fn into_fst>(self) -> F { + let mut fst_out = F::new(); + + // Safe because computed + if let Some(start) = self.get_start().unwrap() { + let nstates = self.num_known_states(); + fst_out.add_states(nstates); + + unsafe { fst_out.set_start_unchecked(start) }; + + let final_weights = self.final_weights.into_inner().unwrap().data; + let trs = self.trs.into_inner().unwrap().data; + + for (state_id, cache_trs) in trs.into_iter().enumerate() { + let cache_trs = cache_trs.unwrap(); + unsafe { + fst_out.set_state_unchecked_noprops( + state_id, + cache_trs.trs, + cache_trs.niepsilons, + cache_trs.noepsilons, + ) + }; + } + + for (state_id, final_weight) in final_weights.into_iter().enumerate() { + // Safe as computed + if let Some(final_weight) = final_weight.unwrap() { + unsafe { fst_out.set_final_unchecked(state_id, final_weight) }; + } + } + } + + fst_out + } +} diff --git a/rustfst/src/algorithms/lazy/lazy_fst.rs b/rustfst/src/algorithms/lazy/lazy_fst.rs index e26f4fa56..57888be5d 100644 --- a/rustfst/src/algorithms/lazy/lazy_fst.rs +++ b/rustfst/src/algorithms/lazy/lazy_fst.rs @@ -7,6 +7,7 @@ use anyhow::Result; use itertools::izip; use unsafe_unwrap::UnsafeUnwrap; +use crate::algorithms::lazy::cache::fst_cache::FullFstCache; use crate::algorithms::lazy::cache::CacheStatus; use crate::algorithms::lazy::fst_op::FstOp; use crate::algorithms::lazy::FstCache; @@ -275,4 +276,48 @@ where // TODO: Symbol tables should be set here Ok(fst_out) } + + fn fill_cache(&self) { + let start_state = self.start(); + if start_state.is_none() { + return; + } + let start_state = start_state.unwrap(); + let mut queue = VecDeque::new(); + let mut visited_states = vec![]; + visited_states.resize(start_state + 1, false); + visited_states[start_state] = true; + queue.push_back(start_state); + while !queue.is_empty() { + let s = queue.pop_front().unwrap(); + let trs_owner = unsafe { self.get_trs_unchecked(s) }; + for tr in trs_owner.trs() { + if tr.nextstate >= visited_states.len() { + visited_states.resize(tr.nextstate + 1, false); + } + if !visited_states[tr.nextstate] { + queue.push_back(tr.nextstate); + visited_states[tr.nextstate] = true; + } + } + + // Force computation final weight + unsafe { self.final_weight_unchecked(s) }; + } + } + + /// Turns the Lazy FST into a static one. + pub fn compute_2 + AllocableFst>(self) -> F2 + where + Cache: FullFstCache, + { + self.fill_cache(); + + let props = self.properties(); + let mut fst: F2 = self.cache.into_fst(); + + fst.set_properties(props); + + fst + } } diff --git a/rustfst/src/fst_impls/vector_fst/mutable_fst.rs b/rustfst/src/fst_impls/vector_fst/mutable_fst.rs index 138672d17..51c562f87 100644 --- a/rustfst/src/fst_impls/vector_fst/mutable_fst.rs +++ b/rustfst/src/fst_impls/vector_fst/mutable_fst.rs @@ -14,7 +14,7 @@ use crate::fst_traits::CoreFst; use crate::fst_traits::MutableFst; use crate::semirings::Semiring; use crate::trs_iter_mut::TrsIterMut; -use crate::{StateId, Tr, Trs, EPS_LABEL}; +use crate::{StateId, Tr, Trs, TrsVec, EPS_LABEL}; #[inline] fn equal_tr(tr_1: &Tr, tr_2: &Tr) -> bool { @@ -410,4 +410,17 @@ impl MutableFst for VectorFst { self.properties &= !mask; self.properties |= props & mask; } + + unsafe fn set_state_unchecked_noprops( + &mut self, + source: usize, + trs: TrsVec, + niepsilons: usize, + noepsilons: usize, + ) { + let state = &mut self.states.get_unchecked_mut(source); + state.trs = trs; + state.niepsilons = niepsilons; + state.noepsilons = noepsilons; + } } diff --git a/rustfst/src/fst_traits/mutable_fst.rs b/rustfst/src/fst_traits/mutable_fst.rs index 6ae20c088..1760553de 100644 --- a/rustfst/src/fst_traits/mutable_fst.rs +++ b/rustfst/src/fst_traits/mutable_fst.rs @@ -9,7 +9,7 @@ use crate::fst_traits::ExpandedFst; use crate::semirings::Semiring; use crate::tr::Tr; use crate::trs_iter_mut::TrsIterMut; -use crate::{Label, StateId}; +use crate::{Label, StateId, TrsVec}; /// Trait defining the methods to modify a wFST. pub trait MutableFst: ExpandedFst { @@ -444,4 +444,12 @@ pub trait MutableFst: ExpandedFst { fn compute_and_update_properties_all(&mut self) -> Result { self.compute_and_update_properties(FstProperties::all_properties()) } + + unsafe fn set_state_unchecked_noprops( + &mut self, + source: StateId, + trs: TrsVec, + niepsilons: usize, + noepsilons: usize, + ); } From a604719a15bc07613105a084fed012686baa77cc Mon Sep 17 00:00:00 2001 From: Garvys Date: Fri, 25 Sep 2020 17:08:25 +0200 Subject: [PATCH 11/12] Move to consuming compute for all compose types --- rustfst-cli/src/cmds/compose.rs | 2 +- rustfst/src/algorithms/compose/compose.rs | 12 ++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/rustfst-cli/src/cmds/compose.rs b/rustfst-cli/src/cmds/compose.rs index 709d7ab11..644595c2d 100644 --- a/rustfst-cli/src/cmds/compose.rs +++ b/rustfst-cli/src/cmds/compose.rs @@ -142,7 +142,7 @@ impl BinaryFstAlgorithm for ComposeAlgorithm { compose_options, )?; - dyn_fst.compute() + Ok(dyn_fst.compute_2()) } } } diff --git a/rustfst/src/algorithms/compose/compose.rs b/rustfst/src/algorithms/compose/compose.rs index 00898fe37..6c6c35a33 100644 --- a/rustfst/src/algorithms/compose/compose.rs +++ b/rustfst/src/algorithms/compose/compose.rs @@ -53,32 +53,32 @@ pub fn compose_with_config< _, NullComposeFilterBuilder<_, SortedMatcher<_, _>, SortedMatcher<_, _>>, >::new(fst1, fst2)? - .compute()?, + .compute_2(), ComposeFilterEnum::SequenceFilter => ComposeFst::< _, SequenceComposeFilterBuilder<_, SortedMatcher<_, _>, SortedMatcher<_, _>>, >::new(fst1, fst2)? - .compute()?, + .compute_2(), ComposeFilterEnum::AltSequenceFilter => ComposeFst::< _, AltSequenceComposeFilterBuilder<_, SortedMatcher<_, _>, SortedMatcher<_, _>>, >::new(fst1, fst2)? - .compute()?, + .compute_2(), ComposeFilterEnum::MatchFilter => ComposeFst::< _, MatchComposeFilterBuilder<_, SortedMatcher<_, _>, SortedMatcher<_, _>>, >::new(fst1, fst2)? - .compute()?, + .compute_2(), ComposeFilterEnum::NoMatchFilter => ComposeFst::< _, NoMatchComposeFilterBuilder<_, SortedMatcher<_, _>, SortedMatcher<_, _>>, >::new(fst1, fst2)? - .compute()?, + .compute_2(), ComposeFilterEnum::TrivialFilter => ComposeFst::< _, TrivialComposeFilterBuilder<_, SortedMatcher<_, _>, SortedMatcher<_, _>>, >::new(fst1, fst2)? - .compute()?, + .compute_2(), }; if config.connect { From 2f179f6afe71500a82a394f2ed19c07ef26da429 Mon Sep 17 00:00:00 2001 From: Garvys Date: Mon, 28 Sep 2020 11:44:37 +0200 Subject: [PATCH 12/12] FullFstCache -> FillableFstCache --- rustfst/src/algorithms/compose/compose_fst.rs | 6 +++--- rustfst/src/algorithms/lazy/cache/fst_cache.rs | 3 ++- rustfst/src/algorithms/lazy/cache/mod.rs | 2 +- rustfst/src/algorithms/lazy/cache/simple_vec_cache.rs | 4 ++-- rustfst/src/algorithms/lazy/lazy_fst.rs | 6 +++--- 5 files changed, 11 insertions(+), 10 deletions(-) diff --git a/rustfst/src/algorithms/compose/compose_fst.rs b/rustfst/src/algorithms/compose/compose_fst.rs index 44bce4fe7..12a6098e9 100644 --- a/rustfst/src/algorithms/compose/compose_fst.rs +++ b/rustfst/src/algorithms/compose/compose_fst.rs @@ -5,7 +5,7 @@ use crate::algorithms::compose::compose_filters::{ }; use crate::algorithms::compose::matchers::{GenericMatcher, Matcher}; use crate::algorithms::compose::{ComposeFstOp, ComposeFstOpOptions, ComposeStateTuple}; -use crate::algorithms::lazy::cache::fst_cache::FullFstCache; +use crate::algorithms::lazy::cache::fst_cache::FillableFstCache; use crate::algorithms::lazy::{FstCache, LazyFst, SimpleVecCache, StateTable}; use crate::fst_properties::FstProperties; use crate::fst_traits::{ @@ -90,9 +90,9 @@ impl, Cache: FstCache> ComposeFst + AllocableFst>(self) -> F2 where - Cache: FullFstCache, + Cache: FillableFstCache, { - self.0.compute_2() + self.0.into_static_fst() } } diff --git a/rustfst/src/algorithms/lazy/cache/fst_cache.rs b/rustfst/src/algorithms/lazy/cache/fst_cache.rs index d594bbcb8..ca46bb015 100644 --- a/rustfst/src/algorithms/lazy/cache/fst_cache.rs +++ b/rustfst/src/algorithms/lazy/cache/fst_cache.rs @@ -31,6 +31,7 @@ pub trait FstCache: Debug { unsafe fn is_final_unchecked(&self, state_id: StateId) -> bool; } -pub trait FullFstCache: FstCache { +pub trait FillableFstCache: FstCache { + /// Given a cache that is full. Turn it into a MutableFst. fn into_fst>(self) -> F; } diff --git a/rustfst/src/algorithms/lazy/cache/mod.rs b/rustfst/src/algorithms/lazy/cache/mod.rs index 0fd314e32..af69eabd1 100644 --- a/rustfst/src/algorithms/lazy/cache/mod.rs +++ b/rustfst/src/algorithms/lazy/cache/mod.rs @@ -8,6 +8,6 @@ pub mod simple_vec_cache; pub use self::cache_status::CacheStatus; pub use self::first_cache::FirstCache; -pub use self::fst_cache::FstCache; +pub use self::fst_cache::{FstCache, FillableFstCache}; pub use self::simple_hash_map_cache::SimpleHashMapCache; pub use self::simple_vec_cache::SimpleVecCache; diff --git a/rustfst/src/algorithms/lazy/cache/simple_vec_cache.rs b/rustfst/src/algorithms/lazy/cache/simple_vec_cache.rs index 33b4df738..c0b6a42b5 100644 --- a/rustfst/src/algorithms/lazy/cache/simple_vec_cache.rs +++ b/rustfst/src/algorithms/lazy/cache/simple_vec_cache.rs @@ -1,7 +1,7 @@ use std::sync::Mutex; use crate::algorithms::lazy::cache::cache_internal_types::{CachedData, FinalWeight, StartState}; -use crate::algorithms::lazy::cache::fst_cache::FullFstCache; +use crate::algorithms::lazy::cache::fst_cache::FillableFstCache; use crate::algorithms::lazy::{CacheStatus, FstCache}; use crate::fst_traits::MutableFst; use crate::semirings::Semiring; @@ -180,7 +180,7 @@ impl FstCache for SimpleVecCache { } } -impl FullFstCache for SimpleVecCache { +impl FillableFstCache for SimpleVecCache { fn into_fst>(self) -> F { let mut fst_out = F::new(); diff --git a/rustfst/src/algorithms/lazy/lazy_fst.rs b/rustfst/src/algorithms/lazy/lazy_fst.rs index 57888be5d..8dcad0f60 100644 --- a/rustfst/src/algorithms/lazy/lazy_fst.rs +++ b/rustfst/src/algorithms/lazy/lazy_fst.rs @@ -7,7 +7,7 @@ use anyhow::Result; use itertools::izip; use unsafe_unwrap::UnsafeUnwrap; -use crate::algorithms::lazy::cache::fst_cache::FullFstCache; +use crate::algorithms::lazy::cache::fst_cache::FillableFstCache; use crate::algorithms::lazy::cache::CacheStatus; use crate::algorithms::lazy::fst_op::FstOp; use crate::algorithms::lazy::FstCache; @@ -307,9 +307,9 @@ where } /// Turns the Lazy FST into a static one. - pub fn compute_2 + AllocableFst>(self) -> F2 + pub fn into_static_fst + AllocableFst>(self) -> F2 where - Cache: FullFstCache, + Cache: FillableFstCache, { self.fill_cache();