diff --git a/rustfst-cli/src/cmds/compose.rs b/rustfst-cli/src/cmds/compose.rs index 569b35975..644595c2d 100644 --- a/rustfst-cli/src/cmds/compose.rs +++ b/rustfst-cli/src/cmds/compose.rs @@ -2,13 +2,6 @@ use std::sync::Arc; use anyhow::Result; -use rustfst::algorithms::compose::{ - compose, ComposeFst, ComposeFstOpOptions, LabelReachableData, MatcherFst, -}; -use rustfst::fst_impls::VectorFst; -use rustfst::semirings::TropicalWeight; - -use crate::binary_fst_algorithm::BinaryFstAlgorithm; use rustfst::algorithms::compose::compose_filters::{ AltSequenceComposeFilterBuilder, ComposeFilterBuilder, }; @@ -20,9 +13,16 @@ use rustfst::algorithms::compose::lookahead_matchers::{ LabelLookAheadMatcher, LookaheadMatcher, MatcherFlagsTrait, }; use rustfst::algorithms::compose::matchers::{MatchType, Matcher, MatcherFlags, SortedMatcher}; -use rustfst::algorithms::lazy::SimpleHashMapCache; +use rustfst::algorithms::compose::{ + compose, ComposeFst, ComposeFstOpOptions, LabelReachableData, MatcherFst, +}; +use rustfst::algorithms::lazy::SimpleVecCache; use rustfst::algorithms::tr_compares::ILabelCompare; use rustfst::algorithms::tr_sort; +use rustfst::fst_impls::VectorFst; +use rustfst::semirings::TropicalWeight; + +use crate::binary_fst_algorithm::BinaryFstAlgorithm; #[derive(Debug, Clone, Copy)] pub enum ComposeType { @@ -136,13 +136,13 @@ impl BinaryFstAlgorithm for ComposeAlgorithm { None, ); - let dyn_fst = ComposeFst::<_, _, SimpleHashMapCache<_>>::new_with_options( + let dyn_fst = ComposeFst::<_, _, SimpleVecCache<_>>::new_with_options( graph1look, fst_2, compose_options, )?; - dyn_fst.compute() + Ok(dyn_fst.compute_2()) } } } diff --git a/rustfst/src/algorithms/closure/closure_fst.rs b/rustfst/src/algorithms/closure/closure_fst.rs index 46dc116b0..a982e632c 100644 --- a/rustfst/src/algorithms/closure/closure_fst.rs +++ b/rustfst/src/algorithms/closure/closure_fst.rs @@ -95,6 +95,14 @@ where self.0.num_trs_unchecked(s) } + fn is_final(&self, state_id: usize) -> Result { + self.0.is_final(state_id) + } + + unsafe fn is_final_unchecked(&self, state_id: usize) -> bool { + self.0.is_final_unchecked(state_id) + } + fn get_trs(&self, state_id: usize) -> Result { self.0.get_trs(state_id) } @@ -111,9 +119,17 @@ where self.0.num_input_epsilons(state) } + unsafe fn num_input_epsilons_unchecked(&self, state: usize) -> usize { + self.0.num_input_epsilons_unchecked(state) + } + fn num_output_epsilons(&self, state: usize) -> Result { self.0.num_output_epsilons(state) } + + unsafe fn num_output_epsilons_unchecked(&self, state: usize) -> usize { + self.0.num_output_epsilons_unchecked(state) + } } impl<'a, W, F> StateIterator<'a> for ClosureFst diff --git a/rustfst/src/algorithms/compose/add_on.rs b/rustfst/src/algorithms/compose/add_on.rs index 68bb33995..3665f8688 100644 --- a/rustfst/src/algorithms/compose/add_on.rs +++ b/rustfst/src/algorithms/compose/add_on.rs @@ -57,6 +57,14 @@ impl, T> CoreFst for FstAddOn { self.fst.num_trs_unchecked(s) } + fn is_final(&self, state_id: usize) -> Result { + self.fst.is_final(state_id) + } + + unsafe fn is_final_unchecked(&self, state_id: usize) -> bool { + self.fst.is_final_unchecked(state_id) + } + fn get_trs(&self, state_id: usize) -> Result { self.fst.get_trs(state_id) } @@ -73,9 +81,17 @@ impl, T> CoreFst for FstAddOn { self.fst.num_input_epsilons(state) } + unsafe fn num_input_epsilons_unchecked(&self, state: usize) -> usize { + self.fst.num_input_epsilons_unchecked(state) + } + fn num_output_epsilons(&self, state: usize) -> Result { self.fst.num_output_epsilons(state) } + + unsafe fn num_output_epsilons_unchecked(&self, state: usize) -> usize { + self.fst.num_output_epsilons_unchecked(state) + } } impl<'a, F: StateIterator<'a>, T> StateIterator<'a> for FstAddOn { diff --git a/rustfst/src/algorithms/compose/compose.rs b/rustfst/src/algorithms/compose/compose.rs index e52519d2a..6c6c35a33 100644 --- a/rustfst/src/algorithms/compose/compose.rs +++ b/rustfst/src/algorithms/compose/compose.rs @@ -48,37 +48,37 @@ pub fn compose_with_config< config: ComposeConfig, ) -> Result { let mut ofst: F3 = match config.compose_filter { - ComposeFilterEnum::AutoFilter => ComposeFst::new_auto(fst1, fst2)?.compute()?, + ComposeFilterEnum::AutoFilter => ComposeFst::new_auto(fst1, fst2)?.compute_2(), ComposeFilterEnum::NullFilter => ComposeFst::< _, NullComposeFilterBuilder<_, SortedMatcher<_, _>, SortedMatcher<_, _>>, >::new(fst1, fst2)? - .compute()?, + .compute_2(), ComposeFilterEnum::SequenceFilter => ComposeFst::< _, SequenceComposeFilterBuilder<_, SortedMatcher<_, _>, SortedMatcher<_, _>>, >::new(fst1, fst2)? - .compute()?, + .compute_2(), ComposeFilterEnum::AltSequenceFilter => ComposeFst::< _, AltSequenceComposeFilterBuilder<_, SortedMatcher<_, _>, SortedMatcher<_, _>>, >::new(fst1, fst2)? - .compute()?, + .compute_2(), ComposeFilterEnum::MatchFilter => ComposeFst::< _, MatchComposeFilterBuilder<_, SortedMatcher<_, _>, SortedMatcher<_, _>>, >::new(fst1, fst2)? - .compute()?, + .compute_2(), ComposeFilterEnum::NoMatchFilter => ComposeFst::< _, NoMatchComposeFilterBuilder<_, SortedMatcher<_, _>, SortedMatcher<_, _>>, >::new(fst1, fst2)? - .compute()?, + .compute_2(), ComposeFilterEnum::TrivialFilter => ComposeFst::< _, TrivialComposeFilterBuilder<_, SortedMatcher<_, _>, SortedMatcher<_, _>>, >::new(fst1, fst2)? - .compute()?, + .compute_2(), }; if config.connect { diff --git a/rustfst/src/algorithms/compose/compose_filters/alt_sequence_compose_filter.rs b/rustfst/src/algorithms/compose/compose_filters/alt_sequence_compose_filter.rs index 81633eb40..29a0b75ac 100644 --- a/rustfst/src/algorithms/compose/compose_filters/alt_sequence_compose_filter.rs +++ b/rustfst/src/algorithms/compose/compose_filters/alt_sequence_compose_filter.rs @@ -92,13 +92,14 @@ impl, M2: Matcher> ComposeFilter self.s1 = s1; self.s2 = s2; self.fs = filter_state.clone(); - // TODO: Could probably use unchecked here as the state should exist. - let fst2 = self.fst2(); - let na2 = fst2.num_trs(self.s2)?; - let ne2 = fst2.num_input_epsilons(self.s2)?; - let fin2 = fst2.is_final(self.s2)?; - self.alleps2 = na2 == ne2 && !fin2; - self.noeps2 = ne2 == 0; + unsafe { + let fst2 = self.fst2(); + let na2 = fst2.num_trs_unchecked(self.s2); + let ne2 = fst2.num_input_epsilons_unchecked(self.s2); + let fin2 = fst2.is_final_unchecked(self.s2); + self.alleps2 = na2 == ne2 && !fin2; + self.noeps2 = ne2 == 0; + } } Ok(()) } diff --git a/rustfst/src/algorithms/compose/compose_filters/match_compose_filter.rs b/rustfst/src/algorithms/compose/compose_filters/match_compose_filter.rs index 81ba7b40b..f426d19eb 100644 --- a/rustfst/src/algorithms/compose/compose_filters/match_compose_filter.rs +++ b/rustfst/src/algorithms/compose/compose_filters/match_compose_filter.rs @@ -96,23 +96,25 @@ impl, M2: Matcher> ComposeFilter self.s2 = s2; self.fs = filter_state.clone(); - let fst1 = self.fst1(); - let fst2 = self.fst2(); + unsafe { + let fst1 = self.fst1(); + let fst2 = self.fst2(); - let na1 = fst1.num_trs(s1)?; - let na2 = fst2.num_trs(s2)?; + let na1 = fst1.num_trs_unchecked(s1); + let na2 = fst2.num_trs_unchecked(s2); - let ne1 = fst1.num_output_epsilons(s1)?; - let ne2 = fst2.num_input_epsilons(s2)?; + let ne1 = fst1.num_output_epsilons_unchecked(s1); + let ne2 = fst2.num_input_epsilons_unchecked(s2); - let f1 = fst1.is_final(s1)?; - let f2 = fst2.is_final(s2)?; + let f1 = fst1.is_final_unchecked(s1); + let f2 = fst2.is_final_unchecked(s2); - self.alleps1 = na1 == ne1 && !f1; - self.alleps2 = na2 == ne2 && !f2; + self.alleps1 = na1 == ne1 && !f1; + self.alleps2 = na2 == ne2 && !f2; - self.noeps1 = ne1 == 0; - self.noeps2 = ne2 == 0; + self.noeps1 = ne1 == 0; + self.noeps2 = ne2 == 0; + } } Ok(()) } diff --git a/rustfst/src/algorithms/compose/compose_filters/sequence_compose_filter.rs b/rustfst/src/algorithms/compose/compose_filters/sequence_compose_filter.rs index 61dd30345..4935e22f4 100644 --- a/rustfst/src/algorithms/compose/compose_filters/sequence_compose_filter.rs +++ b/rustfst/src/algorithms/compose/compose_filters/sequence_compose_filter.rs @@ -90,13 +90,14 @@ impl, M2: Matcher> ComposeFilter self.s1 = s1; self.s2 = s2; self.fs = filter_state.clone(); - // TODO: Could probably use unchecked here as the state should exist. - let fst1 = self.fst1(); - let na1 = fst1.num_trs(self.s1)?; - let ne1 = fst1.num_output_epsilons(self.s1)?; - let fin1 = fst1.is_final(self.s1)?; - self.alleps1 = na1 == ne1 && !fin1; - self.noeps1 = ne1 == 0; + unsafe { + let fst1 = self.fst1(); + let na1 = fst1.num_trs_unchecked(self.s1); + let ne1 = fst1.num_output_epsilons_unchecked(self.s1); + let fin1 = fst1.is_final_unchecked(self.s1); + self.alleps1 = na1 == ne1 && !fin1; + self.noeps1 = ne1 == 0; + } } Ok(()) } diff --git a/rustfst/src/algorithms/compose/compose_fst.rs b/rustfst/src/algorithms/compose/compose_fst.rs index 3b9474128..12a6098e9 100644 --- a/rustfst/src/algorithms/compose/compose_fst.rs +++ b/rustfst/src/algorithms/compose/compose_fst.rs @@ -5,6 +5,7 @@ use crate::algorithms::compose::compose_filters::{ }; use crate::algorithms::compose::matchers::{GenericMatcher, Matcher}; use crate::algorithms::compose::{ComposeFstOp, ComposeFstOpOptions, ComposeStateTuple}; +use crate::algorithms::lazy::cache::fst_cache::FillableFstCache; use crate::algorithms::lazy::{FstCache, LazyFst, SimpleVecCache, StateTable}; use crate::fst_properties::FstProperties; use crate::fst_traits::{ @@ -86,6 +87,13 @@ impl, Cache: FstCache> ComposeFst + AllocableFst>(&self) -> Result { self.0.compute() } + + pub fn compute_2 + AllocableFst>(self) -> F2 + where + Cache: FillableFstCache, + { + self.0.into_static_fst() + } } impl, F2: ExpandedFst> @@ -129,6 +137,14 @@ where self.0.num_trs_unchecked(s) } + fn is_final(&self, state_id: usize) -> Result { + self.0.is_final(state_id) + } + + unsafe fn is_final_unchecked(&self, state_id: usize) -> bool { + self.0.is_final_unchecked(state_id) + } + fn get_trs(&self, state_id: usize) -> Result { self.0.get_trs(state_id) } @@ -145,9 +161,17 @@ where self.0.num_input_epsilons(state) } + unsafe fn num_input_epsilons_unchecked(&self, state: usize) -> usize { + self.0.num_input_epsilons_unchecked(state) + } + fn num_output_epsilons(&self, state: usize) -> Result { self.0.num_output_epsilons(state) } + + unsafe fn num_output_epsilons_unchecked(&self, state: usize) -> usize { + self.0.num_output_epsilons_unchecked(state) + } } impl<'a, W, CFB, Cache> StateIterator<'a> for ComposeFst diff --git a/rustfst/src/algorithms/compose/compose_fst_op.rs b/rustfst/src/algorithms/compose/compose_fst_op.rs index a9617141d..089c8cc5e 100644 --- a/rustfst/src/algorithms/compose/compose_fst_op.rs +++ b/rustfst/src/algorithms/compose/compose_fst_op.rs @@ -137,7 +137,9 @@ impl> ComposeFstOp { } else { Tr::new(NO_LABEL, EPS_LABEL, W::one(), sb) }; - let mut trs = vec![]; + let mut trs_vec = TrsVec(Arc::new(vec![])); + // Safe as there is no other Arc to the same allocation. + let trs = Arc::get_mut(&mut trs_vec.0).unwrap(); match selector { Selector::Fst1Matcher2 => { @@ -147,10 +149,10 @@ impl> ComposeFstOp { match_input, &mut compose_filter, selector, - &mut trs, + trs, )?; for tr in self.fst1.get_trs(sb)?.trs() { - self.match_tr(sa, tr, match_input, &mut compose_filter, selector, &mut trs)?; + self.match_tr(sa, tr, match_input, &mut compose_filter, selector, trs)?; } } Selector::Fst2Matcher1 => { @@ -160,14 +162,14 @@ impl> ComposeFstOp { match_input, &mut compose_filter, selector, - &mut trs, + trs, )?; for tr in self.fst2.get_trs(sb)?.trs() { - self.match_tr(sa, tr, match_input, &mut compose_filter, selector, &mut trs)?; + self.match_tr(sa, tr, match_input, &mut compose_filter, selector, trs)?; } } } - Ok(TrsVec(Arc::new(trs))) + Ok(trs_vec) } fn add_tr( diff --git a/rustfst/src/algorithms/compose/lookahead_filters/push_labels_compose_filter.rs b/rustfst/src/algorithms/compose/lookahead_filters/push_labels_compose_filter.rs index e1e8e14ba..465e804aa 100644 --- a/rustfst/src/algorithms/compose/lookahead_filters/push_labels_compose_filter.rs +++ b/rustfst/src/algorithms/compose/lookahead_filters/push_labels_compose_filter.rs @@ -47,6 +47,8 @@ where filter_builder: CFB, w: PhantomData, smt: PhantomData, + flags_matcher1: MultiEpsMatcherFlags, + flags_matcher2: MultiEpsMatcherFlags, } impl ComposeFilterBuilder for PushLabelsComposeFilterBuilder @@ -72,10 +74,24 @@ where Self: Sized, { let filter_builder = CFB::new(fst1, fst2, matcher1, matcher2)?; + let filter = filter_builder.build()?; + let flags_matcher1 = if filter.lookahead_output() { + MultiEpsMatcherFlags::MULTI_EPS_LIST + } else { + MultiEpsMatcherFlags::MULTI_EPS_LOOP + }; + let flags_matcher2 = if filter.lookahead_output() { + MultiEpsMatcherFlags::MULTI_EPS_LOOP + } else { + MultiEpsMatcherFlags::MULTI_EPS_LIST + }; + Ok(Self { filter_builder, w: PhantomData, smt: PhantomData, + flags_matcher1, + flags_matcher2, }) } @@ -85,21 +101,13 @@ where let matcher1 = MultiEpsMatcher::new_with_opts( Arc::clone(filter.fst1()), MatchType::MatchOutput, - if filter.lookahead_output() { - MultiEpsMatcherFlags::MULTI_EPS_LIST - } else { - MultiEpsMatcherFlags::MULTI_EPS_LOOP - }, + self.flags_matcher1, Arc::clone(filter.matcher1_shared()), )?; let matcher2 = MultiEpsMatcher::new_with_opts( Arc::clone(filter.fst2()), MatchType::MatchInput, - if filter.lookahead_output() { - MultiEpsMatcherFlags::MULTI_EPS_LOOP - } else { - MultiEpsMatcherFlags::MULTI_EPS_LIST - }, + self.flags_matcher2, Arc::clone(filter.matcher2_shared()), )?; Ok(Self::CF { @@ -137,11 +145,13 @@ where { return Ok(()); } - self.ntrsa = if self.lookahead_output() { - self.filter.fst1().num_trs(s1)? - } else { - self.filter.fst2().num_trs(s2)? - }; + unsafe { + self.ntrsa = if self.lookahead_output() { + self.filter.fst1().num_trs_unchecked(s1) + } else { + self.filter.fst2().num_trs_unchecked(s2) + }; + } let fs2 = filter_state.state2(); let flabel = fs2.state(); self.matcher1.clear_multi_eps_labels(); diff --git a/rustfst/src/algorithms/compose/matcher_fst.rs b/rustfst/src/algorithms/compose/matcher_fst.rs index f3eb94dc7..761788df9 100644 --- a/rustfst/src/algorithms/compose/matcher_fst.rs +++ b/rustfst/src/algorithms/compose/matcher_fst.rs @@ -112,6 +112,14 @@ impl, M, T> CoreFst for MatcherFst { self.fst_add_on.num_trs_unchecked(s) } + fn is_final(&self, state_id: usize) -> Result { + self.fst_add_on.is_final(state_id) + } + + unsafe fn is_final_unchecked(&self, state_id: usize) -> bool { + self.fst_add_on.is_final_unchecked(state_id) + } + fn get_trs(&self, state_id: usize) -> Result { self.fst_add_on.get_trs(state_id) } @@ -128,9 +136,17 @@ impl, M, T> CoreFst for MatcherFst { self.fst_add_on.num_input_epsilons(state) } + unsafe fn num_input_epsilons_unchecked(&self, state: usize) -> usize { + self.fst_add_on.num_input_epsilons_unchecked(state) + } + fn num_output_epsilons(&self, state: usize) -> Result { self.fst_add_on.num_output_epsilons(state) } + + unsafe fn num_output_epsilons_unchecked(&self, state: usize) -> usize { + self.fst_add_on.num_output_epsilons_unchecked(state) + } } impl<'a, W, F: StateIterator<'a>, M, T> StateIterator<'a> for MatcherFst { diff --git a/rustfst/src/algorithms/compose/matchers/multi_eps_matcher.rs b/rustfst/src/algorithms/compose/matchers/multi_eps_matcher.rs index b83187f7b..e30b6e26e 100644 --- a/rustfst/src/algorithms/compose/matchers/multi_eps_matcher.rs +++ b/rustfst/src/algorithms/compose/matchers/multi_eps_matcher.rs @@ -5,7 +5,7 @@ use std::sync::Arc; use anyhow::Result; use itertools::Itertools; -use nom::lib::std::collections::BTreeSet; +use std::collections::BTreeSet; use bitflags::bitflags; diff --git a/rustfst/src/algorithms/concat/concat_fst.rs b/rustfst/src/algorithms/concat/concat_fst.rs index 9bc574d5a..e92a8a7ab 100644 --- a/rustfst/src/algorithms/concat/concat_fst.rs +++ b/rustfst/src/algorithms/concat/concat_fst.rs @@ -83,6 +83,14 @@ where self.0.num_trs_unchecked(s) } + fn is_final(&self, state_id: usize) -> Result { + self.0.is_final(state_id) + } + + unsafe fn is_final_unchecked(&self, state_id: usize) -> bool { + self.0.is_final_unchecked(state_id) + } + fn get_trs(&self, state_id: usize) -> Result { self.0.get_trs(state_id) } @@ -99,9 +107,17 @@ where self.0.num_input_epsilons(state) } + unsafe fn num_input_epsilons_unchecked(&self, state: usize) -> usize { + self.0.num_input_epsilons_unchecked(state) + } + fn num_output_epsilons(&self, state: usize) -> Result { self.0.num_output_epsilons(state) } + + unsafe fn num_output_epsilons_unchecked(&self, state: usize) -> usize { + self.0.num_output_epsilons_unchecked(state) + } } impl<'a, W, F> StateIterator<'a> for ConcatFst diff --git a/rustfst/src/algorithms/determinize/determinize_fsa.rs b/rustfst/src/algorithms/determinize/determinize_fsa.rs index 39e6254fb..0eb639ff1 100644 --- a/rustfst/src/algorithms/determinize/determinize_fsa.rs +++ b/rustfst/src/algorithms/determinize/determinize_fsa.rs @@ -43,6 +43,14 @@ where self.0.num_trs_unchecked(s) } + fn is_final(&self, state_id: usize) -> Result { + self.0.is_final(state_id) + } + + unsafe fn is_final_unchecked(&self, state_id: usize) -> bool { + self.0.is_final_unchecked(state_id) + } + fn get_trs(&self, state_id: usize) -> Result { self.0.get_trs(state_id) } @@ -59,9 +67,17 @@ where self.0.num_input_epsilons(state) } + unsafe fn num_input_epsilons_unchecked(&self, state: usize) -> usize { + self.0.num_input_epsilons_unchecked(state) + } + fn num_output_epsilons(&self, state: usize) -> Result { self.0.num_output_epsilons(state) } + + unsafe fn num_output_epsilons_unchecked(&self, state: usize) -> usize { + self.0.num_output_epsilons_unchecked(state) + } } impl<'a, W, F, CD> StateIterator<'a> for DeterminizeFsa diff --git a/rustfst/src/algorithms/factor_weight/factor_weight_fst.rs b/rustfst/src/algorithms/factor_weight/factor_weight_fst.rs index 45e30e996..22c47791c 100644 --- a/rustfst/src/algorithms/factor_weight/factor_weight_fst.rs +++ b/rustfst/src/algorithms/factor_weight/factor_weight_fst.rs @@ -50,6 +50,14 @@ where self.0.num_trs_unchecked(s) } + fn is_final(&self, state_id: usize) -> Result { + self.0.is_final(state_id) + } + + unsafe fn is_final_unchecked(&self, state_id: usize) -> bool { + self.0.is_final_unchecked(state_id) + } + fn get_trs(&self, state_id: usize) -> Result { self.0.get_trs(state_id) } @@ -66,9 +74,17 @@ where self.0.num_input_epsilons(state) } + unsafe fn num_input_epsilons_unchecked(&self, state: usize) -> usize { + self.0.num_input_epsilons_unchecked(state) + } + fn num_output_epsilons(&self, state: usize) -> Result { self.0.num_output_epsilons(state) } + + unsafe fn num_output_epsilons_unchecked(&self, state: usize) -> usize { + self.0.num_output_epsilons_unchecked(state) + } } impl<'a, W, F, B, FI> StateIterator<'a> for FactorWeightFst diff --git a/rustfst/src/algorithms/lazy/cache/arc_cache.rs b/rustfst/src/algorithms/lazy/cache/arc_cache.rs index 5d3153b48..97a2a1fda 100644 --- a/rustfst/src/algorithms/lazy/cache/arc_cache.rs +++ b/rustfst/src/algorithms/lazy/cache/arc_cache.rs @@ -37,14 +37,22 @@ impl> FstCache for Arc { self.deref().num_trs(id) } - fn num_input_epsilons(&self, id: usize) -> Option { + fn num_input_epsilons(&self, id: usize) -> CacheStatus { self.deref().num_input_epsilons(id) } - fn num_output_epsilons(&self, id: usize) -> Option { + unsafe fn num_input_epsilons_unchecked(&self, id: usize) -> usize { + self.deref().num_input_epsilons_unchecked(id) + } + + fn num_output_epsilons(&self, id: usize) -> CacheStatus { self.deref().num_output_epsilons(id) } + unsafe fn num_output_epsilons_unchecked(&self, id: usize) -> usize { + self.deref().num_output_epsilons_unchecked(id) + } + fn len_trs(&self) -> usize { self.deref().len_trs() } @@ -52,4 +60,12 @@ impl> FstCache for Arc { fn len_final_weights(&self) -> usize { self.deref().len_final_weights() } + + fn is_final(&self, state_id: usize) -> CacheStatus { + self.deref().is_final(state_id) + } + + unsafe fn is_final_unchecked(&self, state_id: usize) -> bool { + self.deref().is_final_unchecked(state_id) + } } diff --git a/rustfst/src/algorithms/lazy/cache/cache_internal_types.rs b/rustfst/src/algorithms/lazy/cache/cache_internal_types.rs index 8f2304c67..632d41b28 100644 --- a/rustfst/src/algorithms/lazy/cache/cache_internal_types.rs +++ b/rustfst/src/algorithms/lazy/cache/cache_internal_types.rs @@ -62,6 +62,13 @@ impl CachedData>> { None => CacheStatus::NotComputed, } } + + pub unsafe fn get_unchecked(&self, idx: usize) -> &T { + match self.data.get_unchecked(idx) { + CacheStatus::Computed(e) => e, + CacheStatus::NotComputed => unreachable!(), + } + } } impl Default for CachedData> { diff --git a/rustfst/src/algorithms/lazy/cache/cache_status.rs b/rustfst/src/algorithms/lazy/cache/cache_status.rs index 7c6ae0892..cf8d00dc5 100644 --- a/rustfst/src/algorithms/lazy/cache/cache_status.rs +++ b/rustfst/src/algorithms/lazy/cache/cache_status.rs @@ -1,4 +1,4 @@ -#[derive(Debug, Clone, Copy)] +#[derive(Debug, Clone, Copy, PartialOrd, PartialEq)] pub enum CacheStatus { NotComputed, Computed(T), @@ -12,10 +12,38 @@ impl CacheStatus { } } + pub fn ok_or(self, err: E) -> Result { + match self { + CacheStatus::Computed(v) => Ok(v), + CacheStatus::NotComputed => Err(err), + } + } + + pub fn ok_or_else E>(self, err: F) -> Result { + match self { + CacheStatus::Computed(v) => Ok(v), + CacheStatus::NotComputed => Err(err()), + } + } + + pub fn unwrap(self) -> T { + match self { + CacheStatus::Computed(e) => e, + CacheStatus::NotComputed => unreachable!(), + } + } + pub fn into_option(self) -> Option { match self { CacheStatus::Computed(e) => Some(e), CacheStatus::NotComputed => None, } } + + pub fn is_computed(&self) -> bool { + match self { + CacheStatus::Computed(_) => true, + CacheStatus::NotComputed => false, + } + } } diff --git a/rustfst/src/algorithms/lazy/cache/first_cache.rs b/rustfst/src/algorithms/lazy/cache/first_cache.rs index b503bfbd2..b8db30104 100644 --- a/rustfst/src/algorithms/lazy/cache/first_cache.rs +++ b/rustfst/src/algorithms/lazy/cache/first_cache.rs @@ -68,14 +68,22 @@ impl> FstCache for FirstCache { self.cache.num_trs(id) } - fn num_input_epsilons(&self, id: usize) -> Option { + fn num_input_epsilons(&self, id: usize) -> CacheStatus { self.cache.num_input_epsilons(id) } - fn num_output_epsilons(&self, id: usize) -> Option { + unsafe fn num_input_epsilons_unchecked(&self, id: usize) -> usize { + self.cache.num_input_epsilons_unchecked(id) + } + + fn num_output_epsilons(&self, id: usize) -> CacheStatus { self.cache.num_output_epsilons(id) } + unsafe fn num_output_epsilons_unchecked(&self, id: usize) -> usize { + self.cache.num_output_epsilons_unchecked(id) + } + fn len_trs(&self) -> usize { self.cache.len_trs() } @@ -83,4 +91,12 @@ impl> FstCache for FirstCache { fn len_final_weights(&self) -> usize { self.cache.len_final_weights() } + + fn is_final(&self, state_id: usize) -> CacheStatus { + self.cache.is_final(state_id) + } + + unsafe fn is_final_unchecked(&self, state_id: usize) -> bool { + self.cache.is_final_unchecked(state_id) + } } diff --git a/rustfst/src/algorithms/lazy/cache/fst_cache.rs b/rustfst/src/algorithms/lazy/cache/fst_cache.rs index 179018f29..ca46bb015 100644 --- a/rustfst/src/algorithms/lazy/cache/fst_cache.rs +++ b/rustfst/src/algorithms/lazy/cache/fst_cache.rs @@ -1,6 +1,7 @@ use std::fmt::Debug; use crate::algorithms::lazy::CacheStatus; +use crate::fst_traits::MutableFst; use crate::semirings::Semiring; use crate::{StateId, TrsVec}; @@ -17,9 +18,20 @@ pub trait FstCache: Debug { fn num_known_states(&self) -> usize; fn num_trs(&self, id: StateId) -> Option; - fn num_input_epsilons(&self, id: usize) -> Option; - fn num_output_epsilons(&self, id: usize) -> Option; + fn num_input_epsilons(&self, id: usize) -> CacheStatus; + unsafe fn num_input_epsilons_unchecked(&self, id: usize) -> usize; + + fn num_output_epsilons(&self, id: usize) -> CacheStatus; + unsafe fn num_output_epsilons_unchecked(&self, id: usize) -> usize; fn len_trs(&self) -> usize; fn len_final_weights(&self) -> usize; + + fn is_final(&self, state_id: StateId) -> CacheStatus; + unsafe fn is_final_unchecked(&self, state_id: StateId) -> bool; +} + +pub trait FillableFstCache: FstCache { + /// Given a cache that is full. Turn it into a MutableFst. + fn into_fst>(self) -> F; } diff --git a/rustfst/src/algorithms/lazy/cache/mod.rs b/rustfst/src/algorithms/lazy/cache/mod.rs index 0fd314e32..af69eabd1 100644 --- a/rustfst/src/algorithms/lazy/cache/mod.rs +++ b/rustfst/src/algorithms/lazy/cache/mod.rs @@ -8,6 +8,6 @@ pub mod simple_vec_cache; pub use self::cache_status::CacheStatus; pub use self::first_cache::FirstCache; -pub use self::fst_cache::FstCache; +pub use self::fst_cache::{FstCache, FillableFstCache}; pub use self::simple_hash_map_cache::SimpleHashMapCache; pub use self::simple_vec_cache::SimpleVecCache; diff --git a/rustfst/src/algorithms/lazy/cache/simple_hash_map_cache.rs b/rustfst/src/algorithms/lazy/cache/simple_hash_map_cache.rs index ad3005bdc..48bce01da 100644 --- a/rustfst/src/algorithms/lazy/cache/simple_hash_map_cache.rs +++ b/rustfst/src/algorithms/lazy/cache/simple_hash_map_cache.rs @@ -5,6 +5,7 @@ use crate::algorithms::lazy::cache::cache_internal_types::{CachedData, StartStat use crate::algorithms::lazy::{CacheStatus, FstCache}; use crate::semirings::Semiring; use crate::{StateId, Trs, TrsVec, EPS_LABEL}; +use unsafe_unwrap::UnsafeUnwrap; #[derive(Debug)] pub struct SimpleHashMapCache { @@ -126,14 +127,30 @@ impl FstCache for SimpleHashMapCache { cached_data.data.get(&id).map(|v| v.trs.len()) } - fn num_input_epsilons(&self, id: usize) -> Option { + fn num_input_epsilons(&self, id: usize) -> CacheStatus { let cached_data = self.trs.lock().unwrap(); - cached_data.data.get(&id).map(|v| v.niepsilons) + match cached_data.data.get(&id) { + Some(e) => CacheStatus::Computed(e.niepsilons), + None => CacheStatus::NotComputed, + } + } + + unsafe fn num_input_epsilons_unchecked(&self, id: usize) -> usize { + let cached_data = self.trs.lock().unwrap(); + cached_data.data.get(&id).unsafe_unwrap().niepsilons + } + + fn num_output_epsilons(&self, id: usize) -> CacheStatus { + let cached_data = self.trs.lock().unwrap(); + match cached_data.data.get(&id) { + Some(e) => CacheStatus::Computed(e.noepsilons), + None => CacheStatus::NotComputed, + } } - fn num_output_epsilons(&self, id: usize) -> Option { + unsafe fn num_output_epsilons_unchecked(&self, id: usize) -> usize { let cached_data = self.trs.lock().unwrap(); - cached_data.data.get(&id).map(|v| v.noepsilons) + cached_data.data.get(&id).unsafe_unwrap().noepsilons } fn len_trs(&self) -> usize { @@ -145,4 +162,21 @@ impl FstCache for SimpleHashMapCache { let cached_data = self.final_weights.lock().unwrap(); cached_data.data.len() } + + fn is_final(&self, state_id: usize) -> CacheStatus { + match self.final_weights.lock().unwrap().data.get(&state_id) { + Some(e) => CacheStatus::Computed(e.is_some()), + None => CacheStatus::NotComputed, + } + } + + unsafe fn is_final_unchecked(&self, state_id: usize) -> bool { + self.final_weights + .lock() + .unwrap() + .data + .get(&state_id) + .unwrap() + .is_some() + } } diff --git a/rustfst/src/algorithms/lazy/cache/simple_vec_cache.rs b/rustfst/src/algorithms/lazy/cache/simple_vec_cache.rs index e1c35c5c0..c0b6a42b5 100644 --- a/rustfst/src/algorithms/lazy/cache/simple_vec_cache.rs +++ b/rustfst/src/algorithms/lazy/cache/simple_vec_cache.rs @@ -1,7 +1,9 @@ use std::sync::Mutex; use crate::algorithms::lazy::cache::cache_internal_types::{CachedData, FinalWeight, StartState}; +use crate::algorithms::lazy::cache::fst_cache::FillableFstCache; use crate::algorithms::lazy::{CacheStatus, FstCache}; +use crate::fst_traits::MutableFst; use crate::semirings::Semiring; use crate::{StateId, Trs, TrsVec, EPS_LABEL}; @@ -128,14 +130,24 @@ impl FstCache for SimpleVecCache { cached_data.get(id).map(|e| e.trs.len()).into_option() } - fn num_input_epsilons(&self, id: usize) -> Option { + fn num_input_epsilons(&self, id: usize) -> CacheStatus { let cached_data = self.trs.lock().unwrap(); - cached_data.get(id).map(|e| e.niepsilons).into_option() + cached_data.get(id).map(|e| e.niepsilons) } - fn num_output_epsilons(&self, id: usize) -> Option { + unsafe fn num_input_epsilons_unchecked(&self, id: usize) -> usize { let cached_data = self.trs.lock().unwrap(); - cached_data.get(id).map(|e| e.noepsilons).into_option() + cached_data.get_unchecked(id).niepsilons + } + + fn num_output_epsilons(&self, id: usize) -> CacheStatus { + let cached_data = self.trs.lock().unwrap(); + cached_data.get(id).map(|e| e.noepsilons) + } + + unsafe fn num_output_epsilons_unchecked(&self, id: usize) -> usize { + let cached_data = self.trs.lock().unwrap(); + cached_data.get_unchecked(id).noepsilons } fn len_trs(&self) -> usize { @@ -147,4 +159,61 @@ impl FstCache for SimpleVecCache { let cached_data = self.final_weights.lock().unwrap(); cached_data.data.len() } + + fn is_final(&self, state_id: usize) -> CacheStatus { + let cached_data = self.final_weights.lock().unwrap(); + match cached_data.data.get(state_id) { + Some(e) => match e { + CacheStatus::Computed(v) => CacheStatus::Computed(v.is_some()), + CacheStatus::NotComputed => CacheStatus::NotComputed, + }, + None => CacheStatus::NotComputed, + } + } + + unsafe fn is_final_unchecked(&self, state_id: usize) -> bool { + let cached_data = self.final_weights.lock().unwrap(); + match cached_data.data.get_unchecked(state_id) { + CacheStatus::Computed(e) => e.is_some(), + CacheStatus::NotComputed => unreachable!(), + } + } +} + +impl FillableFstCache for SimpleVecCache { + fn into_fst>(self) -> F { + let mut fst_out = F::new(); + + // Safe because computed + if let Some(start) = self.get_start().unwrap() { + let nstates = self.num_known_states(); + fst_out.add_states(nstates); + + unsafe { fst_out.set_start_unchecked(start) }; + + let final_weights = self.final_weights.into_inner().unwrap().data; + let trs = self.trs.into_inner().unwrap().data; + + for (state_id, cache_trs) in trs.into_iter().enumerate() { + let cache_trs = cache_trs.unwrap(); + unsafe { + fst_out.set_state_unchecked_noprops( + state_id, + cache_trs.trs, + cache_trs.niepsilons, + cache_trs.noepsilons, + ) + }; + } + + for (state_id, final_weight) in final_weights.into_iter().enumerate() { + // Safe as computed + if let Some(final_weight) = final_weight.unwrap() { + unsafe { fst_out.set_final_unchecked(state_id, final_weight) }; + } + } + } + + fst_out + } } diff --git a/rustfst/src/algorithms/lazy/fst_op.rs b/rustfst/src/algorithms/lazy/fst_op.rs index 47d23aee9..40641039c 100644 --- a/rustfst/src/algorithms/lazy/fst_op.rs +++ b/rustfst/src/algorithms/lazy/fst_op.rs @@ -9,7 +9,7 @@ use crate::{StateId, TrsVec}; pub trait FstOp: Debug { // was FstImpl fn compute_start(&self) -> Result>; - fn compute_trs(&self, id: usize) -> Result>; + fn compute_trs(&self, id: StateId) -> Result>; fn compute_final_weight(&self, id: StateId) -> Result>; // Computed at construction time diff --git a/rustfst/src/algorithms/lazy/lazy_fst.rs b/rustfst/src/algorithms/lazy/lazy_fst.rs index b5b3befe2..8dcad0f60 100644 --- a/rustfst/src/algorithms/lazy/lazy_fst.rs +++ b/rustfst/src/algorithms/lazy/lazy_fst.rs @@ -7,6 +7,7 @@ use anyhow::Result; use itertools::izip; use unsafe_unwrap::UnsafeUnwrap; +use crate::algorithms::lazy::cache::fst_cache::FillableFstCache; use crate::algorithms::lazy::cache::CacheStatus; use crate::algorithms::lazy::fst_op::FstOp; use crate::algorithms::lazy::FstCache; @@ -67,6 +68,16 @@ impl, Cache: FstCache> CoreFst for LazyFst Result { + self.cache + .is_final(state_id) + .ok_or_else(|| format_err!("Final weight for state {} not computed yet", state_id)) + } + + unsafe fn is_final_unchecked(&self, state_id: usize) -> bool { + self.cache.is_final_unchecked(state_id) + } + fn get_trs(&self, state_id: usize) -> Result { match self.cache.get_trs(state_id) { CacheStatus::Computed(trs) => Ok(trs), @@ -92,11 +103,19 @@ impl, Cache: FstCache> CoreFst for LazyFst usize { + self.cache.num_input_epsilons_unchecked(state) + } + fn num_output_epsilons(&self, state: usize) -> Result { self.cache .num_output_epsilons(state) .ok_or_else(|| format_err!("State {:?} doesn't exist", state)) } + + unsafe fn num_output_epsilons_unchecked(&self, state: usize) -> usize { + self.cache.num_output_epsilons_unchecked(state) + } } impl<'a, W, Op, Cache> StateIterator<'a> for LazyFst @@ -249,12 +268,56 @@ where } } unsafe { fst_out.set_trs_unchecked(s, trs_owner.trs().to_vec()) }; - if let Some(f_w) = self.final_weight(s)? { - fst_out.set_final(s, f_w)?; + if let Some(f_w) = unsafe { self.final_weight_unchecked(s) } { + unsafe { fst_out.set_final_unchecked(s, f_w) }; } } fst_out.set_properties(self.properties()); // TODO: Symbol tables should be set here Ok(fst_out) } + + fn fill_cache(&self) { + let start_state = self.start(); + if start_state.is_none() { + return; + } + let start_state = start_state.unwrap(); + let mut queue = VecDeque::new(); + let mut visited_states = vec![]; + visited_states.resize(start_state + 1, false); + visited_states[start_state] = true; + queue.push_back(start_state); + while !queue.is_empty() { + let s = queue.pop_front().unwrap(); + let trs_owner = unsafe { self.get_trs_unchecked(s) }; + for tr in trs_owner.trs() { + if tr.nextstate >= visited_states.len() { + visited_states.resize(tr.nextstate + 1, false); + } + if !visited_states[tr.nextstate] { + queue.push_back(tr.nextstate); + visited_states[tr.nextstate] = true; + } + } + + // Force computation final weight + unsafe { self.final_weight_unchecked(s) }; + } + } + + /// Turns the Lazy FST into a static one. + pub fn into_static_fst + AllocableFst>(self) -> F2 + where + Cache: FillableFstCache, + { + self.fill_cache(); + + let props = self.properties(); + let mut fst: F2 = self.cache.into_fst(); + + fst.set_properties(props); + + fst + } } diff --git a/rustfst/src/algorithms/lazy/lazy_fst_2.rs b/rustfst/src/algorithms/lazy/lazy_fst_2.rs index 68f0e7a2c..dcec09a65 100644 --- a/rustfst/src/algorithms/lazy/lazy_fst_2.rs +++ b/rustfst/src/algorithms/lazy/lazy_fst_2.rs @@ -65,6 +65,19 @@ impl, Cache: FstCache> CoreFst for LazyFst2 Result { + match self.cache.is_final(state_id) { + CacheStatus::Computed(e) => Ok(e), + CacheStatus::NotComputed => { + bail!("Final weight for state {} not computed yet", state_id) + } + } + } + + unsafe fn is_final_unchecked(&self, state_id: usize) -> bool { + self.cache.is_final_unchecked(state_id) + } + fn get_trs(&self, state_id: usize) -> Result { match self.cache.get_trs(state_id) { CacheStatus::Computed(trs) => Ok(trs), @@ -91,11 +104,19 @@ impl, Cache: FstCache> CoreFst for LazyFst2 usize { + self.cache.num_input_epsilons_unchecked(state) + } + fn num_output_epsilons(&self, state: usize) -> Result { self.cache .num_output_epsilons(state) .ok_or_else(|| format_err!("State {:?} doesn't exist", state)) } + + unsafe fn num_output_epsilons_unchecked(&self, state: usize) -> usize { + self.cache.num_output_epsilons_unchecked(state) + } } impl<'a, W, Op, Cache> StateIterator<'a> for LazyFst2 diff --git a/rustfst/src/algorithms/replace/replace_fst.rs b/rustfst/src/algorithms/replace/replace_fst.rs index 355df5fa1..23e8a930e 100644 --- a/rustfst/src/algorithms/replace/replace_fst.rs +++ b/rustfst/src/algorithms/replace/replace_fst.rs @@ -74,6 +74,14 @@ where self.0.num_trs_unchecked(s) } + fn is_final(&self, state_id: usize) -> Result { + self.0.is_final(state_id) + } + + unsafe fn is_final_unchecked(&self, state_id: usize) -> bool { + self.0.is_final_unchecked(state_id) + } + fn get_trs(&self, state_id: usize) -> Result { self.0.get_trs(state_id) } @@ -90,9 +98,17 @@ where self.0.num_input_epsilons(state) } + unsafe fn num_input_epsilons_unchecked(&self, state: usize) -> usize { + self.0.num_input_epsilons_unchecked(state) + } + fn num_output_epsilons(&self, state: usize) -> Result { self.0.num_output_epsilons(state) } + + unsafe fn num_output_epsilons_unchecked(&self, state: usize) -> usize { + self.0.num_output_epsilons_unchecked(state) + } } impl<'a, W, F, B> StateIterator<'a> for ReplaceFst diff --git a/rustfst/src/algorithms/rm_epsilon/rm_epsilon_fst.rs b/rustfst/src/algorithms/rm_epsilon/rm_epsilon_fst.rs index 52ad4bd5d..57a299e6b 100644 --- a/rustfst/src/algorithms/rm_epsilon/rm_epsilon_fst.rs +++ b/rustfst/src/algorithms/rm_epsilon/rm_epsilon_fst.rs @@ -60,6 +60,14 @@ where self.0.num_trs_unchecked(s) } + fn is_final(&self, state_id: usize) -> Result { + self.0.is_final(state_id) + } + + unsafe fn is_final_unchecked(&self, state_id: usize) -> bool { + self.0.is_final_unchecked(state_id) + } + fn get_trs(&self, state_id: usize) -> Result { self.0.get_trs(state_id) } @@ -76,9 +84,17 @@ where self.0.num_input_epsilons(state) } + unsafe fn num_input_epsilons_unchecked(&self, state: usize) -> usize { + self.0.num_input_epsilons_unchecked(state) + } + fn num_output_epsilons(&self, state: usize) -> Result { self.0.num_output_epsilons(state) } + + unsafe fn num_output_epsilons_unchecked(&self, state: usize) -> usize { + self.0.num_output_epsilons_unchecked(state) + } } impl<'a, W, F, B> StateIterator<'a> for RmEpsilonFst diff --git a/rustfst/src/algorithms/union/union_fst.rs b/rustfst/src/algorithms/union/union_fst.rs index 2096db2b3..0de23c759 100644 --- a/rustfst/src/algorithms/union/union_fst.rs +++ b/rustfst/src/algorithms/union/union_fst.rs @@ -84,6 +84,14 @@ where self.0.num_trs_unchecked(s) } + fn is_final(&self, state_id: usize) -> Result { + self.0.is_final(state_id) + } + + unsafe fn is_final_unchecked(&self, state_id: usize) -> bool { + self.0.is_final_unchecked(state_id) + } + fn get_trs(&self, state_id: usize) -> Result { self.0.get_trs(state_id) } @@ -100,9 +108,17 @@ where self.0.num_input_epsilons(state) } + unsafe fn num_input_epsilons_unchecked(&self, state: usize) -> usize { + self.0.num_input_epsilons_unchecked(state) + } + fn num_output_epsilons(&self, state: usize) -> Result { self.0.num_output_epsilons(state) } + + unsafe fn num_output_epsilons_unchecked(&self, state: usize) -> usize { + self.0.num_output_epsilons_unchecked(state) + } } impl<'a, W, F> StateIterator<'a> for UnionFst diff --git a/rustfst/src/fst_impls/arc.rs b/rustfst/src/fst_impls/arc.rs index c223c61d0..7209dbfda 100644 --- a/rustfst/src/fst_impls/arc.rs +++ b/rustfst/src/fst_impls/arc.rs @@ -63,6 +63,14 @@ impl> CoreFst for Arc { self.deref().num_trs_unchecked(s) } + fn is_final(&self, state_id: usize) -> Result { + self.deref().is_final(state_id) + } + + unsafe fn is_final_unchecked(&self, state_id: usize) -> bool { + self.deref().is_final_unchecked(state_id) + } + fn get_trs(&self, state_id: usize) -> Result { self.deref().get_trs(state_id) } @@ -79,9 +87,17 @@ impl> CoreFst for Arc { self.deref().num_input_epsilons(state) } + unsafe fn num_input_epsilons_unchecked(&self, state: usize) -> usize { + self.deref().num_input_epsilons_unchecked(state) + } + fn num_output_epsilons(&self, state: usize) -> Result { self.deref().num_output_epsilons(state) } + + unsafe fn num_output_epsilons_unchecked(&self, state: usize) -> usize { + self.deref().num_output_epsilons_unchecked(state) + } } impl<'a, W: Semiring + 'a, F: FstIterator<'a, W>> FstIterator<'a, W> for Arc { diff --git a/rustfst/src/fst_impls/const_fst/fst.rs b/rustfst/src/fst_impls/const_fst/fst.rs index 23ea55ecb..2191db373 100644 --- a/rustfst/src/fst_impls/const_fst/fst.rs +++ b/rustfst/src/fst_impls/const_fst/fst.rs @@ -65,6 +65,19 @@ impl CoreFst for ConstFst { self.states.get_unchecked(s).ntrs } + fn is_final(&self, state_id: usize) -> Result { + let s = self + .states + .get(state_id) + .ok_or_else(|| format_err!("State {:?} doesn't exist", state_id))?; + Ok(s.final_weight.is_some()) + } + + unsafe fn is_final_unchecked(&self, state_id: usize) -> bool { + let s = self.states.get_unchecked(state_id); + s.final_weight.is_some() + } + fn get_trs(&self, state_id: usize) -> Result { let state = self .states @@ -98,6 +111,10 @@ impl CoreFst for ConstFst { .niepsilons) } + unsafe fn num_input_epsilons_unchecked(&self, state: usize) -> usize { + self.states.get_unchecked(state).niepsilons + } + fn num_output_epsilons(&self, state: usize) -> Result { Ok(self .states @@ -105,4 +122,8 @@ impl CoreFst for ConstFst { .ok_or_else(|| format_err!("State {:?} doesn't exist", state))? .noepsilons) } + + unsafe fn num_output_epsilons_unchecked(&self, state: usize) -> usize { + self.states.get_unchecked(state).noepsilons + } } diff --git a/rustfst/src/fst_impls/vector_fst/fst.rs b/rustfst/src/fst_impls/vector_fst/fst.rs index 8272f2464..268bc7e09 100644 --- a/rustfst/src/fst_impls/vector_fst/fst.rs +++ b/rustfst/src/fst_impls/vector_fst/fst.rs @@ -67,6 +67,19 @@ impl CoreFst for VectorFst { self.states.get_unchecked(s).trs.len() } + fn is_final(&self, state_id: usize) -> Result { + let s = self + .states + .get(state_id) + .ok_or_else(|| format_err!("State {:?} doesn't exist", state_id))?; + Ok(s.final_weight.is_some()) + } + + unsafe fn is_final_unchecked(&self, state_id: usize) -> bool { + let s = self.states.get_unchecked(state_id); + s.final_weight.is_some() + } + fn get_trs(&self, state_id: usize) -> Result { let state = self .states @@ -94,6 +107,10 @@ impl CoreFst for VectorFst { .niepsilons) } + unsafe fn num_input_epsilons_unchecked(&self, state: usize) -> usize { + self.states.get_unchecked(state).niepsilons + } + fn num_output_epsilons(&self, state: usize) -> Result { Ok(self .states @@ -101,4 +118,8 @@ impl CoreFst for VectorFst { .ok_or_else(|| format_err!("State {:?} doesn't exist", state))? .noepsilons) } + + unsafe fn num_output_epsilons_unchecked(&self, state: usize) -> usize { + self.states.get_unchecked(state).noepsilons + } } diff --git a/rustfst/src/fst_impls/vector_fst/mutable_fst.rs b/rustfst/src/fst_impls/vector_fst/mutable_fst.rs index 138672d17..51c562f87 100644 --- a/rustfst/src/fst_impls/vector_fst/mutable_fst.rs +++ b/rustfst/src/fst_impls/vector_fst/mutable_fst.rs @@ -14,7 +14,7 @@ use crate::fst_traits::CoreFst; use crate::fst_traits::MutableFst; use crate::semirings::Semiring; use crate::trs_iter_mut::TrsIterMut; -use crate::{StateId, Tr, Trs, EPS_LABEL}; +use crate::{StateId, Tr, Trs, TrsVec, EPS_LABEL}; #[inline] fn equal_tr(tr_1: &Tr, tr_2: &Tr) -> bool { @@ -410,4 +410,17 @@ impl MutableFst for VectorFst { self.properties &= !mask; self.properties |= props & mask; } + + unsafe fn set_state_unchecked_noprops( + &mut self, + source: usize, + trs: TrsVec, + niepsilons: usize, + noepsilons: usize, + ) { + let state = &mut self.states.get_unchecked_mut(source); + state.trs = trs; + state.niepsilons = niepsilons; + state.noepsilons = noepsilons; + } } diff --git a/rustfst/src/fst_traits/fst.rs b/rustfst/src/fst_traits/fst.rs index 49b44dddf..16abcd0cb 100644 --- a/rustfst/src/fst_traits/fst.rs +++ b/rustfst/src/fst_traits/fst.rs @@ -110,11 +110,7 @@ pub trait CoreFst { /// assert_eq!(fst.is_final(s2).unwrap(), true); /// assert!(fst.is_final(s2 + 1).is_err()); /// ``` - #[inline] - fn is_final(&self, state_id: StateId) -> Result { - let w = self.final_weight(state_id)?; - Ok(w.is_some()) - } + fn is_final(&self, state_id: StateId) -> Result; /// Returns whether or not the state with identifier passed as parameters is a final state. /// @@ -122,10 +118,7 @@ pub trait CoreFst { /// /// Unsafe behaviour if `state` is not present in Fst. /// - #[inline] - unsafe fn is_final_unchecked(&self, state: StateId) -> bool { - self.final_weight_unchecked(state).is_some() - } + unsafe fn is_final_unchecked(&self, state_id: StateId) -> bool; /// Check whether a state is the start state or not. #[inline] @@ -194,6 +187,14 @@ pub trait CoreFst { /// ``` fn num_input_epsilons(&self, state: StateId) -> Result; + /// Returns the number of trs with epsilon input labels leaving a state. + /// + /// # Safety + /// + /// Unsafe behaviour if `state` is not present in Fst. + /// + unsafe fn num_input_epsilons_unchecked(&self, state: StateId) -> usize; + /// Returns the number of trs with epsilon output labels leaving a state. /// /// # Example : @@ -217,6 +218,14 @@ pub trait CoreFst { /// assert_eq!(fst.num_output_epsilons(s1).unwrap(), 0); /// ``` fn num_output_epsilons(&self, state: StateId) -> Result; + + /// Returns the number of trs with epsilon output labels leaving a state. + /// + /// # Safety + /// + /// Unsafe behaviour if `state` is not present in Fst. + /// + unsafe fn num_output_epsilons_unchecked(&self, state: StateId) -> usize; } /// Trait defining the minimum interface necessary for a wFST. diff --git a/rustfst/src/fst_traits/mutable_fst.rs b/rustfst/src/fst_traits/mutable_fst.rs index 6ae20c088..1760553de 100644 --- a/rustfst/src/fst_traits/mutable_fst.rs +++ b/rustfst/src/fst_traits/mutable_fst.rs @@ -9,7 +9,7 @@ use crate::fst_traits::ExpandedFst; use crate::semirings::Semiring; use crate::tr::Tr; use crate::trs_iter_mut::TrsIterMut; -use crate::{Label, StateId}; +use crate::{Label, StateId, TrsVec}; /// Trait defining the methods to modify a wFST. pub trait MutableFst: ExpandedFst { @@ -444,4 +444,12 @@ pub trait MutableFst: ExpandedFst { fn compute_and_update_properties_all(&mut self) -> Result { self.compute_and_update_properties(FstProperties::all_properties()) } + + unsafe fn set_state_unchecked_noprops( + &mut self, + source: StateId, + trs: TrsVec, + niepsilons: usize, + noepsilons: usize, + ); }