Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,11 @@ class CardinalFst(GraphFst):

Args:
input_case: accepting either "lower_cased" or "cased" input.
project: if True, adds input projection for mapping original text.
"""

def __init__(self, input_case: str = INPUT_LOWER_CASED):
super().__init__(name="cardinal", kind="classify")
def __init__(self, input_case: str = INPUT_LOWER_CASED, project_input: bool = False):
super().__init__(name="cardinal", kind="classify", project_input=project_input)
self.input_case = input_case
graph_zero = pynini.string_file(get_abs_path("data/numbers/zero.tsv"))
graph_digit = pynini.string_file(get_abs_path("data/numbers/digit.tsv"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,8 @@ class DateFst(GraphFst):
input_case: accepting either "lower_cased" or "cased" input.
"""

def __init__(self, ordinal: GraphFst, input_case: str):
super().__init__(name="date", kind="classify")
def __init__(self, ordinal: GraphFst, input_case: str, project_input: bool = False):
super().__init__(name="date", kind="classify", project_input=project_input)

ordinal_graph = ordinal.graph
year_graph = _get_year_graph(input_case=input_case)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,8 @@ class DecimalFst(GraphFst):
input_case: accepting either "lower_cased" or "cased" input.
"""

def __init__(self, cardinal: GraphFst, input_case: str = INPUT_LOWER_CASED):
super().__init__(name="decimal", kind="classify")
def __init__(self, cardinal: GraphFst, input_case: str = INPUT_LOWER_CASED, project_input: bool = False):
super().__init__(name="decimal", kind="classify", project_input=project_input)

cardinal_graph = cardinal.graph_no_exception

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ class ElectronicFst(GraphFst):
input_case: accepting either "lower_cased" or "cased" input.
"""

def __init__(self, input_case: str = INPUT_LOWER_CASED):
super().__init__(name="electronic", kind="classify")
def __init__(self, input_case: str = INPUT_LOWER_CASED, project_input: bool = False):
super().__init__(name="electronic", kind="classify", project_input=project_input)

delete_extra_space = pynutil.delete(" ")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,6 @@ class FractionFst(GraphFst):
input_case: accepting either "lower_cased" or "cased" input.
"""

def __init__(self, input_case: str = INPUT_LOWER_CASED):
super().__init__(name="fraction", kind="classify")
def __init__(self, input_case: str = INPUT_LOWER_CASED, project_input: bool = False):
super().__init__(name="fraction", kind="classify", project_input=project_input)
# integer_part # numerator # denominator
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,10 @@ class MeasureFst(GraphFst):
input_case: accepting either "lower_cased" or "cased" input.
"""

def __init__(self, cardinal: GraphFst, decimal: GraphFst, input_case: str = INPUT_LOWER_CASED):
super().__init__(name="measure", kind="classify")
def __init__(
self, cardinal: GraphFst, decimal: GraphFst, input_case: str = INPUT_LOWER_CASED, project_input: bool = False
):
super().__init__(name="measure", kind="classify", project_input=project_input)

cardinal_graph = cardinal.graph_no_exception

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,10 @@ class MoneyFst(GraphFst):
input_case: accepting either "lower_cased" or "cased" input.
"""

def __init__(self, cardinal: GraphFst, decimal: GraphFst, input_case: str = INPUT_LOWER_CASED):
super().__init__(name="money", kind="classify")
def __init__(
self, cardinal: GraphFst, decimal: GraphFst, input_case: str = INPUT_LOWER_CASED, project_input: bool = False
):
super().__init__(name="money", kind="classify", project_input=project_input)
# quantity, integer_part, fractional_part, currency

cardinal_graph = cardinal.graph_no_exception
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ class OrdinalFst(GraphFst):
input_case: accepting either "lower_cased" or "cased" input.
"""

def __init__(self, cardinal: GraphFst, input_case: str = INPUT_LOWER_CASED):
super().__init__(name="ordinal", kind="classify")
def __init__(self, cardinal: GraphFst, input_case: str = INPUT_LOWER_CASED, project_input: bool = False):
super().__init__(name="ordinal", kind="classify", project_input=project_input)

cardinal_graph = cardinal.graph_no_exception
graph_digit = pynini.string_file(get_abs_path("data/ordinals/digit.tsv"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ class PunctuationFst(GraphFst):
e.g. a, -> tokens { name: "a" } tokens { name: "," }
"""

def __init__(self):
super().__init__(name="punctuation", kind="classify")
def __init__(self, project_input: bool = False):
super().__init__(name="punctuation", kind="classify", project_input=project_input)

s = "!#$%&\'()*+,-./:;<=>?@^_`{|}~"
punct = pynini.union(*s)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,8 @@ class TelephoneFst(GraphFst):
input_case: accepting either "lower_cased" or "cased" input.
"""

def __init__(self, cardinal: GraphFst, input_case: str = INPUT_LOWER_CASED):
super().__init__(name="telephone", kind="classify")
def __init__(self, cardinal: GraphFst, input_case: str = INPUT_LOWER_CASED, project_input: bool = False):
super().__init__(name="telephone", kind="classify", project_input=project_input)
# country code, number_part, extension
digit_to_str = (
pynini.invert(pynini.string_file(get_abs_path("data/numbers/digit.tsv")).optimize())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ class TimeFst(GraphFst):
e.g. half past two -> time { hours: "2" minutes: "30" }
"""

def __init__(self, input_case: str = INPUT_LOWER_CASED):
super().__init__(name="time", kind="classify")
def __init__(self, input_case: str = INPUT_LOWER_CASED, project_input: bool = False):
super().__init__(name="time", kind="classify", project_input=project_input)
# hours, minutes, seconds, suffix, zone, style, speak_period

suffix_graph = pynini.string_file(get_abs_path("data/time/time_suffix.tsv"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
GraphFst,
delete_extra_space,
delete_space,
generate_far_filename,
generator_main,
)
from nemo_text_processing.utils.logging import logger
Expand All @@ -49,6 +50,7 @@ class ClassifyFst(GraphFst):
Args:
input_case: accepting either "lower_cased" or "cased" input.
cache_dir: path to a dir with .far grammar file. Set to None to avoid using cache.
project: if True, adds input projection for mapping original text.
overwrite_cache: set to True to overwrite .far files
whitelist: path to a file with whitelist replacements
"""
Expand All @@ -57,6 +59,7 @@ def __init__(
self,
input_case: str = INPUT_LOWER_CASED,
cache_dir: str = None,
project_input: bool = False,
overwrite_cache: bool = False,
whitelist: str = None,
):
Expand All @@ -65,30 +68,44 @@ def __init__(
far_file = None
if cache_dir is not None and cache_dir != "None":
os.makedirs(cache_dir, exist_ok=True)
far_file = os.path.join(cache_dir, f"en_itn_{input_case}.far")
far_file = generate_far_filename(
language="en",
mode="itn",
cache_dir=cache_dir,
operation="tokenize_and_classify",
project_input=project_input,
input_case=input_case,
whitelist_file=whitelist,
)
if not overwrite_cache and far_file and os.path.exists(far_file):
self.fst = pynini.Far(far_file, mode="r")["tokenize_and_classify"]
logger.info(f"ClassifyFst.fst was restored from {far_file}.")
else:
logger.info(f"Creating ClassifyFst grammars.")
cardinal = CardinalFst(input_case=input_case)
cardinal = CardinalFst(input_case=input_case, project_input=project_input)
cardinal_graph = cardinal.fst

ordinal = OrdinalFst(cardinal, input_case=input_case)
ordinal = OrdinalFst(cardinal, input_case=input_case, project_input=project_input)
ordinal_graph = ordinal.fst

decimal = DecimalFst(cardinal, input_case=input_case)
decimal = DecimalFst(cardinal, input_case=input_case, project_input=project_input)
decimal_graph = decimal.fst

measure_graph = MeasureFst(cardinal=cardinal, decimal=decimal, input_case=input_case).fst
date_graph = DateFst(ordinal=ordinal, input_case=input_case).fst
word_graph = WordFst().fst
time_graph = TimeFst(input_case=input_case).fst
money_graph = MoneyFst(cardinal=cardinal, decimal=decimal, input_case=input_case).fst
whitelist_graph = WhiteListFst(input_file=whitelist, input_case=input_case).fst
punct_graph = PunctuationFst().fst
electronic_graph = ElectronicFst(input_case=input_case).fst
telephone_graph = TelephoneFst(cardinal, input_case=input_case).fst
measure_graph = MeasureFst(
cardinal=cardinal, decimal=decimal, input_case=input_case, project_input=project_input
).fst
date_graph = DateFst(ordinal=ordinal, input_case=input_case, project_input=project_input).fst
word_graph = WordFst(project_input=project_input).fst
time_graph = TimeFst(input_case=input_case, project_input=project_input).fst
money_graph = MoneyFst(
cardinal=cardinal, decimal=decimal, input_case=input_case, project_input=project_input
).fst
whitelist_graph = WhiteListFst(
input_file=whitelist, input_case=input_case, project_input=project_input
).fst
punct_graph = PunctuationFst(project_input=project_input).fst
electronic_graph = ElectronicFst(input_case=input_case, project_input=project_input).fst
telephone_graph = TelephoneFst(cardinal, input_case=input_case, project_input=project_input).fst

classify = (
pynutil.add_weight(whitelist_graph, 1.01)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,15 @@

import os

import pynini
from pynini.lib import pynutil

from nemo_text_processing.inverse_text_normalization.en.utils import get_abs_path
from nemo_text_processing.text_normalization.en.graph_utils import (
INPUT_CASED,
INPUT_LOWER_CASED,
GraphFst,
convert_space,
string_map_cased,
)
from nemo_text_processing.text_normalization.en.utils import load_labels


class WhiteListFst(GraphFst):
Expand All @@ -43,8 +40,8 @@ class WhiteListFst(GraphFst):
input_case: accepting either "lower_cased" or "cased" input.
"""

def __init__(self, input_case: str = INPUT_LOWER_CASED, input_file: str = None):
super().__init__(name="whitelist", kind="classify")
def __init__(self, input_case: str = INPUT_LOWER_CASED, input_file: str = None, project_input: bool = False):
super().__init__(name="whitelist", kind="classify", project_input=project_input)

if input_file is None:
input_file = get_abs_path("data/whitelist.tsv")
Expand All @@ -54,4 +51,5 @@ def __init__(self, input_case: str = INPUT_LOWER_CASED, input_file: str = None):

whitelist = string_map_cased(input_file, input_case)
graph = pynutil.insert("name: \"") + convert_space(whitelist) + pynutil.insert("\"")
self.fst = graph.optimize()
final_graph = self.add_tokens(graph)
self.fst = final_graph.optimize()
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class WordFst(GraphFst):
e.g. sleep -> tokens { name: "sleep" }
"""

def __init__(self):
super().__init__(name="word", kind="classify")
def __init__(self, project_input: bool = False):
super().__init__(name="word", kind="classify", project_input=project_input)
word = pynutil.insert("name: \"") + pynini.closure(NEMO_NOT_SPACE, 1) + pynutil.insert("\"")
self.fst = word.optimize()
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,13 @@ class CardinalFst(GraphFst):
"""
Finite state transducer for verbalizing cardinal
e.g. cardinal { integer: "23" negative: "-" } -> -23

Args:
project: if True, adds input projection for mapping original text.
"""

def __init__(self):
super().__init__(name="cardinal", kind="verbalize")
def __init__(self, project_input: bool = False):
super().__init__(name="cardinal", kind="verbalize", project_input=project_input)
optional_sign = pynini.closure(
pynutil.delete("negative:")
+ delete_space
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ class DateFst(GraphFst):
date { day: "5" month: "january" year: "2012" preserve_order: true } -> 5 february 2012
"""

def __init__(self):
super().__init__(name="date", kind="verbalize")
def __init__(self, project_input: bool = False):
super().__init__(name="date", kind="verbalize", project_input=project_input)
month = (
pynutil.delete("month:")
+ delete_space
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ class DecimalFst(GraphFst):
decimal { negative: "true" integer_part: "12" fractional_part: "5006" quantity: "billion" } -> -12.5006 billion
"""

def __init__(self):
super().__init__(name="decimal", kind="verbalize")
def __init__(self, project_input: bool = False):
super().__init__(name="decimal", kind="verbalize", project_input=project_input)
optionl_sign = pynini.closure(pynini.cross("negative: \"true\"", "-") + delete_space, 0, 1)
integer = (
pynutil.delete("integer_part:")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ class ElectronicFst(GraphFst):
e.g. tokens { electronic { username: "cdf1" domain: "abc.edu" } } -> [email protected]
"""

def __init__(self):
super().__init__(name="electronic", kind="verbalize")
def __init__(self, project_input: bool = False):
super().__init__(name="electronic", kind="verbalize", project_input=project_input)
user_name = (
pynutil.delete("username:")
+ delete_space
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,5 @@ class FractionFst(GraphFst):
Finite state transducer for verbalizing fraction,
"""

def __init__(self):
super().__init__(name="fraction", kind="verbalize")
def __init__(self, project_input: bool = False):
super().__init__(name="fraction", kind="verbalize", project_input=project_input)
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ class MeasureFst(GraphFst):
cardinal: CardinalFst
"""

def __init__(self, decimal: GraphFst, cardinal: GraphFst):
super().__init__(name="measure", kind="verbalize")
def __init__(self, decimal: GraphFst, cardinal: GraphFst, project_input: bool = False):
super().__init__(name="measure", kind="verbalize", project_input=project_input)
optional_sign = pynini.closure(pynini.cross("negative: \"true\"", "-"), 0, 1)
unit = (
pynutil.delete("units:")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ class MoneyFst(GraphFst):
decimal: DecimalFst
"""

def __init__(self, decimal: GraphFst):
super().__init__(name="money", kind="verbalize")
def __init__(self, decimal: GraphFst, project_input: bool = False):
super().__init__(name="money", kind="verbalize", project_input=project_input)
unit = (
pynutil.delete("currency:")
+ delete_space
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ class OrdinalFst(GraphFst):
ordinal { integer: "13" } -> 13th
"""

def __init__(self):
super().__init__(name="ordinal", kind="verbalize")
def __init__(self, project_input: bool = False):
super().__init__(name="ordinal", kind="verbalize", project_input=project_input)
graph = (
pynutil.delete("integer:")
+ delete_space
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ class TelephoneFst(GraphFst):
-> 123-123-5678
"""

def __init__(self):
super().__init__(name="telephone", kind="verbalize")
def __init__(self, project_input: bool = False):
super().__init__(name="telephone", kind="verbalize", project_input=project_input)

number_part = pynutil.delete("number_part: \"") + pynini.closure(NEMO_NOT_QUOTE, 1) + pynutil.delete("\"")
optional_country_code = pynini.closure(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ class TimeFst(GraphFst):
time { hours: "2" suffix: "a.m." } -> 02:00 a.m.
"""

def __init__(self):
super().__init__(name="time", kind="verbalize")
def __init__(self, project_input: bool = False):
super().__init__(name="time", kind="verbalize", project_input=project_input)
add_leading_zero_to_double_digit = (NEMO_DIGIT + NEMO_DIGIT) | (pynutil.insert("0") + NEMO_DIGIT)
hour = (
pynutil.delete("hours:")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,20 +33,20 @@ class VerbalizeFst(GraphFst):
More details to deployment at NeMo/tools/text_processing_deployment.
"""

def __init__(self):
def __init__(self, project_input: bool = False):
super().__init__(name="verbalize", kind="verbalize")
cardinal = CardinalFst()
cardinal = CardinalFst(project_input=project_input)
cardinal_graph = cardinal.fst
ordinal_graph = OrdinalFst().fst
decimal = DecimalFst()
ordinal_graph = OrdinalFst(project_input=project_input).fst
decimal = DecimalFst(project_input=project_input)
decimal_graph = decimal.fst
measure_graph = MeasureFst(decimal=decimal, cardinal=cardinal).fst
money_graph = MoneyFst(decimal=decimal).fst
time_graph = TimeFst().fst
date_graph = DateFst().fst
whitelist_graph = WhiteListFst().fst
telephone_graph = TelephoneFst().fst
electronic_graph = ElectronicFst().fst
measure_graph = MeasureFst(decimal=decimal, cardinal=cardinal, project_input=project_input).fst
money_graph = MoneyFst(decimal=decimal, project_input=project_input).fst
time_graph = TimeFst(project_input=project_input).fst
date_graph = DateFst(project_input=project_input).fst
whitelist_graph = WhiteListFst(project_input=project_input).fst
telephone_graph = TelephoneFst(project_input=project_input).fst
electronic_graph = ElectronicFst(project_input=project_input).fst
graph = (
time_graph
| date_graph
Expand Down
Loading