Skip to content

fastspeech models #4337

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions paddlex/configs/modules/text_to_pinyin/G2PWModel.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
Global:
model: G2PWModel
mode: predict # only support predict
device: gpu:0
output: "output"

Predict:
batch_size: 1
input: "欢迎使用飞桨"
kernel_option:
run_mode: paddle
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
Global:
model: fastspeech2_csmsc
mode: predict # only support predict
device: gpu:0
use_trt: False
use_mkldnn: False
cpu_threads: 1
precision: "fp32"
output: "output"
model_name: "fastspeech2_csmsc"
speaker_dict: None
lang: zh
speaker_id: 0

Predict:
batch_size: 1
model_dir: "fastspeech2csmsc"
input: "今天天气真不错"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

input 应该是phone?

lang: zh
speaker_id: 0
kernel_option:
run_mode: paddle
22 changes: 22 additions & 0 deletions paddlex/configs/modules/text_to_speech_vocoder/pwgan_csmsc.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
Global:
model: pwgan_csmsc
mode: predict # only support predict
device: gpu:0
use_trt: False
use_mkldnn: False
cpu_threads: 1
precision: "fp32"
output: "output"
model_name: "pwgan_csmsc"
speaker_dict: None
lang: zh
speaker_id: 0

Predict:
batch_size: 1
model_dir: "pwgan_csmsc"
input: "今天天气真不错"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

input应该是npy或者tensor?

lang: zh
speaker_id: 0
kernel_option:
run_mode: paddle
1 change: 1 addition & 0 deletions paddlex/inference/common/batch_sampler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,4 @@
from .markdown_batch_sampler import MarkDownBatchSampler
from .ts_batch_sampler import TSBatchSampler
from .video_batch_sampler import VideoBatchSampler
from .text_batch_sampler import TextBatchSampler
61 changes: 61 additions & 0 deletions paddlex/inference/common/batch_sampler/text_batch_sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from ....utils import logging
from .base_batch_sampler import BaseBatchSampler


class TextBatchSampler(BaseBatchSampler):
def __init__(self):
"""Initializes the BaseBatchSampler.

Args:
batch_size (int, optional): The size of each batch. Only support 1.
"""
super().__init__()
self.batch_size = 1

def sample(self, inputs):
"""Generate list of input file path.

Args:
inputs (str): file path.

Yields:
list: list of file path.
"""
if isinstance(inputs, str):
yield [inputs]
else:
logging.warning(
f"Not supported input data type! Only `str` are supported, but got: {input}."
)

@BaseBatchSampler.batch_size.setter
def batch_size(self, batch_size):
"""Sets the batch size.

Args:
batch_size (int): The batch size to set.

Raises:
Warning: If the batch size is not equal 1.
"""
# only support batch size 1
if batch_size != 1:
logging.warning(
f"audio batch sampler only support batch size 1, but got {batch_size}."
)
else:
self._batch_size = batch_size
1 change: 1 addition & 0 deletions paddlex/inference/common/result/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from .base_result import BaseResult
from .base_ts_result import BaseTSResult
from .base_video_result import BaseVideoResult
from .base_audio_result import BaseAudioResult
from .mixin import (
Base64Mixin,
CSVMixin,
Expand Down
36 changes: 36 additions & 0 deletions paddlex/inference/common/result/base_audio_result.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .base_result import BaseResult
from .mixin import AudioMixin


class BaseAudioResult(BaseResult, AudioMixin):
"""Base class for computer vision results."""

INPUT_AUDIO_KEY = "input_audio"

def __init__(self, data: dict) -> None:
"""
Initialize the BaseAudioResult.

Args:
data (dict): The initial data.

Raises:
AssertionError: If the required key (`BaseAudioResult.INPUT_AUDIO_KEY`) are not found in the data.
"""

super().__init__(data)
AudioMixin.__init__(self,'wav')
61 changes: 61 additions & 0 deletions paddlex/inference/common/result/mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
TextWriter,
VideoWriter,
XlsxWriter,
AudioWriter,
)


Expand Down Expand Up @@ -592,7 +593,67 @@ def _is_video_file(file_path):
f"The result has multiple video files need to be saved. But the `save_path` has been specified as `{save_path}`!"
)
video_writer.write(save_path, video[list(video.keys())[0]], *args, **kwargs)

class AudioMixin:
"""Mixin class for adding Audio handling capabilities."""

def __init__(self, backend, *args: List, **kwargs: Dict) -> None:
"""Initializes AudioMixin.

Args:
*args: Additional positional arguments to pass to the AudioWriter.
**kwargs: Additional keyword arguments to pass to the AudioWriter.
"""
self._backend = backend
self._save_funcs.append(self.save_to_audio)

@abstractmethod
def _to_audio(self) -> Dict[str, np.array]:
"""Abstract method to convert the result to a audio.

Returns:
Dict[str, np.array]: The audio representation result.
"""
raise NotImplementedError

@property
def audio(self) -> Dict[str, np.array]:
"""Property to get the audio representation of the result.

Returns:
Dict[str, np.array]: The audio representation of the result.
"""
return self._to_audio()

def save_to_audio(self, save_path: str, *args: List, **kwargs: Dict) -> None:
"""Saves the audio representation of the result to the specified path.

Args:
save_path (str): The path to save the audio. If the save path does not end with .mp4 or .avi, it appends the input path's stem and suffix to the save path.
*args: Additional positional arguments that will be passed to the audio writer.
**kwargs: Additional keyword arguments that will be passed to the audio writer.
"""

def _is_audio_file(file_path):
mime_type, _ = mimetypes.guess_type(file_path)
return mime_type is not None and mime_type.startswith("audio/")

audio_writer = AudioWriter(backend=self._backend, *args, **kwargs)
audio = self._to_audio()
if not _is_audio_file(save_path):
fn = Path(self._get_input_fn())
stem = fn.stem
suffix = fn.suffix if _is_audio_file(fn) else ".mp4"
base_save_path = Path(save_path)
for key in audio:
save_path = base_save_path / f"{stem}_{key}{suffix}"
audio_writer.write(save_path.as_posix(), audio[key], *args, **kwargs)
else:
if len(audio) > 1:
logging.warning(
f"The result has multiple audio files need to be saved. But the `save_path` has been specified as `{save_path}`!"
)
audio_writer.write(save_path, audio[list(audio.keys())[0]], *args, **kwargs)

class MarkdownMixin:
"""Mixin class for adding Markdown handling capabilities."""
Expand Down
4 changes: 3 additions & 1 deletion paddlex/inference/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,9 @@
from .ts_forecasting import TSFcPredictor
from .video_classification import VideoClasPredictor
from .video_detection import VideoDetPredictor

from .text_to_speech_acoustic import Fastspeech2Predictor
from .text_to_speech_vocoder import PwganPredictor
from .text_to_pinyin import TextToPinyinPredictor

def create_predictor(
model_name: str,
Expand Down
2 changes: 0 additions & 2 deletions paddlex/inference/models/common/tokenizer/tokenizer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -927,8 +927,6 @@ def init_chat_template(self, chat_template: Union[str, dict]):
raise ValueError("Receive error chat_template data: ", chat_template)

def save_resources(self, save_directory):
super().save_resources(save_directory)

if isinstance(
self.chat_template, ChatTemplate
): # Future remove if ChatTemplate is deprecated
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
import numpy as np

from .....utils import logging
from .....utils.download import download
from .....utils.cache import CACHE_DIR

__all__ = [
"AddedToken",
Expand Down Expand Up @@ -1661,7 +1663,15 @@ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
resolved_vocab_files = {}
for file_id, file_path in vocab_files.items():
# adapt to PaddleX
resolved_vocab_files[file_id] = file_path
if file_path is None or os.path.isfile(file_path):
resolved_vocab_files[file_id] = file_path
continue
else:
download_path = os.path.join(
CACHE_DIR, "official_models", pretrained_model_name_or_path, file_id
)
download(file_path, download_path)
resolved_vocab_files[file_id] = download_path

for file_id, file_path in resolved_vocab_files.items():
if resolved_vocab_files[file_id] is not None:
Expand Down
15 changes: 15 additions & 0 deletions paddlex/inference/models/text_to_pinyin/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# copyright (c) 2025 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .predictor import TextToPinyinPredictor
84 changes: 84 additions & 0 deletions paddlex/inference/models/text_to_pinyin/predictor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# copyright (c) 2025 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np

from ....utils.func_register import FuncRegister
from ...common.batch_sampler import TextBatchSampler

from ..base import BasePredictor
from .result import TextToPinyinResult
from ....modules.text_to_pinyin.model_list import MODELS


class TextToPinyinPredictor(BasePredictor):

entities = MODELS

def __init__(self, *args, **kwargs):
"""Initializes TextSegmentPredictor.

Args:
*args: Arbitrary positional arguments passed to the superclass.
**kwargs: Arbitrary keyword arguments passed to the superclass.
"""
super().__init__(*args, **kwargs)
self.model = self._build()

def _build_batch_sampler(self):
"""Builds and returns an TextBatchSampler instance.

Returns:
TextBatchSampler: An instance of TextBatchSampler.
"""
return TextBatchSampler()

def _get_result_class(self):
"""Returns the result class, TextToPinyinResult.

Returns:
type: The TextToPinyinResult class.
"""
return TextToPinyinResult

def _build(self):
"""Build the model.

Returns:
G2PWOnnxConverter: An instance of G2PWOnnxConverter.
"""
from .processors import (
G2PWOnnxConverter,
)

# build model
model = G2PWOnnxConverter(
model_dir=self.model_dir, style="pinyin", enable_non_tradional_chinese=True
)
return model

def process(self, batch_data):
"""
Process a batch of data through the preprocessing, inference, and postprocessing.

Args:
batch_data (List[Union[str], ...]): A batch of input text data.

Returns:
dict: A dictionary containing the input path and result. The result include the output pinyin dict.
"""
result = self.model(batch_data[0])
return {
"result": [result]
}
Loading