diff --git a/.gitignore b/.gitignore index d15377a62..839333dc3 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,7 @@ # IDE .idea +.vscode # Jupyter .ipynb_checkpoints/ @@ -26,4 +27,7 @@ tests/data_for_tests/generated/ # coverage / pytest-cov .coverage -coverage.xml \ No newline at end of file +coverage.xml + +#nox +.nox/ \ No newline at end of file diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 000000000..9ee86e71a --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,6 @@ +{ + "[python]": { + "editor.defaultFormatter": "ms-python.autopep8" + }, + "python.formatting.provider": "none" +} \ No newline at end of file diff --git a/src/scripts/test_vae.ipynb b/src/scripts/test_vae.ipynb new file mode 100644 index 000000000..9084cb9d8 --- /dev/null +++ b/src/scripts/test_vae.ipynb @@ -0,0 +1,49 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "from src.vak.nets.ava import Ava\n", + "import torch" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "x_shape = (3, 128, 512)\n", + "input = torch.zeros(x_shape)\n", + "net = Ava(x_shape=(x_shape[1], x_shape[2]))\n", + "output, _ = net.forward(input)\n", + "assert output.shape == x_shape, 'Error'" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.13" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/src/vak/cli/prep.py b/src/vak/cli/prep.py index 3a8ee6b8d..92e36104b 100644 --- a/src/vak/cli/prep.py +++ b/src/vak/cli/prep.py @@ -139,6 +139,9 @@ def prep(toml_path): test_dur=cfg.prep.test_dur, train_set_durs=cfg.prep.train_set_durs, num_replicates=cfg.prep.num_replicates, + context_s=cfg.prep.context_s, + max_dur=cfg.prep.max_dur, + target_shape=cfg.prep.target_shape, ) # use config and section from above to add dataset_path to config.toml file diff --git a/src/vak/config/prep.py b/src/vak/config/prep.py index 7481d8cc2..4281981d1 100644 --- a/src/vak/config/prep.py +++ b/src/vak/config/prep.py @@ -25,7 +25,7 @@ def is_valid_duration(instance, attribute, value): """validator for dataset split durations""" if type(value) not in {int, float}: raise TypeError( - f"invalid type for {attribute} of {instance}: {type(value)}. Type should be float or int." + f"invalid type for {attribute.name} of {instance}: {type(value)}. Type should be float or int." ) if value == -1: # specifies "use the remainder of the dataset" @@ -34,7 +34,7 @@ def is_valid_duration(instance, attribute, value): if not value >= 0: raise ValueError( - f"value specified for {attribute} of {instance} must be greater than or equal to zero, was {value}" + f"value specified for {attribute.name} of {instance} must be greater than or equal to zero, was {value}" ) @@ -60,6 +60,25 @@ def are_valid_dask_bag_kwargs(instance, attribute, value): ) +def is_valid_target_shape(instance, attribute, value): + """validator for target shape""" + if not isinstance(value, (tuple, list)): + raise TypeError( + f"invalid type for {attribute.name} of {instance}: {type(value)}. Type should be tuple or list." + ) + + if not all([isinstance(val, int) for val in value]): + raise ValueError( + f"All values in {attribute.name} of {instance} should be integers" + ) + + if not len(value) == 2: + raise ValueError( + f"{attribute.name} of {instance} should have length 2: " + f"(number of frequency bins, number of time bins). " + f"Length was: {len(value)}" + ) + @attr.s class PrepConfig: """class to represent [PREP] section of config.toml file @@ -125,6 +144,31 @@ class PrepConfig: in a learning curve. Each replicate uses a different randomly drawn subset of the training data (but of the same duration). Default is None. Required if config file has a learncurve section. + context_s : float + Number of seconds of "context" around a segment to + add, i.e., time before and after the onset + and offset respectively. Default is 0.005s, + 5 milliseconds. This parameter is only used for + Parametric UMAP and segment-VAE datasets. + max_dur : float + Maximum duration for segments. + If a float value is specified, + any segment with a duration larger than + that value (in seconds) will be omitted + from the dataset. Default is None. + This parameter is only used for + vae-segment datasets. + target_shape : tuple + Of ints, (target number of frequency bins, + target number of time bins). + Spectrograms of units will be reshaped + by interpolation to have the specified + number of frequency and time bins. + The transformation is only applied if both this + parameter and ``max_dur`` are specified. + Default is None. + This parameter is only used for + vae-segment datasets. """ data_dir = attr.ib(converter=expanded_user_path) @@ -195,6 +239,30 @@ def is_valid_input_type(self, attribute, value): validator=validators.optional(instance_of(int)), default=None ) + context_s = attr.ib( + default=0.005 + ) + @context_s.validator + def is_valid_context_s(self, attribute, value): + if not isinstance(value, float): + raise TypeError( + f"Value for {attribute.name} should be float but type was: {type(value)}" + ) + if not value >= 0.: + raise ValueError( + f"Value for {attribute.name} should be greater than or equal to 0., " + f"but was: {value}" + ) + + max_dur = attr.ib( + validator=validators.optional(instance_of(float)), default=None + ) + target_shape = attr.ib( + converter=converters.optional(tuple), + validator=validators.optional(is_valid_target_shape), + default=None + ) + def __attrs_post_init__(self): if self.audio_format is not None and self.spect_format is not None: raise ValueError("cannot specify audio_format and spect_format") diff --git a/src/vak/config/spect_params.py b/src/vak/config/spect_params.py index 4a61942a6..cc17aabed 100644 --- a/src/vak/config/spect_params.py +++ b/src/vak/config/spect_params.py @@ -15,7 +15,9 @@ def freq_cutoffs_validator(instance, attribute, value): ) -VALID_TRANSFORM_TYPES = {"log_spect", "log_spect_plus_one"} +VALID_TRANSFORM_TYPES = { + "log", "log_spect", "log_spect_plus_one" +} def is_valid_transform_type(instance, attribute, value): @@ -57,6 +59,24 @@ class SpectParamsConfig: audio_path_key : str key for accessing path to source audio file for spectogram in files. Default is 'audio_path'. + min_val : float, optional + Minimum value to allow in spectrogram. + All values less than this will be set to this value. + This operation is applied *after* the transform + specified by ``transform_type``. + Default is None. + max_val : float, optional + Maximum value to allow in spectrogram. + All values greater than this will be set to this value. + This operation is applied *after* the transform + specified by ``transform_type``. + Default is None. + normalize : bool + If True, min-max normalize the spectrogram. + Normalization is done *after* the transform + specified by ``transform_type``, and *after* + the ``min_val`` and ``max_val`` operations. + Default is False. """ fft_size = attr.ib(converter=int, validator=instance_of(int), default=512) @@ -79,3 +99,16 @@ class SpectParamsConfig: freqbins_key = attr.ib(validator=instance_of(str), default="f") timebins_key = attr.ib(validator=instance_of(str), default="t") audio_path_key = attr.ib(validator=instance_of(str), default="audio_path") + min_val = attr.ib( + validator=validators.optional(instance_of(float)), + default=None + ) + max_val = attr.ib( + validator=validators.optional(instance_of(float)), + default=None + ) + normalize = attr.ib( + validator=instance_of(bool), + default=False, + ) + diff --git a/src/vak/config/valid.toml b/src/vak/config/valid.toml index 11cd535f5..d7da02f18 100644 --- a/src/vak/config/valid.toml +++ b/src/vak/config/valid.toml @@ -21,6 +21,9 @@ val_dur = 15 test_dur = 30 train_set_durs = [ 4.5, 6.0 ] num_replicates = 2 +context_s = 0.005 +max_dur = 0.2 +target_shape = [128, 128] [SPECT_PARAMS] fft_size = 512 @@ -32,6 +35,9 @@ spect_key = 's' freqbins_key = 'f' timebins_key = 't' audio_path_key = 'audio_path' +min_val = -6.0 +max_val = 0.0 +normalize = true [TRAIN] model = 'TweetyNet' diff --git a/src/vak/datasets/__init__.py b/src/vak/datasets/__init__.py index 0a8cc3764..8b73bdbd5 100644 --- a/src/vak/datasets/__init__.py +++ b/src/vak/datasets/__init__.py @@ -1,3 +1,3 @@ -from . import frame_classification, parametric_umap +from . import frame_classification, parametric_umap, vae -__all__ = ["frame_classification", "parametric_umap"] +__all__ = ["frame_classification", "parametric_umap", "vae"] diff --git a/src/vak/datasets/frame_classification/metadata.py b/src/vak/datasets/frame_classification/metadata.py index 61c7cb918..711bf556c 100644 --- a/src/vak/datasets/frame_classification/metadata.py +++ b/src/vak/datasets/frame_classification/metadata.py @@ -1,7 +1,8 @@ """A dataclass that represents metadata -associated with a frame classification dataset, -as generated by -:func:`vak.core.prep.frame_classification.prep_frame_classification_dataset`""" +associated with a frame classification dataset. + +Metadata is generated by +:func:`vak.core.prep.frame_classification.prep_frame_classification_dataset`.""" from __future__ import annotations import json @@ -10,43 +11,16 @@ import attr - -def is_valid_dataset_csv_filename(instance, attribute, value): - valid = "_prep_" in value and value.endswith(".csv") - if not valid: - raise ValueError( - f"Invalid dataset csv filename: {value}." - f'Filename should contain the string "_prep_" ' - f"and end with the extension .csv." - f"Valid filenames are generated by " - f"vak.core.prep.generate_dataset_csv_filename" - ) - - -def is_valid_audio_format(instance, attribute, value): - import vak.common.constants - - if value not in vak.common.constants.VALID_AUDIO_FORMATS: - raise ValueError( - f"Not a valid audio format: {value}. Valid audio formats are: {vak.common.constants.VALID_AUDIO_FORMATS}" - ) - - -def is_valid_spect_format(instance, attribute, value): - import vak.common.constants - - if value not in vak.common.constants.VALID_SPECT_FORMATS: - raise ValueError( - f"Not a valid spectrogram format: {value}. " - f"Valid spectrogram formats are: {vak.common.constants.VALID_SPECT_FORMATS}" - ) +from .. import validators @attr.define class Metadata: """A dataclass that represents metadata - associated with a dataset that was - generated by :func:`vak.core.prep.prep`. + associated with a frame classification dataset. + + Metadata is generated by + :func:`vak.core.prep.frame_classification.prep_frame_classification_dataset` Attributes ---------- @@ -67,7 +41,7 @@ class Metadata: METADATA_JSON_FILENAME: ClassVar = "metadata.json" dataset_csv_filename: str = attr.field( - converter=str, validator=is_valid_dataset_csv_filename + converter=str, validator=validators.is_valid_dataset_csv_filename ) input_type: str = attr.field() @@ -97,13 +71,13 @@ def is_valid_frame_dur(self, attribute, value): audio_format: str = attr.field( converter=attr.converters.optional(str), - validator=attr.validators.optional(is_valid_audio_format), + validator=attr.validators.optional(validators.is_valid_audio_format), default=None, ) spect_format: str = attr.field( converter=attr.converters.optional(str), - validator=attr.validators.optional(is_valid_spect_format), + validator=attr.validators.optional(validators.is_valid_spect_format), default=None, ) @@ -112,7 +86,7 @@ def from_path(cls, json_path: str | pathlib.Path) -> Metadata: """Load dataset metadata from a json file. Class method that returns an instance of - :class:`~vak.datasets.frame_classification.FrameClassificationDatatsetMetadata`. + :class:`~vak.datasets.frame_classification.Metadata`. Parameters ---------- @@ -123,8 +97,8 @@ def from_path(cls, json_path: str | pathlib.Path) -> Metadata: Returns ------- - metadata : vak.datasets.frame_classification.FrameClassificationDatatsetMetadata - Instance of :class:`~vak.datasets.frame_classification.FrameClassificationDatatsetMetadata` + metadata : vak.datasets.frame_classification.Metadata + Instance of :class:`~vak.datasets.frame_classification.Metadata` with metadata loaded from json file. """ json_path = pathlib.Path(json_path) @@ -153,16 +127,17 @@ def to_json(self, dataset_path: str | pathlib.Path) -> None: This method is called by :func:`vak.core.prep.prep` after it generates a dataset and then creates an - instance of :class:`~vak.datasets.frame_classification.FrameClassificationDatatsetMetadata` + instance of :class:`~vak.datasets.frame_classification.Metadata` with metadata about that dataset. Parameters ---------- dataset_path : string, pathlib.Path - Path to root of a directory representing a dataset - generated by :func:`vak.core.prep.prep`. - where 'metadata.json' file - should be saved. + Path where 'metadata.json' file + should be saved. Typically, + the root of a directory representing a dataset + generated by + :func:`vak.core.prep.frame_classification.prep_frame_classification_dataset`. """ dataset_path = pathlib.Path(dataset_path) if not dataset_path.exists() or not dataset_path.is_dir(): diff --git a/src/vak/datasets/parametric_umap/metadata.py b/src/vak/datasets/parametric_umap/metadata.py index ac0b8a137..cfcc2483e 100644 --- a/src/vak/datasets/parametric_umap/metadata.py +++ b/src/vak/datasets/parametric_umap/metadata.py @@ -1,7 +1,8 @@ """A dataclass that represents metadata -associated with a dimensionality reduction dataset, -as generated by -:func:`vak.core.prep.frame_classification.prep_dimensionality_reduction_dataset`""" +associated with a parametric UMAP dataset. + +The metadata is generated by +:func:`vak.core.prep.parametric_umap.prep_parametric_umap_dataset`.""" from __future__ import annotations import json @@ -10,43 +11,16 @@ import attr - -def is_valid_dataset_csv_filename(instance, attribute, value): - valid = "_prep_" in value and value.endswith(".csv") - if not valid: - raise ValueError( - f"Invalid dataset csv filename: {value}." - f'Filename should contain the string "_prep_" ' - f"and end with the extension .csv." - f"Valid filenames are generated by " - f"vak.core.prep.generate_dataset_csv_filename" - ) - - -def is_valid_audio_format(instance, attribute, value): - import vak.common.constants - - if value not in vak.common.constants.VALID_AUDIO_FORMATS: - raise ValueError( - f"Not a valid audio format: {value}. Valid audio formats are: {vak.common.constants.VALID_AUDIO_FORMATS}" - ) - - -def is_valid_spect_format(instance, attribute, value): - import vak.common.constants - - if value not in vak.common.constants.VALID_SPECT_FORMATS: - raise ValueError( - f"Not a valid spectrogram format: {value}. " - f"Valid spectrogram formats are: {vak.common.constants.VALID_SPECT_FORMATS}" - ) +from .. import validators @attr.define class Metadata: """A dataclass that represents metadata - associated with a dataset that was - generated by :func:`vak.core.prep.prep`. + associated with a parametric UMAP dataset. + + The metadata is generated by + :func:`vak.core.prep.parametric_umap.prep_parametric_umap_dataset`. Attributes ---------- @@ -54,7 +28,10 @@ class Metadata: Name of csv file representing the source files in the dataset. Csv file will be located in root of directory representing dataset, so only the filename is given. - audio_format + shape : tuple + Of ints, the shape of the samples. + audio_format : str + The format of the source audio files used to generate the dataset. """ # declare this as a constant to avoid @@ -62,7 +39,7 @@ class Metadata: METADATA_JSON_FILENAME: ClassVar = "metadata.json" dataset_csv_filename: str = attr.field( - converter=str, validator=is_valid_dataset_csv_filename + converter=str, validator=validators.is_valid_dataset_csv_filename ) shape: tuple = attr.field(converter=tuple) @@ -80,7 +57,7 @@ def is_valid_shape(self, attribute, value): audio_format: str = attr.field( converter=attr.converters.optional(str), - validator=attr.validators.optional(is_valid_audio_format), + validator=attr.validators.optional(validators.is_valid_audio_format), default=None, ) @@ -89,7 +66,7 @@ def from_path(cls, json_path: str | pathlib.Path): """Load dataset metadata from a json file. Class method that returns an instance of - :class:`~vak.datasets.frame_classification.FrameClassificationDatatsetMetadata`. + :class:`~vak.datasets.parametric_umap.Metadata`. Parameters ---------- @@ -100,8 +77,8 @@ def from_path(cls, json_path: str | pathlib.Path): Returns ------- - metadata : vak.datasets.frame_classification.FrameClassificationDatatsetMetadata - Instance of :class:`~vak.datasets.frame_classification.FrameClassificationDatatsetMetadata` + metadata : vak.datasets.parametric_umap.Metadata + Instance of :class:`~vak.datasets.parametric_umap.Metadata` with metadata loaded from json file. """ json_path = pathlib.Path(json_path) @@ -130,16 +107,17 @@ def to_json(self, dataset_path: str | pathlib.Path) -> None: This method is called by :func:`vak.core.prep.prep` after it generates a dataset and then creates an - instance of :class:`~vak.datasets.frame_classification.FrameClassificationDatatsetMetadata` + instance of :class:`~vak.datasets.parametric_umap.Metadata` with metadata about that dataset. Parameters ---------- dataset_path : string, pathlib.Path - Path to root of a directory representing a dataset - generated by :func:`vak.core.prep.prep`. - where 'metadata.json' file + Path where 'metadata.json' file should be saved. + Typically the root of a directory representing a dataset + generated by + :func:`vak.core.prep.parametric_umap.prep_parametric_umap_dataset` """ dataset_path = pathlib.Path(dataset_path) if not dataset_path.exists() or not dataset_path.is_dir(): diff --git a/src/vak/datasets/vae/__init__.py b/src/vak/datasets/vae/__init__.py new file mode 100644 index 000000000..340bda9ae --- /dev/null +++ b/src/vak/datasets/vae/__init__.py @@ -0,0 +1,10 @@ +from .metadata import Metadata +from .segment_dataset import SegmentDataset +from .window_dataset import WindowDataset + + +__all__ = [ + "Metadata", + "SegmentDataset", + "WindowDataset", +] diff --git a/src/vak/datasets/vae/metadata.py b/src/vak/datasets/vae/metadata.py new file mode 100644 index 000000000..f6bdc2be9 --- /dev/null +++ b/src/vak/datasets/vae/metadata.py @@ -0,0 +1,138 @@ +"""A dataclass that represents metadata +associated with a VAE dataset. + +The metadata is generated by +:func:`vak.core.prep.vae.prep_vae_dataset`.""" +from __future__ import annotations + +import json +import pathlib +from typing import ClassVar + +import attr + +from .. import validators +from ...prep.vae.vae import VAE_DATASET_TYPES + + +def is_valid_vae_dataset_type(instance, attribute, value): + if value not in VAE_DATASET_TYPES: + raise ValueError( + f"`dataset_type` must be one of '{VAE_DATASET_TYPES}', but was: {value}" + ) + + +@attr.define +class Metadata: + """A dataclass that represents metadata + associated with a dataset that was + generated by :func:`vak.core.prep.prep`. + + Attributes + ---------- + dataset_csv_filename : str + Name of csv file representing the source files in the dataset. + Csv file will be located in root of directory representing dataset, + so only the filename is given. + dataset_type : str + One of: {'vae-segment', 'vae-window'} + audio_format : str + Format of audio files. One of {'wav', 'cbin'}. + Default is ``None``, but either ``audio_format`` or ``spect_format`` + must be specified. + shape : tuple, optional + Shape of dataset. + Only used for 'segment-vae' dataset. + """ + + # declare this as a constant to avoid + # needing to remember this in multiple places, and to use in unit tests + METADATA_JSON_FILENAME: ClassVar = "metadata.json" + + dataset_csv_filename: str = attr.field( + converter=str, validator=validators.is_valid_dataset_csv_filename + ) + dataset_type: str = attr.field( + converter=str, validator=is_valid_vae_dataset_type + ) + + shape: tuple = attr.field( + converter=attr.converters.optional(tuple), + validator=attr.validators.optional(validators.is_valid_shape), + default=None + ) + + audio_format: str = attr.field( + converter=attr.converters.optional(str), + validator=attr.validators.optional(validators.is_valid_audio_format), + default=None, + ) + + @classmethod + def from_path(cls, json_path: str | pathlib.Path): + """Load dataset metadata from a json file. + + Class method that returns an instance of + :class:`~vak.datasets.vae.Metadata`. + + Parameters + ---------- + json_path : string, pathlib.Path + Path to a 'metadata.json' file created by + :func:`vak.core.prep.prep` when generating + a dataset. + + Returns + ------- + metadata : vak.datasets.vae.Metadata + Instance of :class:`~vak.datasets.vae.Metadata` + with metadata loaded from json file. + """ + json_path = pathlib.Path(json_path) + with json_path.open("r") as fp: + metadata_json = json.load(fp) + return cls(**metadata_json) + + @classmethod + def from_dataset_path(cls, dataset_path: str | pathlib.Path): + dataset_path = pathlib.Path(dataset_path) + if not dataset_path.exists() or not dataset_path.is_dir(): + raise NotADirectoryError( + f"`dataset_path` not found or not recognized as a directory: {dataset_path}" + ) + + metadata_json_path = dataset_path / cls.METADATA_JSON_FILENAME + if not metadata_json_path.exists(): + raise FileNotFoundError( + f"Metadata file not found: {metadata_json_path}" + ) + + return cls.from_path(metadata_json_path) + + def to_json(self, dataset_path: str | pathlib.Path) -> None: + """Dump dataset metadata to a json file. + + This method is called by :func:`vak.core.prep.prep` + after it generates a dataset and then creates an + instance of :class:`~vak.datasets.vae.Metadata` + with metadata about that dataset. + + Parameters + ---------- + dataset_path : string, pathlib.Path + Path where 'metadata.json' file + should be saved. Typically, the root + of a directory representing a dataset + generated by + :func:`vak.core.prep.vae.prep_vae_dataset`. + """ + dataset_path = pathlib.Path(dataset_path) + if not dataset_path.exists() or not dataset_path.is_dir(): + raise NotADirectoryError( + f"dataset_path not recognized as a directory: {dataset_path}" + ) + + json_dict = attr.asdict(self) + json_path = dataset_path / self.METADATA_JSON_FILENAME + with json_path.open("w") as fp: + json.dump(json_dict, fp, indent=4) diff --git a/src/vak/datasets/vae/segment_dataset.py b/src/vak/datasets/vae/segment_dataset.py new file mode 100644 index 000000000..687988b94 --- /dev/null +++ b/src/vak/datasets/vae/segment_dataset.py @@ -0,0 +1,112 @@ +"""Dataset class for VAE models that operate on segments. + +Segments are typically found with a segmenting algorithm +that thresholds audio signal energy, +e.g., syllables from birdsong or mouse USVs.""" +from __future__ import annotations + +import pathlib +from typing import Callable + +import numpy as np +import numpy.typing as npt +import pandas as pd +import torch.utils.data + + +class SegmentDataset(torch.utils.data.Dataset): + """Dataset class for VAE models that operate on segments. + + Segments are typically found with a segmenting algorithm + that thresholds audio signal energy, + e.g., syllables from birdsong or mouse USVs.""" + + def __init__( + self, + data: npt.NDArray, + dataset_df: pd.DataFrame, + transform: Callable | None = None, + ): + self.data = data + self.dataset_df = dataset_df + self.transform = transform + + @property + def duration(self): + return self.dataset_df["duration"].sum() + + def __len__(self): + return self.data.shape[0] + + @property + def shape(self): + tmp_x_ind = 0 + tmp_item = self.__getitem__(tmp_x_ind) + return tmp_item["x"].shape + + def __getitem__(self, index): + x = self.data[index] + df_index = self.dataset_df.index[index] + if self.transform: + x = self.transform(x) + return {"x": x, "df_index": df_index} + + @classmethod + def from_dataset_path( + cls, + dataset_path: str | pathlib.Path, + split: str, + subset: str | None = None, + transform: Callable | None = None, + ): + """Make a :class:`SegmentDataset` instance, + given the path to a VAE segment dataset. + + Parameters + ---------- + dataset_path : pathlib.Path + Path to directory that represents a + frame classification dataset, + as created by + :func:`vak.prep.prep_frame_classification_dataset`. + split : str + The name of a split from the dataset, + one of {'train', 'val', 'test'}. + subset : str, optional + Name of subset to use. + If specified, this takes precedence over split. + Subsets are typically taken from the training data + for use when generating a learning curve. + transform : callable + The transform applied to the input to the neural network :math:`x`. + + Returns + ------- + dataset : vak.datasets.vae.SegmentDataset + """ + import vak.datasets # import here just to make classmethod more explicit + + dataset_path = pathlib.Path(dataset_path) + metadata = vak.datasets.vae.Metadata.from_dataset_path( + dataset_path + ) + + dataset_csv_path = dataset_path / metadata.dataset_csv_filename + dataset_df = pd.read_csv(dataset_csv_path) + # subset takes precedence over split, if specified + if subset: + dataset_df = dataset_df[dataset_df.subset == subset].copy() + else: + dataset_df = dataset_df[dataset_df.split == split].copy() + + data = np.stack( + [ + np.load(dataset_path / spect_path) + for spect_path in dataset_df.spect_path.values + ] + ) + return cls( + data, + dataset_df, + transform=transform, + ) diff --git a/src/vak/datasets/vae/window_dataset.py b/src/vak/datasets/vae/window_dataset.py new file mode 100644 index 000000000..34e55293d --- /dev/null +++ b/src/vak/datasets/vae/window_dataset.py @@ -0,0 +1,343 @@ +"""Dataset class used for VAE models that operate on fixed-sized windows, +such as a "shotgun VAE" [1]_. + +.. [1] Goffinet, J., Brudner, S., Mooney, R., & Pearson, J. (2021). + Low-dimensional learned feature spaces quantify individual and group differences in vocal repertoires. + eLife, 10:e67855. https://doi.org/10.7554/eLife.67855""" +from __future__ import annotations + +import pathlib +from typing import Callable + +import numpy as np +import numpy.typing as npt +import pandas as pd + +from ..frame_classification import constants, helper +from .metadata import Metadata + + +class WindowDataset: + """Dataset class used for VAE models that operate on fixed-sized windows, + such as a "shotgun VAE" [1]_. + + Attributes + ---------- + dataset_path : pathlib.Path + Path to directory that represents a + frame classification dataset, + as created by + :func:`vak.prep.prep_frame_classification_dataset`. + split : str + The name of a split from the dataset, + one of {'train', 'val', 'test'}. + subset : str, optional + Name of subset to use. + If specified, this takes precedence over split. + Subsets are typically taken from the training data + for use when generating a learning curve. + dataset_df : pandas.DataFrame + A frame classification dataset, + represented as a :class:`pandas.DataFrame`. + This will be only the rows that correspond + to either ``subset`` or ``split`` from the + ``dataset_df`` that was passed in when + instantiating the class. + input_type : str + The type of input to the neural network model. + One of {'audio', 'spect'}. + frame_paths : numpy.ndarray + Paths to npy files containing frames, + either spectrograms or audio signals + that are input to the model. + sample_ids : numpy.ndarray + Indexing vector representing which sample + from the dataset every frame belongs to. + inds_in_sample : numpy.ndarray + Indexing vector representing which index + within each sample from the dataset + that every frame belongs to. + window_size : int + Size of windows to return; + number of frames. + frame_dur: float + Duration of a frame, i.e., a single sample in audio + or a single timebin in a spectrogram. + stride : int + The size of the stride used to determine which windows + are included in the dataset. The default is 1. + Used to compute ``window_inds``, + with the function + :func:`vak.datasets.frame_classification.window_dataset.get_window_inds`. + window_inds : numpy.ndarray, optional + A vector of valid window indices for the dataset. + If specified, this takes precedence over ``stride``. + transform : callable + The transform applied to the frames, + the input to the neural network :math:`x`. + + References + ---------- + .. [1] Goffinet, J., Brudner, S., Mooney, R., & Pearson, J. (2021). + Low-dimensional learned feature spaces quantify individual and group differences in vocal repertoires. + eLife, 10:e67855. https://doi.org/10.7554/eLife.67855 + """ + + def __init__( + self, + dataset_path: str | pathlib.Path, + dataset_df: pd.DataFrame, + input_type: str, + split: str, + sample_ids: npt.NDArray, + inds_in_sample: npt.NDArray, + window_size: int, + frame_dur: float, + stride: int = 1, + subset: str | None = None, + window_inds: npt.NDArray | None = None, + transform: Callable | None = None, + ): + """Initialize a new instance of a WindowDataset. + + Parameters + ---------- + dataset_path : pathlib.Path + Path to directory that represents a + VAE dataset, as created by :func:`vak.prep.prep_vae_dataset`. + dataset_df : pandas.DataFrame + A VAE dataset, + represented as a :class:`pandas.DataFrame`. + input_type : str + The type of input to the neural network model. + One of {'audio', 'spect'}. + split : str + The name of a split from the dataset, + one of {'train', 'val', 'test'}. + sample_ids : numpy.ndarray + Indexing vector representing which sample + from the dataset every frame belongs to. + inds_in_sample : numpy.ndarray + Indexing vector representing which index + within each sample from the dataset + that every frame belongs to. + window_size : int + Size of windows to return; + number of frames. + frame_dur: float + Duration of a frame, i.e., a single sample in audio + or a single timebin in a spectrogram. + stride : int + The size of the stride used to determine which windows + are included in the dataset. The default is 1. + Used to compute ``window_inds``, + with the function + :func:`vak.datasets.frame_classification.window_dataset.get_window_inds`. + subset : str, optional + Name of subset to use. + If specified, this takes precedence over split. + Subsets are typically taken from the training data + for use when generating a learning curve. + window_inds : numpy.ndarray, optional + A vector of valid window indices for the dataset. + If specified, this takes precedence over ``stride``. + transform : callable + The transform applied to the input to the neural network :math:`x`. + target_transform : callable + The transform applied to the target for the output + of the neural network :math:`y`. + """ + from ... import ( + prep, + ) # avoid circular import, use for constants.INPUT_TYPES + + if input_type not in prep.constants.INPUT_TYPES: + raise ValueError( + f"``input_type`` must be one of: {prep.constants.INPUT_TYPES}\n" + f"Value for ``input_type`` was: {input_type}" + ) + + self.dataset_path = pathlib.Path(dataset_path) + self.split = split + self.subset = subset + # subset takes precedence over split, if specified + if subset: + dataset_df = dataset_df[dataset_df.subset == subset].copy() + else: + dataset_df = dataset_df[dataset_df.split == split].copy() + self.dataset_df = dataset_df + self.input_type = input_type + self.frames_paths = self.dataset_df[ + constants.FRAMES_PATH_COL_NAME + ].values + self.sample_ids = sample_ids + self.inds_in_sample = inds_in_sample + self.window_size = window_size + self.frame_dur = float(frame_dur) + self.stride = stride + if window_inds is None: + window_inds = get_window_inds( + sample_ids.shape[-1], window_size, stride + ) + self.window_inds = window_inds + self.transform = transform + + @property + def duration(self): + return self.sample_ids.shape[-1] * self.frame_dur + + @property + def shape(self): + tmp_x_ind = 0 + one_x, _ = self.__getitem__(tmp_x_ind) + # used by vak functions that need to determine size of window, + # e.g. when initializing a neural network model + return one_x.shape + + def _load_frames(self, frames_path): + """Helper function that loads "frames", + the input to the frame classification model. + Loads audio or spectrogram, depending on + :attr:`self.input_type`. + This function assumes that audio is in wav format + and spectrograms are in npz files. + """ + return helper.load_frames(frames_path, self.input_type) + + def __getitem__(self, idx): + window_idx = self.window_inds[idx] + sample_ids = self.sample_ids[ + window_idx : window_idx + self.window_size # noqa: E203 + ] + uniq_sample_ids = np.unique(sample_ids) + if len(uniq_sample_ids) == 1: + # we repeat ourselves here to avoid running a loop on one item + sample_id = uniq_sample_ids[0] + frames_path = self.dataset_path / self.frames_paths[sample_id] + frames = self._load_frames(frames_path) + + elif len(uniq_sample_ids) > 1: + frames = [] + frame_labels = [] + for sample_id in sorted(uniq_sample_ids): + frames_path = self.dataset_path / self.frames_paths[sample_id] + frames.append(self._load_frames(frames_path)) + + if all([frames_.ndim == 1 for frames_ in frames]): + # --> all 1-d audio vectors; if we specify `axis=1` here we'd get error + frames = np.concatenate(frames) + else: + frames = np.concatenate(frames, axis=1) + frame_labels = np.concatenate(frame_labels) + else: + raise ValueError( + f"Unexpected number of ``uniq_sample_ids``: {uniq_sample_ids}" + ) + + inds_in_sample = self.inds_in_sample[window_idx] + frames = frames[ + ..., + inds_in_sample : inds_in_sample + self.window_size, # noqa: E203 + ] + if self.transform: + frames = self.transform(frames) + + return frames + + def __len__(self): + """number of batches""" + return len(self.window_inds) + + @classmethod + def from_dataset_path( + cls, + dataset_path: str | pathlib.Path, + window_size: int, + stride: int = 1, + split: str = "train", + subset: str | None = None, + transform: Callable | None = None, + ): + """Make a :class:`WindowDataset` instance, + given the path to a VAE window dataset. + + Parameters + ---------- + dataset_path : pathlib.Path + Path to directory that represents a + frame classification dataset, + as created by + :func:`vak.prep.prep_frame_classification_dataset`. + window_size : int + Size of windows to return; + number of frames. + stride : int + The size of the stride used to determine which windows + are included in the dataset. The default is 1. + Used to compute ``window_inds``, + with the function + :func:`vak.datasets.frame_classification.window_dataset.get_window_inds`. + split : str + The name of a split from the dataset, + one of {'train', 'val', 'test'}. + subset : str, optional + Name of subset to use. + If specified, this takes precedence over split. + Subsets are typically taken from the training data + for use when generating a learning curve. + transform : callable + The transform applied to the input to the neural network :math:`x`. + + Returns + ------- + dataset : vak.datasets.vae.WindowDataset + """ + dataset_path = pathlib.Path(dataset_path) + metadata = Metadata.from_dataset_path(dataset_path) + frame_dur = metadata.frame_dur + input_type = metadata.input_type + + dataset_csv_path = dataset_path / metadata.dataset_csv_filename + dataset_df = pd.read_csv(dataset_csv_path) + + split_path = dataset_path / split + if subset: + sample_ids_path = ( + split_path + / helper.sample_ids_array_filename_for_subset(subset) + ) + else: + sample_ids_path = split_path / constants.SAMPLE_IDS_ARRAY_FILENAME + sample_ids = np.load(sample_ids_path) + + if subset: + inds_in_sample_path = ( + split_path + / helper.inds_in_sample_array_filename_for_subset(subset) + ) + else: + inds_in_sample_path = ( + split_path / constants.INDS_IN_SAMPLE_ARRAY_FILENAME + ) + inds_in_sample = np.load(inds_in_sample_path) + + window_inds_path = split_path / constants.WINDOW_INDS_ARRAY_FILENAME + if window_inds_path.exists(): + window_inds = np.load(window_inds_path) + else: + window_inds = None + + return cls( + dataset_path, + dataset_df, + input_type, + split, + sample_ids, + inds_in_sample, + window_size, + frame_dur, + stride, + subset, + window_inds, + transform, + ) diff --git a/src/vak/datasets/validators.py b/src/vak/datasets/validators.py new file mode 100644 index 000000000..237d63771 --- /dev/null +++ b/src/vak/datasets/validators.py @@ -0,0 +1,41 @@ +"""Validators used with metadata""" +def is_valid_dataset_csv_filename(instance, attribute, value): + valid = "_prep_" in value and value.endswith(".csv") + if not valid: + raise ValueError( + f"Invalid dataset csv filename: {value}." + f'Filename should contain the string "_prep_" ' + f"and end with the extension .csv." + f"Valid filenames are generated by " + f"vak.core.prep.generate_dataset_csv_filename" + ) + + +def is_valid_audio_format(instance, attribute, value): + import vak.common.constants + + if value not in vak.common.constants.VALID_AUDIO_FORMATS: + raise ValueError( + f"Not a valid audio format: {value}. Valid audio formats are: {vak.common.constants.VALID_AUDIO_FORMATS}" + ) + + +def is_valid_spect_format(instance, attribute, value): + import vak.common.constants + + if value not in vak.common.constants.VALID_SPECT_FORMATS: + raise ValueError( + f"Not a valid spectrogram format: {value}. " + f"Valid spectrogram formats are: {vak.common.constants.VALID_SPECT_FORMATS}" + ) + + +def is_valid_shape(instance, attribute, value): + if not isinstance(value, tuple): + raise TypeError( + f"`shape` should be a tuple but type was: {type(value)}" + ) + if not all([isinstance(val, int) and val > 0 for val in value]): + raise ValueError( + f"All values of `shape` should be positive integers but values were: {value}" + ) diff --git a/src/vak/models/__init__.py b/src/vak/models/__init__.py index 604fa408e..05db4cb5a 100644 --- a/src/vak/models/__init__.py +++ b/src/vak/models/__init__.py @@ -8,6 +8,8 @@ from .parametric_umap_model import ParametricUMAPModel from .registry import model_family from .tweetynet import TweetyNet +from .vae_model import VAEModel +from .ava import AVA __all__ = [ "base", @@ -23,4 +25,7 @@ "ParametricUMAPModel", "registry", "TweetyNet", + "VAEModel", + "AVA", + ] diff --git a/src/vak/models/ava.py b/src/vak/models/ava.py new file mode 100644 index 000000000..70f21537d --- /dev/null +++ b/src/vak/models/ava.py @@ -0,0 +1,35 @@ +"""Autoencoded Vocal Analysis (AVA) model [1]_. +Code is adapted from [2]_. + +.. [1] Goffinet, J., Brudner, S., Mooney, R., & Pearson, J. (2021). + Low-dimensional learned feature spaces quantify individual and group differences in vocal repertoires. + eLife, 10:e67855. https://doi.org/10.7554/eLife.67855 + +.. [2] https://github.com/pearsonlab/autoencoded-vocal-analysis +""" +from __future__ import annotations + +import torch +from torchmetrics import KLDivergence +from .. import nets +from .decorator import model +from .vae_model import VAEModel +from ..nn.loss import VaeElboLoss + + +@model(family=VAEModel) +class AVA: + """Autoencoded Vocal Analysis (AVA) model [1]_. + + .. [1] Goffinet, J., Brudner, S., Mooney, R., & Pearson, J. (2021). + Low-dimensional learned feature spaces quantify individual and group differences in vocal repertoires. + eLife, 10:e67855. https://doi.org/10.7554/eLife.67855 + """ + network = nets.AVA + loss = VaeElboLoss + optimizer = torch.optim.Adam + metrics = { + "loss": VaeElboLoss, + "kl": KLDivergence + } + default_config = {"optimizer": {"lr": 1e-3}} diff --git a/src/vak/models/get.py b/src/vak/models/get.py index e4cc3ab06..8d1ba9480 100644 --- a/src/vak/models/get.py +++ b/src/vak/models/get.py @@ -96,6 +96,16 @@ def get( else: config["network"]["encoder"] = dict(input_shape=input_shape) + model = model_class.from_config(config=config) + elif model_family == "VAEModel": + net_init_params = list( + inspect.signature( + model_class.definition.network.__init__ + ).parameters.keys() + ) + if "input_shape" in net_init_params: + config["network"]["input_shape"] = input_shape + model = model_class.from_config(config=config) else: raise ValueError( diff --git a/src/vak/models/vae_model.py b/src/vak/models/vae_model.py new file mode 100644 index 000000000..b3b348b93 --- /dev/null +++ b/src/vak/models/vae_model.py @@ -0,0 +1,86 @@ +from __future__ import annotations + +from typing import Callable, ClassVar, Type + +import torch +import torch.utils.data +from operator import itemgetter + +from .registry import model_family +from . import base +from .definition import ModelDefinition + + +@model_family +class VAEModel(base.Model): + definition: ClassVar[ModelDefinition] + def __init__( + self, + network: dict | None = None, + loss: torch.nn.Module | Callable | None = None, + optimizer: torch.optim.Optimizer | None = None, + metrics: dict[str:Type] | None = None, + ): + super().__init__( + network=network, loss=loss, optimizer=optimizer, metrics=metrics + ) + + def forward(self, x): + out, _ = self.network(x) + return out + + def encode(self, x): + return self.network.encoder(x) + + def decode(self, x): + return self.network.decoder(x) + + def configure_optimizers(self): + return self.optimizer + + def training_step(self, batch: tuple, batch_idx: int): + """ + """ + x = batch[0] + out, z, latent_dist= self.network(x) + loss = self.loss(x, z, out, latent_dist) + self.log("train_loss", loss) + return loss + + def training_step(self, batch: tuple, batch_idx: int): + """ + """ + x = batch["x"] + x_rec, z, latent_dist = self.network(x) + loss = self.loss(x, z, x_rec, latent_dist) + self.log("train_loss", loss) + return loss + + def validation_step(self, batch: tuple, batch_idx: int): + x = batch["x"] + x_rec, z, latent_dist = self.network(x) + for metric_name, metric_callable in self.metrics.items(): + if metric_name == "loss": + self.log( + f"val_{metric_name}", + metric_callable(x, z, x_rec, latent_dist), + on_step=True, + ) + elif metric_name == "acc": + self.log( + f"val_{metric_name}", + metric_callable(x_rec, x), + on_step=True, + ) + + @classmethod + def from_config( + cls, config: dict + ): + network, loss, optimizer, metrics = cls.attributes_from_config(config) + return cls( + network=network, + optimizer=optimizer, + loss=loss, + metrics=metrics, + ) diff --git a/src/vak/nets/__init__.py b/src/vak/nets/__init__.py index e31b90bff..22f91d3ea 100644 --- a/src/vak/nets/__init__.py +++ b/src/vak/nets/__init__.py @@ -2,6 +2,7 @@ from .conv_encoder import ConvEncoder from .ed_tcn import ED_TCN from .tweetynet import TweetyNet +from .ava import AVA __all__ = [ "conv_encoder", @@ -10,4 +11,5 @@ "ED_TCN", "tweetynet", "TweetyNet", + "AVA", ] diff --git a/src/vak/nets/ava.py b/src/vak/nets/ava.py new file mode 100644 index 000000000..b6bbfbee1 --- /dev/null +++ b/src/vak/nets/ava.py @@ -0,0 +1,274 @@ +"""AVA variational autoencoder, as described in [1]_. +Code is adapted from [2]_. + +.. [1] Goffinet, J., Brudner, S., Mooney, R., & Pearson, J. (2021). + Low-dimensional learned feature spaces quantify individual and group differences in vocal repertoires. + eLife, 10:e67855. https://doi.org/10.7554/eLife.67855 + +.. [2] https://github.com/pearsonlab/autoencoded-vocal-analysis +""" +from __future__ import annotations + +from typing import Sequence + +import numpy as np +import torch +from torch import nn +from torch.distributions import LowRankMultivariateNormal + + +class FullyConnectedLayers(nn.Module): + """Module containing two fully-connected layers. + + This module is used to parametrize :math:`\mu` + and :math:`\Sigma` in AVA. + """ + def __init__(self, n_features: Sequence[int]): + super().__init__() + self.layer = nn.Sequential( + nn.Linear(n_features[0], n_features[1]), + nn.ReLU(), + nn.Linear(n_features[1], n_features[2])) + + def forward(self, x): + return self.layer(x) + + +class AVA(nn.Module): + """AVA variational autoencoder, as described in [1]_. + Code is adapted from [2]_. + + Attributes + ---------- + input_shape + in_channels + x_shape + x_dim + encoder + fc_view + in_fc_dims + shared_encoder_fc + mu_fc + cov_factor_fc + cov_diag_fc + decoder_fc + decoder + + + References + ---------- + .. [1] Goffinet, J., Brudner, S., Mooney, R., & Pearson, J. (2021). + Low-dimensional learned feature spaces quantify individual and group differences in vocal repertoires. + eLife, 10:e67855. https://doi.org/10.7554/eLife.67855 + + .. [2] https://github.com/pearsonlab/autoencoded-vocal-analysis + """ + def __init__( + self, + input_shape: Sequence[int] = (1, 128, 128), + encoder_channels: Sequence[int] = (8, 8, 16, 16, 24, 24, 32), + fc_dims: Sequence[int] = (1024, 256, 64), + z_dim: int = 32, + ): + """Initalize a new instance of + an AVA variational autoencoder. + + Parameters + ---------- + input_shape : Sequence + Shape of input to network, a fixed size + for all spectrograms. + Tuple/list of integers, with dimensions + (channels, frequency bins, time bins). + Default is ``(1, 128, 128)``. + encoder_channels : Sequence + Number of channels in convolutional layers + of encoder. Tuple/list of integers. + Default is ``(8, 8, 16, 16, 24, 24, 32)``. + fc_dims : Sequence + Dimensionality of fully-connected layers. + Tuple/list of integers. + These values are used for the linear layers + in the encoder (``self.shared_encoder_fc``) + after passing through the convolutional layers, + as well as the linear layers + that are used to parametrize :math:`\mu` and + :math:`\Sigma`. + Default is (1024, 256, 64). + z_dim : int + Dimensionality of latent space. + Default is 32. + """ + super().__init__() + + self.input_shape = input_shape + self.in_channels = int(input_shape[0]) + self.x_shape = input_shape[1:] # channels * hide * width + self.x_dim = int(np.prod(self.x_shape)) + + # ---- build encoder + modules = [] + in_channels = self.in_channels + for out_channels in encoder_channels: + # AVA uses stride=2 when out_channels == in_channels + stride = 2 if out_channels == in_channels else 1 + modules.append( + nn.Sequential( + nn.BatchNorm2d(in_channels), + nn.Conv2d( + in_channels, out_channels, + kernel_size=3, stride=stride, padding=1 + ), + nn.ReLU() + ) + ) + in_channels = out_channels + self.encoder = nn.Sequential(*modules) + + # we compute shapes dynamically to make code more general + # we could compute this using equations for conv shape etc. to avoid running tensor through encoder + dummy_inp = torch.rand(1, *input_shape) + out = self.encoder(dummy_inp) + self.fc_view = tuple(out.shape[1:]) + out = torch.flatten(out, start_dim=1) + self.in_fc_dims = out.shape[1] + + # ---- build shared fully-connected layers of encoder + modules = [] + in_features = self.in_fc_dims + for out_features in fc_dims[:-1]: + modules.append( + nn.Sequential( + nn.Linear(in_features, out_features), + nn.ReLU() + ) + ) + in_features = out_features + self.shared_encoder_fc = nn.Sequential(*modules) + + fc_features = (*fc_dims[-2:], z_dim) + self.mu_fc = FullyConnectedLayers(fc_features) + self.cov_factor_fc = FullyConnectedLayers(fc_features) + self.cov_diag_fc = FullyConnectedLayers(fc_features) + + # ---- build fully-connected layers of decoder + modules = [] + decoder_dims = (*reversed(fc_dims), self.in_fc_dims) + in_features = z_dim + for i, out_features in enumerate(decoder_dims): + modules.append( + nn.Sequential( + nn.Linear(in_features, out_features), + nn.ReLU() + ) + ) + in_features = out_features + self.decoder_fc = nn.Sequential(*modules) + + # ---- build decoder + modules = [] + decoder_channels = (*reversed(encoder_channels[:-1]), self.in_channels) + in_channels = encoder_channels[-1] + for i, out_channels in enumerate(decoder_channels): + stride = 2 if out_channels == in_channels else 1 + output_padding = 1 if out_channels == in_channels else 0 + layers = [nn.BatchNorm2d(in_channels), + nn.ConvTranspose2d( + in_channels, out_channels, + kernel_size=3, stride=stride, padding=1, output_padding=output_padding + )] + if i < len(decoder_channels) - 1: + layers.append(nn.ReLU()) + modules.append(nn.Sequential(*layers)) + in_channels = out_channels + self.decoder = nn.Sequential(*modules) + + def encode(self, x): + """Encode a spectrogram ``x`` + by mapping it to a vector :math:`z` + in latent space. + + Parameters + ---------- + x : torch.Tensor + + Returns + ------- + z : torch.Tensor + latent_dist : torch.Tensor + """ + x = self.encoder(x) + x = torch.flatten(x, start_dim=1) + x = self.shared_encoder_fc(x) + mu = self.mu_fc(x) + cov_factor = self.cov_factor_fc(x).unsqueeze(-1) # Last dimension is rank \Sigma = 1 + cov_diag = torch.exp(self.cov_diag_fc(x)) # cov_diag must be positive + z, latent_dist = self.reparametrize(mu, cov_factor, cov_diag) + return z, latent_dist + + def decode(self, z): + """Decode a latent space vector ``z``, + mapping it back to a spectrogram :math:`x` + in the space of spectrograms :math:`\mathcal{X}`. + + Parameters + ---------- + z : torch.Tensor + Output of encoder, with dimensions + (batch size, latent space size). + + Returns + ------- + x : torch.Tensor + Output of decoder, with shape + (batch, channel, frequency bins, time bins). + """ + x = self.decoder_fc(z).view(-1, *self.fc_view) + x = self.decoder(x).view(-1, *self.input_shape) + return x + + @staticmethod + def reparametrize(mu, cov_factor, cov_diag): + """Sample a latent distribution + to get the latent embedding :math:`z`. + + Method that encapsulates the reparametrization trick. + + Parameters + ---------- + mu : torch.Tensor + cov_factor : torch.Tensor + cov_diag : torch.Tensor + + Returns + ------- + z : torch.Tensor + latent_dist : LowRankMultivariateNormal + """ + latent_dist = LowRankMultivariateNormal(mu, cov_factor, cov_diag) + z = latent_dist.rsample() + return z, latent_dist + + def forward(self, x): + """Pass a spectrogram ``x`` + through the variational autoencoder: + encode, then decode. + + Parameters + ---------- + x : torch.Tensor + + Returns + ------- + x_rec : torch.Tensor + Reconstruction of ``x``, + output of the decoder. + z : torch.Tensor + Latent space embedding of ``x``. + latent_dist : LowRankMultivariateNormal + Distribution parametrized + by the output of the encoder. + """ + z, latent_dist = self.encode(x) + x_rec = self.decode(z) + return x_rec, z, latent_dist diff --git a/src/vak/nn/loss/__init__.py b/src/vak/nn/loss/__init__.py index 18f4e6d2f..5435ae048 100644 --- a/src/vak/nn/loss/__init__.py +++ b/src/vak/nn/loss/__init__.py @@ -1,9 +1,13 @@ from .dice import DiceLoss, dice_loss from .umap import UmapLoss, umap_loss +from .vae import VaeElboLoss, vae_elbo_loss + __all__ = [ "DiceLoss", "dice_loss", "UmapLoss", "umap_loss", + "VaeElboLoss", + "vae_elbo_loss" ] diff --git a/src/vak/nn/loss/vae.py b/src/vak/nn/loss/vae.py new file mode 100644 index 000000000..7ad7c826c --- /dev/null +++ b/src/vak/nn/loss/vae.py @@ -0,0 +1,134 @@ +"""Evidence Lower Bound (ELBO) loss for a Variational Auto-Encpoder, +as used with the Autoencoded Vocal Analysis (AVA) model [1]_. +Code is adapted from [2]_. + +.. [1] Goffinet, J., Brudner, S., Mooney, R., & Pearson, J. (2021). + Low-dimensional learned feature spaces quantify individual and group differences in vocal repertoires. + eLife, 10:e67855. https://doi.org/10.7554/eLife.67855 + +.. [2] https://github.com/pearsonlab/autoencoded-vocal-analysis +""" + +from __future__ import annotations + +import math + +import numpy as np +import torch + + +PI = torch.tensor(math.pi) + +def vae_elbo_loss( + x: torch.Tensor, + z: torch.Tensor, + x_rec: torch.Tensor, + latent_dist: torch.distributions.LowRankMultivariateNormal, + model_precision: float, + z_dim: int +) -> torch.Tensor: + """Evidence Lower Bound (ELBO) loss for a Variational Auto-Encpoder, + as used with the Autoencoded Vocal Analysis (AVA) model [1]_. + + Notes + ----- + Code is adapted from [2]_. + + References + ---------- + .. [1] Goffinet, J., Brudner, S., Mooney, R., & Pearson, J. (2021). + Low-dimensional learned feature spaces quantify individual and group differences in vocal repertoires. + eLife, 10:e67855. https://doi.org/10.7554/eLife.67855 + + .. [2] https://github.com/pearsonlab/autoencoded-vocal-analysis + + Parameters + ---------- + x : torch.Tensor + z : torch.Tensor + x_rec : torch.Tensor + latent_dist + model_precision : float + z_dim : int + Dimensionality of latent space + + Returns + ------- + + """ + # E_{q(z|x)} p(z) + elbo = -0.5 * (torch.sum(torch.pow(z, 2) ) + z_dim * torch.log( 2 * PI )) + + # E_{q(z|x)} p(x|z) + x_dim = np.prod(x.shape[1:]) + pxz_term = -0.5 * x_dim * (torch.log(2 * PI / model_precision)) + l2s = torch.sum( + torch.pow( + x.view(x.shape[0], -1) - x_rec.view(x_rec.shape[0], -1), + 2), + dim=1 + ) + pxz_term = pxz_term - 0.5 * model_precision * torch.sum(l2s) + elbo = elbo + pxz_term + + # H[q(z|x)] + elbo = elbo + torch.sum(latent_dist.entropy()) + return -elbo + + +class VaeElboLoss(torch.nn.Module): + """Evidence Lower Bound (ELBO) loss for a Variational Auto-Encpoder, + as used with the Autoencoded Vocal Analysis (AVA) model [1]_. + + ELBO can be written as + :math:`L(\phi, \theta; x) = \text{ln} p_{\theta}(x) - D_{KL}(q_{\phi}(z|x) || p_{\theta}(z|x))` + where the first term is the *evidence* for :math:`x` + and the second is the Kullback-Leibler divergence between + :math:`q_{\phi}` and :math:`p_{\theta}`. + + Notes + ----- + Code is adapted from [2]_. + + References + ---------- + .. [1] Goffinet, J., Brudner, S., Mooney, R., & Pearson, J. (2021). + Low-dimensional learned feature spaces quantify individual and group differences in vocal repertoires. + eLife, 10:e67855. https://doi.org/10.7554/eLife.67855 + .. [2] https://github.com/pearsonlab/autoencoded-vocal-analysis + """ + def __init__( + self, + model_precision: float = 10.0, + z_dim: int = 32 + ): + super().__init__() + self.model_precision = model_precision + self.z_dim = z_dim + + def forward( + self, + x: torch.Tensor, + z: torch.Tensor, + x_rec: torch.Tensor, + latent_dist: torch.distributions.LowRankMultivariateNormal, + ): + """Compute ELBO loss + + Parameters + ---------- + x + z + x_rec + latent_dist + + Returns + ------- + + """ + return vae_elbo_loss( + x=x, z=z, x_rec=x_rec, + latent_dist=latent_dist, model_precision=self.model_precision, + z_dim=self.z_dim + ) + diff --git a/src/vak/prep/__init__.py b/src/vak/prep/__init__.py index 47c378daf..503f030f3 100644 --- a/src/vak/prep/__init__.py +++ b/src/vak/prep/__init__.py @@ -7,7 +7,8 @@ prep_, sequence_dataset, spectrogram_dataset, - unit_dataset, + segment_dataset, + vae, ) from .prep_ import prep @@ -21,5 +22,6 @@ "prep_", "sequence_dataset", "spectrogram_dataset", - "unit_dataset", + "segment_dataset", + "vae", ] diff --git a/src/vak/prep/constants.py b/src/vak/prep/constants.py index 68399dd4c..bc72686ee 100644 --- a/src/vak/prep/constants.py +++ b/src/vak/prep/constants.py @@ -2,7 +2,7 @@ Defined in a separate module to minimize circular imports. """ -from . import frame_classification, parametric_umap +from . import frame_classification, parametric_umap, vae VALID_PURPOSES = frozenset( [ @@ -18,6 +18,8 @@ DATASET_TYPE_FUNCTION_MAP = { "frame classification": frame_classification.prep_frame_classification_dataset, "parametric umap": parametric_umap.prep_parametric_umap_dataset, + "vae-window": vae.prep_vae_dataset, + "vae-segment": vae.prep_vae_dataset, } DATASET_TYPES = tuple(DATASET_TYPE_FUNCTION_MAP.keys()) diff --git a/src/vak/prep/frame_classification/make_splits.py b/src/vak/prep/frame_classification/make_splits.py index e4fd01564..cd5e1035c 100644 --- a/src/vak/prep/frame_classification/make_splits.py +++ b/src/vak/prep/frame_classification/make_splits.py @@ -127,6 +127,7 @@ def make_splits( spect_key: str = "s", timebins_key: str = "t", freqbins_key: str = "f", + prep_frame_label_vecs=True, ) -> pd.DataFrame: r"""Make each split of a frame classification dataset. @@ -234,6 +235,11 @@ def make_splits( Key for accessing vector of time bins in files. Default is 't'. freqbins_key : str key for accessing vector of frequency bins in files. Default is 'f'. + prep_frame_label_vecs : bool + If True, prepare vectors of labels for each frame. Default is True. + This option is used by + :func:`vak.prep.vae.prep_window_vae_dataset` + since those datasets do not require frame labels. Returns ------- @@ -353,7 +359,7 @@ def _save_dataset_arrays_and_return_index_arrays( inds_in_sample_vec = np.arange(n_frames) # add to frame labels - if annot: + if prep_frame_label_vecs and annot: lbls_int = [labelmap[lbl] for lbl in annot.seq.labels] frame_labels = transforms.frame_labels.from_segments( lbls_int, diff --git a/src/vak/prep/parametric_umap/__init__.py b/src/vak/prep/parametric_umap/__init__.py index fb80f20ef..30427dabb 100644 --- a/src/vak/prep/parametric_umap/__init__.py +++ b/src/vak/prep/parametric_umap/__init__.py @@ -1,7 +1,5 @@ -from . import dataset_arrays from .parametric_umap import prep_parametric_umap_dataset __all__ = [ - "dataset_arrays", "prep_parametric_umap_dataset", ] diff --git a/src/vak/prep/parametric_umap/parametric_umap.py b/src/vak/prep/parametric_umap/parametric_umap.py index 560b5a699..bec2835d6 100644 --- a/src/vak/prep/parametric_umap/parametric_umap.py +++ b/src/vak/prep/parametric_umap/parametric_umap.py @@ -1,3 +1,4 @@ +"""Prepare datasets for parametric UMAP models.""" from __future__ import annotations import json @@ -13,8 +14,8 @@ from ...common.logging import config_logging_for_cli, log_version from ...common.timenow import get_timenow_as_str from .. import dataset_df_helper, split -from ..unit_dataset import prep_unit_dataset -from . import dataset_arrays +from ..segment_dataset import learncurve, make_splits, prep_segment_dataset + logger = logging.getLogger(__name__) @@ -37,8 +38,7 @@ def prep_parametric_umap_dataset( spect_key: str = "s", timebins_key: str = "t", ): - """Prepare datasets for neural network models - that perform a dimensionality reduction task. + """Prepare datasets for parametric UMAP models. For general information on dataset preparation, see the docstring for :func:`vak.prep.prep`. @@ -158,13 +158,13 @@ def prep_parametric_umap_dataset( f"with ``purpose='{purpose}'." ) - logger.info(f"Purpose for frame classification dataset: {purpose}") + logger.info(f"Purpose for parametric UMAP dataset: {purpose}") # ---- set up directory that will contain dataset, and csv file name ----------------------------------------------- data_dir_name = data_dir.name timenow = get_timenow_as_str() dataset_path = ( output_dir - / f"{data_dir_name}-vak-dimensionality-reduction-dataset-generated-{timenow}" + / f"{data_dir_name}-vak-parametric-UMAP-dataset-generated-{timenow}" ) dataset_path.mkdir() @@ -211,7 +211,7 @@ def prep_parametric_umap_dataset( logger.info(f"Will prepare dataset as directory: {dataset_path}") # ---- actually make the dataset ----------------------------------------------------------------------------------- - dataset_df, shape = prep_unit_dataset( + dataset_df, shape = prep_segment_dataset( audio_format=audio_format, output_dir=dataset_path, spect_params=spect_params, @@ -224,8 +224,8 @@ def prep_parametric_umap_dataset( if dataset_df.empty: raise ValueError( - "Calling `vak.prep.unit_dataset.prep_unit_dataset` " - "with arguments passed to `vak.core.prep.prep_dimensionality_reduction_dataset` " + "Calling `vak.prep.segment_dataset.prep_segment_dataset` " + "with arguments passed to `vak.core.prep.prep_parametric_umap_dataset` " "returned an empty dataframe.\n" "Please double-check arguments to `vak.core.prep` function." ) @@ -242,7 +242,8 @@ def prep_parametric_umap_dataset( and (test_dur is None or val_dur == 0) ): raise ValueError( - "A duration specified for just training set, but prep function does not currently support creating a " + "A duration was specified for just the training set, " + "but prep function does not currently support creating a " "single split of a specified duration. Either remove the train_dur option from the prep section and " "rerun, in which case all data will be included in the training set, or specify values greater than " "zero for test_dur (and val_dur, if a validation set will be used)" @@ -267,7 +268,7 @@ def prep_parametric_umap_dataset( do_split = True if do_split: - dataset_df = split.unit_dataframe( + dataset_df = split.segment_dataframe( dataset_df, dataset_path, labelset=labelset, @@ -307,24 +308,20 @@ def prep_parametric_umap_dataset( labelmap = None # ---- make arrays that represent final dataset -------------------------------------------------------------------- - dataset_arrays.move_files_into_split_subdirs( + make_splits( dataset_df, dataset_path, - purpose, ) - # - # ---- if purpose is learncurve, additionally prep splits for that ----------------------------------------------- - # if purpose == 'learncurve': - # dataset_df = make_learncurve_splits_from_dataset_df( - # dataset_df, - # train_set_durs, - # num_replicates, - # dataset_path, - # labelmap, - # audio_format, - # spect_key, - # timebins_key, - # ) + + # ---- if purpose is learncurve, additionally prep splits for that ------------------------------------------------- + if purpose == 'learncurve': + dataset_df = learncurve.make_subsets_from_dataset_df( + dataset_df, + train_set_durs, + num_replicates, + dataset_path, + labelmap, + ) # ---- save csv file that captures provenance of source data ------------------------------------------------------- logger.info(f"Saving dataset csv file: {dataset_csv_path}") diff --git a/src/vak/prep/prep_.py b/src/vak/prep/prep_.py index a99e287d4..383401e7a 100644 --- a/src/vak/prep/prep_.py +++ b/src/vak/prep/prep_.py @@ -6,6 +6,7 @@ from . import constants from .frame_classification import prep_frame_classification_dataset from .parametric_umap import prep_parametric_umap_dataset +from .vae import prep_vae_dataset logger = logging.getLogger(__name__) @@ -31,6 +32,9 @@ def prep( spect_key: str = "s", timebins_key: str = "t", context_s: float = 0.015, + max_dur: float | None = None, + target_shape: tuple[int, int] | None = None, + normalize: bool = True, ): """Prepare datasets for use with neural network models. @@ -143,6 +147,31 @@ def prep( key for accessing spectrogram in files. Default is 's'. timebins_key : str key for accessing vector of time bins in files. Default is 't'. + context_s : float + Number of seconds of "context" around a segment to + add, i.e., time before and after the onset + and offset respectively. Default is 0.005s, + 5 milliseconds. This parameter is only used for + Parametric UMAP and segment-VAE datasets. + max_dur : float + Maximum duration for segments. + If a float value is specified, + any segment with a duration larger than + that value (in seconds) will be omitted + from the dataset. Default is None. + This parameter is only used for + segment-VAE datasets. + target_shape : tuple + Of ints, (target number of frequency bins, + target number of time bins). + Spectrograms of segments will be reshaped + by interpolation to have the specified + number of frequency and time bins. + The transformation is only applied if both this + parameter and ``max_dur`` are specified. + Default is None. + This parameter is only used for + segment-VAE datasets. Returns ------- @@ -232,6 +261,31 @@ def prep( timebins_key=timebins_key, ) return dataset_df, dataset_path + elif dataset_type in {"vae-segment", "vae-window"}: + dataset_df, dataset_path = prep_vae_dataset( + data_dir, + purpose, + dataset_type, + output_dir, + audio_format, + spect_format, + spect_params, + annot_format, + annot_file, + labelset, + audio_dask_bag_kwargs, + context_s, + max_dur, + target_shape, + train_dur, + val_dur, + test_dur, + train_set_durs, + num_replicates, + spect_key=spect_key, + timebins_key=timebins_key, + ) + return dataset_df, dataset_path else: # this is in case a dataset type is written wrong # in the if-else statements above, we want to error loudly diff --git a/src/vak/prep/segment_dataset/__init__.py b/src/vak/prep/segment_dataset/__init__.py new file mode 100644 index 000000000..22fc64d54 --- /dev/null +++ b/src/vak/prep/segment_dataset/__init__.py @@ -0,0 +1,5 @@ +from . import learncurve, segment_dataset +from .make_splits import make_splits +from .segment_dataset import prep_segment_dataset + +__all__ = ["learncurve", "make_splits", "prep_segment_dataset", "segment_dataset"] diff --git a/src/vak/prep/segment_dataset/learncurve.py b/src/vak/prep/segment_dataset/learncurve.py new file mode 100644 index 000000000..be530b8c0 --- /dev/null +++ b/src/vak/prep/segment_dataset/learncurve.py @@ -0,0 +1,121 @@ +"""Functionality to prepare subsets of the 'train' split of segment datasets, +for generating a learning curve.""" +from __future__ import annotations + +import logging +import pathlib +from typing import Sequence + +import pandas as pd + +from ... import common +from .. import split + + +logger = logging.getLogger(__name__) + + +def make_subsets_from_dataset_df( + dataset_df: pd.DataFrame, + train_set_durs: Sequence[float], + num_replicates: int, + dataset_path: pathlib.Path, + labelmap: dict, +) -> pd.DataFrame: + """Make subsets of the training data split for a learning curve. + + Makes subsets given a dataframe representing the entire dataset, + with one subset for each combination of (training set duration, + replicate number). Each subset is randomly drawn + from the total training split. + + Uses :func:`vak.prep.split.segment_dataset` to make + subsets of the training data from ``dataset_df``. + + A new column will be added to the dataframe, `'subset'`, + and additional rows for each subset. + The dataframe is returned with these subsets added. + (The `'split'` for these rows will still be `'train'`.) + + Parameters + ---------- + dataset_df : pandas.DataFrame + Dataframe representing a dataset for frame classification models. + It is returned by + :func:`vak.prep.segment_dataset.prep_segment_dataset`, + and has a ``'split'`` column added. + train_set_durs : list + Durations in seconds of subsets taken from training data + to create a learning curve, e.g., `[5., 10., 15., 20.]`. + num_replicates : int + number of times to replicate training for each training set duration + to better estimate metrics for a training set of that size. + Each replicate uses a different randomly drawn subset of the training + data (but of the same duration). + dataset_path : str, pathlib.Path + Directory where splits will be saved. + + Returns + ------- + dataset_df_out : pandas.DataFrame + A pandas.DataFrame that has the original splits + from ``dataset_df``, as well as the additional subsets + of the training data added, along with additional + columns, ``'subset', 'train_dur', 'replicate_num'``, + that are used by :mod:`vak`. + Other functions like :func:`vak.learncurve.learncurve` + specify a specific subset of the training data + by getting the subset name with the function + :func:`vak.common.learncurve.get_train_dur_replicate_split_name`, + and then filtering ``dataset_df_out`` with that name + using the 'subset' column. + """ + dataset_path = pathlib.Path(dataset_path) + + # get just train split, to pass to split.dataframe + # so we don't end up with other splits in the training set + train_split_df = dataset_df[dataset_df["split"] == "train"].copy() + labelset = set([k for k in labelmap.keys() if k != "unlabeled"]) + + # will concat after loop, then use ``csv_path`` to replace + # original dataset df with this one + subsets_df = [] + for train_dur in train_set_durs: + logger.info( + f"Subsetting training set for training set of duration: {train_dur}", + ) + for replicate_num in range(1, num_replicates + 1): + train_dur_replicate_subset_name = ( + common.learncurve.get_train_dur_replicate_subset_name( + train_dur, replicate_num + ) + ) + + train_dur_replicate_df = split.segment_dataframe( + # copy to avoid mutating original train_split_df + train_split_df.copy(), + dataset_path, + train_dur=train_dur, + labelset=labelset, + ) + # remove rows where split set to 'None' + train_dur_replicate_df = train_dur_replicate_df[ + train_dur_replicate_df.split == "train" + ] + # next line, make split name in csv match the split name used for directory in dataset dir + train_dur_replicate_df["subset"] = train_dur_replicate_subset_name + train_dur_replicate_df["train_dur"] = train_dur + train_dur_replicate_df["replicate_num"] = replicate_num + subsets_df.append(train_dur_replicate_df) + + subsets_df = pd.concat(subsets_df) + + # keep the same validation, test, and total train sets by concatenating them with the train subsets + dataset_df["subset"] = None # add column but have it be empty + dataset_df = pd.concat((subsets_df, dataset_df)) + # We reset the entire index across all splits, instead of repeating indices, + # and we set drop=False because we don't want to add a new column 'index' or 'level_0'. + # Need to do this again after calling `make_npy_files_for_each_split` since we just + # did `pd.concat` with the original dataframe + dataset_df = dataset_df.reset_index(drop=True) + return dataset_df diff --git a/src/vak/prep/parametric_umap/dataset_arrays.py b/src/vak/prep/segment_dataset/make_splits.py similarity index 77% rename from src/vak/prep/parametric_umap/dataset_arrays.py rename to src/vak/prep/segment_dataset/make_splits.py index 67e224ae7..a88858e43 100644 --- a/src/vak/prep/parametric_umap/dataset_arrays.py +++ b/src/vak/prep/segment_dataset/make_splits.py @@ -12,31 +12,27 @@ logger = logging.getLogger(__name__) -def move_files_into_split_subdirs( - dataset_df: pd.DataFrame, dataset_path: pathlib.Path, purpose: str +def make_splits( + dataset_df: pd.DataFrame, dataset_path: pathlib.Path ) -> None: """Move npy files in dataset into sub-directories, one for each split in the dataset. - This is run *after* calling :func:`vak.prep.unit_dataset.prep_unit_dataset` + This is run *after* calling :func:`vak.prep.segment_dataset.prep_segment_dataset` to generate ``dataset_df``. Parameters ---------- dataset_df : pandas.DataFrame A ``pandas.DataFrame`` returned by - :func:`vak.prep.unit_dataset.prep_unit_dataset` - with a ``'split'`` column added, as a result of calling - :func:`vak.prep.split.unit_dataframe` or because it was added "manually" - by calling :func:`vak.core.prep.prep_helper.add_split_col` (as is done - for 'predict' when the entire ``DataFrame`` belongs to this - "split"). + :func:`vak.prep.segment_dataset.prep_segment_dataset` + with a ``'split'`` column added. The ```split'`` is added + as a result of calling :func:`vak.prep.split.segment_dataframe`, + or because it was added "manually" + by calling :func:`vak.core.prep.prep_helper.add_split_col` + (as is done for 'predict' when the entire ``DataFrame`` + belongs to this "split"). dataset_path : pathlib.Path Path to directory that represents dataset. - purpose: str - A string indicating what the dataset will be used for. - One of {'train', 'eval', 'predict', 'learncurve'}. - Determined by :func:`vak.core.prep.prep` - using the TOML configuration file. Returns ------- @@ -104,11 +100,10 @@ def move_files_into_split_subdirs( dataset_df.loc[split_df.index, "spect_path"] = new_spect_paths # ---- clean up after moving/copying ------------------------------------------------------------------------------- - # remove any directories that we just emptied - if moved_spect_paths: - unique_parents = set( - [moved_spect.parent for moved_spect in moved_spect_paths] - ) - for parent in unique_parents: - if len(list(parent.iterdir())) < 1: - shutil.rmtree(parent) + # Remove any npy files that were *not* added to a split + npy_files_not_in_split = sorted( + dataset_path.glob(f"*npy") + ) + if len(npy_files_not_in_split) > 0: + for npy_file in npy_files_not_in_split: + npy_file.unlink() diff --git a/src/vak/prep/segment_dataset/segment_dataset.py b/src/vak/prep/segment_dataset/segment_dataset.py new file mode 100644 index 000000000..98d18e857 --- /dev/null +++ b/src/vak/prep/segment_dataset/segment_dataset.py @@ -0,0 +1,620 @@ +"""Functions for making a dataset of segments, +as used to train parametric UMAP and AVA models.""" +from __future__ import annotations + +import logging +import os +import pathlib + +import attrs +import crowsetta +import dask +import dask.delayed +import numpy as np +import numpy.typing as npt +from scipy.interpolate import RegularGridInterpolator +import pandas as pd +from dask.diagnostics import ProgressBar + +from ...common import annotation, constants +from ...common.converters import expanded_user_path, labelset_to_set +from ...config.spect_params import SpectParamsConfig +from ..spectrogram_dataset.audio_helper import files_from_dir +from ..spectrogram_dataset.spect import spectrogram + +logger = logging.getLogger(__name__) + + +@attrs.define +class Segment: + """Dataclass that represents a segment + from segmented audio or spectrogram. + + The attributes are metadata used to track + the origin of this segment in a dataset + of such segments. + + The dataset including metadata is saved as a csv file + where these attributes become the columns. + """ + + data: npt.NDArray + samplerate: int + onset_s: float + offset_s: float + label: str + sample_dur: float + segment_dur: float + audio_path: str + annot_path: str + + +@dask.delayed +def get_segment_list( + audio_path: str, + annot: crowsetta.Annotation, + audio_format: str, + context_s: float = 0.005, + max_dur: float | None = None +) -> list[Segment]: + """Get a list of :class:`Segment` instances, given + the path to an audio file and an annotation that indicates + where segments occur in that audio file. + + Parameters + ---------- + audio_path : str + Path to an audio file. + annot : crowsetta.Annotation + Annotation for audio file. + audio_format : str + String representing audio file format, e.g. 'wav'. + context_s : float + Number of seconds of "context" around segment to + add, i.e., time before and after the onset + and offset respectively. Default is 0.005s, + 5 milliseconds. + max_dur : float + Maximum duration for segments. + If a float value is specified, + any segment with a duration larger than + that value (in seconds) will be omitted + from the returned list of segments. + Default is None. + + Returns + ------- + segments : list + A :class:`list` of :class:`Segment` instances. + + Notes + ----- + Function used by + :func:`vak.prep.segment_dataset.prep_segment_dataset`. + """ + data, samplerate = constants.AUDIO_FORMAT_FUNC_MAP[audio_format]( + audio_path + ) + sample_dur = 1.0 / samplerate + + segments = [] + for segment_num, (onset_s, offset_s, label) in enumerate(zip( + annot.seq.onsets_s, annot.seq.offsets_s, annot.seq.labels + )): + if max_dur is not None: + segment_dur = offset_s - onset_s + if segment_dur > max_dur: + logger.info( + f"Segment {segment_num} in {pathlib.Path(audio_path).name}, " + f"with onset at {onset_s}s and offset at {offset_s}s with label '{label}'," + f"has duration ({segment_dur}) that is greater than " + f"maximum allowed duration ({max_dur})." + "Omitting segment from dataset." + ) + continue + onset_s -= context_s + offset_s += context_s + onset_ind = int(np.floor(onset_s * samplerate)) + offset_ind = int(np.ceil(offset_s * samplerate)) + segment_data = data[onset_ind : offset_ind + 1] # noqa: E203 + segment_dur = segment_data.shape[-1] * sample_dur + segment = Segment( + segment_data, + samplerate, + onset_s, + offset_s, + label, + sample_dur, + segment_dur, + audio_path, + annot.annot_path, + ) + segments.append(segment) + + return segments + + +def spectrogram_from_segment( + segment: Segment, + spect_params: SpectParamsConfig, +) -> npt.NDArray: + """Compute a spectrogram given a :class:`Segment` instance. + + Parameters + ---------- + segment : Segment + spect_params : SpectParamsConfig + + + Returns + ------- + spect : numpy.ndarray + + Notes + ----- + Function used by + :func:`vak.prep.segment_dataset.prep_segment_dataset`. + """ + data, samplerate = np.array(segment.data), segment.samplerate + s, f, t = spectrogram( + data, + samplerate, + spect_params.fft_size, + spect_params.step_size, + spect_params.thresh, + spect_params.transform_type, + spect_params.freq_cutoffs, + spect_params.min_val, + spect_params.max_val, + spect_params.normalize, + ) + + return s, f, t + + +@attrs.define +class SpectToSave: + """A spectrogram to be saved. + + Used by :func:`save_spect`. + """ + + spect: npt.NDArray + f: npt.NDArray + t: npt.NDArray + ind: int + audio_path: str + + +def save_spect( + spect_to_save: SpectToSave, output_dir: str | pathlib.Path +) -> str: + """Save a spectrogram array to an npy file. + + The filename is build from the attributes of ``spect_to_save``, + saved in output dir, and the full path is returned as a string. + + Parameters + ---------- + spect_to_save : SpectToSave + output_dir : str, pathlib.Path + + Returns + ------- + npz_path : str + Path to npz file containing spectrogram inside ``output_dir`` + """ + spect_dict = { + "s": spect_to_save.spect, + "f": spect_to_save.f, + "t": spect_to_save.t, + } + + basename = ( + os.path.basename(spect_to_save.audio_path) + + f"-segment-{spect_to_save.ind}" + ) + npz_path = os.path.join( + os.path.normpath(output_dir), basename + ".spect.npz" + ) + np.savez(npz_path, **spect_dict) + return npz_path + + +def abspath(a_path): + """Convert a path to an absolute path""" + if isinstance(a_path, str) or isinstance(a_path, pathlib.Path): + return str(pathlib.Path(a_path).absolute()) + elif np.isnan(a_path): + return a_path + + +# ---- make spectrograms + records for dataframe ----------------------------------------------------------------------- +@dask.delayed +def make_spect_return_record( + segment: Segment, + ind: int, + spect_params: SpectParamsConfig, + output_dir: pathlib.Path, +) -> tuple[tuple, int, float]: + """Helper function that enables parallelized creation of "records", + i.e. rows for dataframe, from . + Accepts a two-element tuple containing (1) a dictionary that represents a spectrogram + and (2) annotation for that file""" + + s, f, t = spectrogram_from_segment( + segment, + spect_params, + ) + n_timebins = s.shape[-1] + + spect_to_save = SpectToSave(s, f, t, ind, segment.audio_path) + spect_path = save_spect(spect_to_save, output_dir) + record = tuple( + [ + abspath(spect_path), + abspath(segment.audio_path), + abspath(segment.annot_path), + segment.onset_s, + segment.offset_s, + segment.label, + segment.samplerate, + segment.sample_dur, + segment.segment_dur, + ] + ) + + return record, n_timebins, s.mean() + + +@dask.delayed +def pad_spectrogram(record: tuple, pad_length: float, padval: float = 0.) -> None: + """Pads a spectrogram to a specified length on the left and right sides. + + Spectrogram is saved again after padding. + + Parameters + ---------- + record : tuple + Returned by :func:`make_spect_return_record`, + has path to spectrogram file. + pad_length : int + Length to which spectrogram should be padded. + + Returns + ------- + shape : tuple + Shape of spectrogram after padding. + """ + spect_path = record[0] # 'spect_path' + spect_dict = np.load(spect_path) + spect = spect_dict["s"] + + excess_needed = pad_length - spect.shape[-1] + pad_left = np.floor(float(excess_needed) / 2).astype("int") + pad_right = np.ceil(float(excess_needed) / 2).astype("int") + spect_padded = np.pad( + spect, [(0, 0), (pad_left, pad_right)], "constant", constant_values=padval + ) + new_spect_path = str(spect_path).replace(".npz", ".npy") + np.save(new_spect_path, spect_padded) + return new_spect_path, spect_padded.shape + + +@dask.delayed +def interp_spectrogram( + record: tuple, + max_dur: float, + target_shape: tuple[int, int], + normalize: bool = True, + fill_value: float = 0. +): + """Linearly interpolate a spectrogram to a target shape. + + Spectrogram is saved again after interpolation. + + Uses :func:`scipy.interpolate.RegularGridInterpolator` + to treat the spectrogram as if it were a function of the + frequencies vector :math:`f` and the times vector :math:`t`, + then interpolates given new frequencies and times + with the same range but with the number of values + specified by the argument ``target_shape``. + + Parameters + ---------- + record : tuple + Returned by :func:`make_spect_return_record`, + has path to spectrogram file. + max_dur : float + Maximum duration for segments. + Used with ``target_shape`` when reshaping + the spectrogram via interpolation. + Default is None. + target_shape : tuple + Of ints, (target number of frequency bins, + target number of time bins). + Spectrograms of segments will be reshaped + by interpolation to have the specified + number of frequency and time bins. + The transformation is only applied if both this + parameter and ``max_dur`` are specified. + Default is None. + normalize : bool + If True, min-max normalize the spectrogram. + Default is True. + + Returns + ------- + shape : tuple + Shape of spectrogram after interpolation. + """ + spect_path = record[0] # 'spect_path' + spect_dict = np.load(spect_path) + s = spect_dict["s"] + f = spect_dict["f"] + t = spect_dict["t"] + + # if max_dur and target_shape are specified we interpolate spectrogram to target shape, like AVA + target_freqs = np.linspace(f.min(), f.max(), target_shape[0]) + duration = t.max() - t.min() + new_duration = np.sqrt(duration * max_dur) # stretched duration + shoulder = 0.5 * (max_dur - new_duration) + target_times = np.linspace(t.min() - shoulder, t.max() + shoulder, target_shape[1]) + ttnew, ffnew = np.meshgrid(target_times, target_freqs, indexing='ij', sparse=True) + r = RegularGridInterpolator((t, f), s.T, bounds_error=False, fill_value=fill_value) + s = r((ttnew, ffnew)).T + if normalize: + s_max, s_min = s.max(), s.min() + s = (s - s_min) / (s_max - s_min) + s = np.clip(s, 0.0, 1.0) + new_spect_path = str(spect_path).replace(".npz", ".npy") + np.save(new_spect_path, s) + return new_spect_path, s.shape + + +# constant, used for names of columns in DataFrame below +DF_COLUMNS = [ + "spect_path", + "audio_path", + "annot_path", + "onset_s", + "offset_s", + "label", + "samplerate", + "sample_dur", + "duration", +] + + +def prep_segment_dataset( + audio_format: str, + output_dir: str | pathlib.Path, + spect_params: SpectParamsConfig, + data_dir: str | pathlib.Path, + annot_format: str | None = None, + annot_file: str | pathlib.Path | None = None, + labelset: set | None = None, + context_s: float = 0.005, + max_dur: float | None = None, + target_shape: tuple[int, int] | None = None, +) -> tuple[pd.DataFrame, tuple[int]]: + """Prepare a dataset of segments. + + Finds segments with a segmenting algorithm, + then computes a spectrogram for each segment + and saves in npy files. + Finally, assigns each npy file to a split + and moves files into split directories + inside the directory representing the dataset. + + Parameters + ---------- + audio_format : str + Format of audio files. One of {'wav', 'cbin'}. + Default is ``None``, but either ``audio_format`` or ``spect_format`` + must be specified. + output_dir : str + Path to location where data sets should be saved. + Default is ``None``, in which case it defaults to ``data_dir``. + spect_params : dict, vak.config.SpectParams + Parameters for creating spectrograms. Default is ``None``. + data_dir : str, pathlib.Path + Path to directory with files from which to make dataset. + annot_format : str + Format of annotations. Any format that can be used with the + :mod:`crowsetta` library is valid. Default is ``None``. + annot_file : str + Path to a single annotation file. Default is ``None``. + Used when a single file contains annotates multiple audio + or spectrogram files. + labelset : str, list, set + Set of unique labels for vocalizations. Strings or integers. + Default is ``None``. If not ``None``, then files will be skipped + where the associated annotation + contains labels not found in ``labelset``. + ``labelset`` is converted to a Python ``set`` using + :func:`vak.converters.labelset_to_set`. + See help for that function for details on how to specify ``labelset``. + context_s : float + Number of seconds of "context" around segment to + add, i.e., time before and after the onset + and offset respectively. Default is 0.005s, + 5 milliseconds. + max_dur : float + Maximum duration for segments. + If a float value is specified, + any segment with a duration larger than + that value (in seconds) will be omitted + from the dataset. Default is None. + target_shape : tuple + Of ints, (target number of frequency bins, + target number of time bins). + Spectrograms of segments will be reshaped + by interpolation to have the specified + number of frequency and time bins. + The transformation is only applied if both this + parameter and ``max_dur`` are specified. + Default is None. + + Returns + ------- + segment_df : pandas.DataFrame + A DataFrame representing all the segments in the dataset. + shape: tuple + A tuple representing the shape of all spectrograms in the dataset. + The spectrograms of all segments are padded so that they are all + as wide as the widest segment (i.e, the one with the longest duration). + """ + # pre-conditions --------------------------------------------------------------------------------------------------- + if audio_format not in constants.VALID_AUDIO_FORMATS: + raise ValueError( + f"audio format must be one of '{constants.VALID_AUDIO_FORMATS}'; " + f"format '{audio_format}' not recognized." + ) + + if labelset is not None: + labelset = labelset_to_set(labelset) + + data_dir = expanded_user_path(data_dir) + if not data_dir.is_dir(): + raise NotADirectoryError(f"data_dir not found: {data_dir}") + + audio_files = files_from_dir(data_dir, audio_format) + + if annot_format is not None: + if annot_file is None: + annot_files = annotation.files_from_dir( + annot_dir=data_dir, annot_format=annot_format + ) + scribe = crowsetta.Transcriber(format=annot_format) + annot_list = [ + scribe.from_file(annot_file).to_annot() + for annot_file in annot_files + ] + else: + scribe = crowsetta.Transcriber(format=annot_format) + annot_list = scribe.from_file(annot_file).to_annot() + if isinstance(annot_list, crowsetta.Annotation): + # if e.g. only one annotated audio file in directory, wrap in a list to make iterable + # fixes https://github.com/NickleDave/vak/issues/467 + annot_list = [annot_list] + else: # if annot_format not specified + annot_list = None + + if annot_list: + audio_annot_map = annotation.map_annotated_to_annot( + audio_files, annot_list, annot_format + ) + else: + # no annotation, so map spectrogram files to None + audio_annot_map = dict( + (audio_path, None) for audio_path in audio_files + ) + + # use labelset, if supplied, with annotations, if any, to filter; + if ( + labelset and annot_list + ): # then remove annotations with labels not in labelset + for audio_file, annot in list(audio_annot_map.items()): + # loop in a verbose way (i.e. not a comprehension) + # so we can give user warning when we skip files + annot_labelset = set(annot.seq.labels) + # below, set(labels_mapping) is a set of that dict's keys + if not annot_labelset.issubset(set(labelset)): + # because there's some label in labels that's not in labelset + audio_annot_map.pop(audio_file) + extra_labels = annot_labelset - labelset + logger.info( + f"Found labels, {extra_labels}, in {pathlib.Path(audio_file).name}, " + "that are not in labels_mapping. Skipping file.", + ) + + segments = [] + for audio_path, annot in audio_annot_map.items(): + segment_list = dask.delayed(get_segment_list)( + audio_path, annot, audio_format, context_s, max_dur + ) + segments.append(segment_list) + + logger.info( + "Loading audio for all segments in all files", + ) + with ProgressBar(): + segments: list[list[Segment]] = dask.compute(*segments) + segments: list[Segment] = [ + segment for segment_list in segments for segment in segment_list + ] + + # ---- make and save all spectrograms *before* interpolating or padding + # This is a design choice to avoid keeping all the spectrograms in memory + # but since we want to pad all spectrograms to be the same width, + # it requires us to go back, load each one, and pad it. + # Might be worth looking at how often typical dataset sizes in memory and whether this is really necessary. + records_n_timebins_tuples = [] + for ind, segment in enumerate(segments): + records_n_timebins_tuple = make_spect_return_record( + segment, ind, spect_params, output_dir, + ) + records_n_timebins_tuples.append(records_n_timebins_tuple) + with ProgressBar(): + records_n_timebins_tuples: list[tuple[tuple, int]] = dask.compute( + *records_n_timebins_tuples + ) + + # we use n_timebins to pad to the same length, + # and spect_means to fill with the mean across all spectrograms + # when we interpolate + records, n_timebins_list = [], [] + for records_n_timebins_tuple in records_n_timebins_tuples: + record, n_timebins, spect_mean = records_n_timebins_tuple + records.append(record) + n_timebins_list.append(n_timebins) + + # ---- either interpolate or pad spectrograms so they are all the same size + fill_value = spect_params.min_val if spect_params.min_val else 0. + + if max_dur is not None and target_shape is not None: + interpolated = [] + for record in records: + interpolated.append( + interp_spectrogram( + record, max_dur, target_shape, spect_params.normalize, fill_value + )) + with ProgressBar(): + path_shape_tuples = dask.compute(*interpolated) + + else: + # then we pad + pad_length = max(n_timebins_list) + + padded = [] + for record in records: + padded.append(pad_spectrogram(record, pad_length, padval=fill_value)) + with ProgressBar(): + path_shape_tuples = dask.compute(*padded) + + # ---- clean up npz files with spectrograms, don't need anymore + npz_files = sorted(output_dir.glob('*npz')) + for npz_file in npz_files: + npz_file.unlink() + + paths, shapes = [], [] + for path, shape in path_shape_tuples: + paths.append(path) + shapes.append(shape) + shape = set(shapes) + assert ( + len(shape) == 1 + ), f"Did not find a single unique shape for all spectrograms. Instead found: {shape}" + shape = shape.pop() + + new_records = [] + for record, path in zip(records, paths): + new_records.append( + tuple( + [path, *record[1:]] + ) + ) + segment_df = pd.DataFrame.from_records(new_records, columns=DF_COLUMNS) + + return segment_df, shape diff --git a/src/vak/prep/spectrogram_dataset/spect.py b/src/vak/prep/spectrogram_dataset/spect.py index d4d84ada0..910344eb4 100644 --- a/src/vak/prep/spectrogram_dataset/spect.py +++ b/src/vak/prep/spectrogram_dataset/spect.py @@ -5,7 +5,10 @@ spectrogram adapted from code by Kyle Kastner and Tim Sainburg https://github.com/timsainb/python_spectrograms_and_inversion """ +from __future__ import annotations + import numpy as np +import numpy.typing as npt from matplotlib.mlab import specgram from scipy.signal import butter, lfilter @@ -25,14 +28,17 @@ def butter_bandpass_filter(data, lowcut, highcut, fs, order=5): def spectrogram( - dat, - samp_freq, - fft_size=512, - step_size=64, - thresh=None, - transform_type=None, - freq_cutoffs=None, -): + dat: npt.NDArray, + samp_freq: int, + fft_size: int = 512, + step_size: int = 64, + thresh: float | None = None, + transform_type: str | None = None, + freq_cutoffs: list[int, int] | None = None, + min_val: float | None = None, + max_val: float | None = None, + normalize: bool = False, +) -> tuple[npt.NDArray, npt.NDArray, npt.NDArray]: """creates a spectrogram Parameters @@ -54,6 +60,24 @@ def spectrogram( threshold minimum power for log spectrogram freq_cutoffs : tuple of two elements, lower and higher frequencies. + min_val : float, optional + Minimum value to allow in spectrogram. + All values less than this will be set to this value. + This operation is applied *after* the transform + specified by ``transform_type``. + Default is None. + max_val : float, optional + Maximum value to allow in spectrogram. + All values greater than this will be set to this value. + This operation is applied *after* the transform + specified by ``transform_type``. + Default is None. + normalize : bool + If True, min-max normalize the spectrogram. + Normalization is done *after* the transform + specified by ``transform_type``, and *after* + the ``min_val`` and ``max_val`` operations. + Default is False. Return ------ @@ -77,7 +101,9 @@ def spectrogram( )[:3] if transform_type: - if transform_type == "log_spect": + if transform_type == "log": + spect = np.log(np.abs(spect) + np.finfo(spect.dtype).eps) + elif transform_type == "log_spect": spect /= spect.max() # volume normalize to max 1 spect = np.log10(spect) # take log if thresh: @@ -93,6 +119,16 @@ def spectrogram( spect < thresh ] = thresh # set anything less than the threshold as the threshold + if min_val: + spect[spect < min_val] = min_val + if max_val: + spect[spect > max_val] = max_val + + if normalize: + s_max, s_min = spect.max(), spect.min() + spect = (spect - s_min) / (s_max - s_min) + spect = np.clip(spect, 0.0, 1.0) + if freq_cutoffs: f_inds = np.nonzero( (freqbins >= freq_cutoffs[0]) & (freqbins < freq_cutoffs[1]) diff --git a/src/vak/prep/split/__init__.py b/src/vak/prep/split/__init__.py index e8c9a001e..c3114279b 100644 --- a/src/vak/prep/split/__init__.py +++ b/src/vak/prep/split/__init__.py @@ -1,8 +1,8 @@ from . import algorithms -from .split import frame_classification_dataframe, unit_dataframe +from .split import frame_classification_dataframe, segment_dataframe __all__ = [ "algorithms", "frame_classification_dataframe", - "unit_dataframe", + "segment_dataframe", ] diff --git a/src/vak/prep/split/split.py b/src/vak/prep/split/split.py index 23d37dd49..932cfe5bc 100644 --- a/src/vak/prep/split/split.py +++ b/src/vak/prep/split/split.py @@ -178,7 +178,7 @@ def frame_classification_dataframe( return dataset_df -def unit_dataframe( +def segment_dataframe( dataset_df: pd.DataFrame, dataset_path: str | pathlib.Path, labelset: set, @@ -187,7 +187,7 @@ def unit_dataframe( val_dur: float | None = None, ): """Create datasets splits from a dataframe - representing a unit dataset. + representing a segment dataset. Splits dataset into training, test, and (optionally) validation subsets, specified by their duration. diff --git a/src/vak/prep/unit_dataset/__init__.py b/src/vak/prep/unit_dataset/__init__.py deleted file mode 100644 index bf68aa74b..000000000 --- a/src/vak/prep/unit_dataset/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from . import unit_dataset -from .unit_dataset import prep_unit_dataset - -__all__ = ["prep_unit_dataset", "unit_dataset"] diff --git a/src/vak/prep/unit_dataset/unit_dataset.py b/src/vak/prep/unit_dataset/unit_dataset.py deleted file mode 100644 index 76a0e29b0..000000000 --- a/src/vak/prep/unit_dataset/unit_dataset.py +++ /dev/null @@ -1,410 +0,0 @@ -"""Functions for making a dataset of units from sequences, -as used to train dimensionality reduction models.""" -from __future__ import annotations - -import logging -import os -import pathlib - -import attrs -import crowsetta -import dask -import dask.delayed -import numpy as np -import numpy.typing as npt -import pandas as pd -from dask.diagnostics import ProgressBar - -from ...common import annotation, constants -from ...common.converters import expanded_user_path, labelset_to_set -from ..spectrogram_dataset.audio_helper import files_from_dir -from ..spectrogram_dataset.spect import spectrogram - -logger = logging.getLogger(__name__) - - -@attrs.define -class Segment: - """Dataclass that represents a segment - from segmented audio or spectrogram. - - The attributes are metadata used to track - the origin of this segment in a dataset - of such segments. - - The dataset including metadata is saved as a csv file - where these attributes become the columns. - """ - - data: npt.NDArray - samplerate: int - onset_s: float - offset_s: float - label: str - sample_dur: float - segment_dur: float - audio_path: str - annot_path: str - - -@dask.delayed -def get_segment_list( - audio_path: str, - annot: crowsetta.Annotation, - audio_format: str, - context_s: float = 0.005, -) -> list[Segment]: - """Get a list of :class:`Segment` instances, given - the path to an audio file and an annotation that indicates - where segments occur in that audio file. - - Function used by - :func:`vak.prep.dimensionality_reduction.unit_dataset.prep_unit_dataset`. - - Parameters - ---------- - audio_path : str - Path to an audio file. - annot : crowsetta.Annotation - Annotation for audio file. - audio_format : str - String representing audio file format, e.g. 'wav'. - context_s : float - Number of seconds of "context" around unit to - add, i.e., time before and after the onset - and offset respectively. Default is 0.005s, - 5 milliseconds. - - Returns - ------- - segments : list - A :class:`list` of :class:`Segment` instances. - """ - data, samplerate = constants.AUDIO_FORMAT_FUNC_MAP[audio_format]( - audio_path - ) - sample_dur = 1.0 / samplerate - - segments = [] - for onset_s, offset_s, label in zip( - annot.seq.onsets_s, annot.seq.offsets_s, annot.seq.labels - ): - onset_s -= context_s - offset_s += context_s - onset_ind = int(np.floor(onset_s * samplerate)) - offset_ind = int(np.ceil(offset_s * samplerate)) - segment_data = data[onset_ind : offset_ind + 1] # noqa: E203 - segment_dur = segment_data.shape[-1] * sample_dur - segment = Segment( - segment_data, - samplerate, - onset_s, - offset_s, - label, - sample_dur, - segment_dur, - audio_path, - annot.annot_path, - ) - segments.append(segment) - - return segments - - -def spectrogram_from_segment( - segment: Segment, spect_params: dict -) -> npt.NDArray: - """Compute a spectrogram given a :class:`Segment` instance. - - Parameters - ---------- - segment : Segment - spect_params : dict - - Returns - ------- - spect : numpy.ndarray - """ - data, samplerate = np.array(segment.data), segment.samplerate - s, _, _ = spectrogram( - data, - samplerate, - spect_params.fft_size, - spect_params.step_size, - spect_params.thresh, - spect_params.transform_type, - spect_params.freq_cutoffs, - ) - return s - - -@attrs.define -class SpectToSave: - """A spectrogram to be saved. - - Used by :func:`save_spect`. - """ - - spect: npt.NDArray - ind: int - audio_path: str - - -def save_spect( - spect_to_save: SpectToSave, output_dir: str | pathlib.Path -) -> str: - """Save a spectrogram array to an npy file. - - The filename is build from the attributes of ``spect_to_save``, - saved in output dir, and the full path is returned as a string. - - Parameters - ---------- - spect_to_save : SpectToSave - output_dir : str, pathlib.Path - - Returns - ------- - npy_path : str - Path to npy file containing spectrogram inside ``output_dir`` - """ - basename = ( - os.path.basename(spect_to_save.audio_path) - + f"-segment-{spect_to_save.ind}" - ) - npy_path = os.path.join( - os.path.normpath(output_dir), basename + ".spect.npy" - ) - np.save(npy_path, spect_to_save.spect) - return npy_path - - -def abspath(a_path): - """Convert a path to an absolute path""" - if isinstance(a_path, str) or isinstance(a_path, pathlib.Path): - return str(pathlib.Path(a_path).absolute()) - elif np.isnan(a_path): - return a_path - - -# ---- make spectrograms + records for dataframe ----------------------------------------------------------------------- -@dask.delayed -def make_spect_return_record( - segment: Segment, ind: int, spect_params: dict, output_dir: pathlib.Path -) -> tuple: - """Helper function that enables parallelized creation of "records", - i.e. rows for dataframe, from . - Accepts a two-element tuple containing (1) a dictionary that represents a spectrogram - and (2) annotation for that file""" - - spect = spectrogram_from_segment(segment, spect_params) - n_timebins = spect.shape[-1] - - spect_to_save = SpectToSave(spect, ind, segment.audio_path) - spect_path = save_spect(spect_to_save, output_dir) - record = tuple( - [ - abspath(spect_path), - abspath(segment.audio_path), - abspath(segment.annot_path), - segment.onset_s, - segment.offset_s, - segment.label, - segment.samplerate, - segment.sample_dur, - segment.segment_dur, - ] - ) - - return record, n_timebins - - -@dask.delayed -def pad_spectrogram(record: tuple, pad_length: float) -> None: - """Pads a spectrogram to a specified length on the left and right sides. - Spectrogram is saved again after padding. - - Parameters - ---------- - record : tuple - pad_length : int - """ - spect_path = record[0] # 'spect_path' - spect = np.load(spect_path) - - excess_needed = pad_length - spect.shape[-1] - pad_left = np.floor(float(excess_needed) / 2).astype("int") - pad_right = np.ceil(float(excess_needed) / 2).astype("int") - spect_padded = np.pad( - spect, [(0, 0), (pad_left, pad_right)], "constant", constant_values=0 - ) - np.save(spect_path, spect_padded) - return spect_padded.shape - - -# constant, used for names of columns in DataFrame below -DF_COLUMNS = [ - "spect_path", - "audio_path", - "annot_path", - "onset_s", - "offset_s", - "label", - "samplerate", - "sample_dur", - "duration", -] - - -def prep_unit_dataset( - audio_format: str, - output_dir: str, - spect_params: dict, - data_dir: list | None = None, - annot_format: str | None = None, - annot_file: str | pathlib.Path | None = None, - labelset: set | None = None, - context_s: float = 0.005, -) -> pd.DataFrame: - """Prepare a dataset of units from sequences, - e.g., all syllables segmented out of a dataset of birdsong. - - Parameters - ---------- - audio_format - output_dir - spect_params - data_dir - annot_format - annot_file - labelset - context_s - - Returns - ------- - unit_df : pandas.DataFrame - A DataFrame representing all the units in the dataset. - shape: tuple - A tuple representing the shape of all spectograms in the dataset. - The spectrograms of all units are padded so that they are all - as wide as the widest unit (i.e, the one with the longest duration). - """ - # pre-conditions --------------------------------------------------------------------------------------------------- - if audio_format not in constants.VALID_AUDIO_FORMATS: - raise ValueError( - f"audio format must be one of '{constants.VALID_AUDIO_FORMATS}'; " - f"format '{audio_format}' not recognized." - ) - - if labelset is not None: - labelset = labelset_to_set(labelset) - - data_dir = expanded_user_path(data_dir) - if not data_dir.is_dir(): - raise NotADirectoryError(f"data_dir not found: {data_dir}") - - audio_files = files_from_dir(data_dir, audio_format) - - if annot_format is not None: - if annot_file is None: - annot_files = annotation.files_from_dir( - annot_dir=data_dir, annot_format=annot_format - ) - scribe = crowsetta.Transcriber(format=annot_format) - annot_list = [ - scribe.from_file(annot_file).to_annot() - for annot_file in annot_files - ] - else: - scribe = crowsetta.Transcriber(format=annot_format) - annot_list = scribe.from_file(annot_file).to_annot() - if isinstance(annot_list, crowsetta.Annotation): - # if e.g. only one annotated audio file in directory, wrap in a list to make iterable - # fixes https://github.com/NickleDave/vak/issues/467 - annot_list = [annot_list] - else: # if annot_format not specified - annot_list = None - - if annot_list: - audio_annot_map = annotation.map_annotated_to_annot( - audio_files, annot_list, annot_format - ) - else: - # no annotation, so map spectrogram files to None - audio_annot_map = dict( - (audio_path, None) for audio_path in audio_files - ) - - # use labelset, if supplied, with annotations, if any, to filter; - if ( - labelset and annot_list - ): # then remove annotations with labels not in labelset - for audio_file, annot in list(audio_annot_map.items()): - # loop in a verbose way (i.e. not a comprehension) - # so we can give user warning when we skip files - annot_labelset = set(annot.seq.labels) - # below, set(labels_mapping) is a set of that dict's keys - if not annot_labelset.issubset(set(labelset)): - # because there's some label in labels that's not in labelset - audio_annot_map.pop(audio_file) - extra_labels = annot_labelset - labelset - logger.info( - f"Found labels, {extra_labels}, in {pathlib.Path(audio_file).name}, " - "that are not in labels_mapping. Skipping file.", - ) - - segments = [] - for audio_path, annot in audio_annot_map.items(): - segment_list = dask.delayed(get_segment_list)( - audio_path, annot, audio_format, context_s - ) - segments.append(segment_list) - - logger.info( - "Loading audio for all segments in all files", - ) - with ProgressBar(): - segments: list[list[Segment]] = dask.compute(*segments) - segments: list[Segment] = [ - segment for segment_list in segments for segment in segment_list - ] - - # ---- make and save all spectrograms *before* padding - # This is a design choice to avoid keeping all the spectrograms in memory - # but since we want to pad all spectrograms to be the same width, - # it requires us to go back, load each one, and pad it. - # Might be worth looking at how often typical dataset sizes in memory and whether this is really necessary. - records_n_timebins_tuples = [] - for ind, segment in enumerate(segments): - records_n_timebins_tuple = make_spect_return_record( - segment, ind, spect_params, output_dir - ) - records_n_timebins_tuples.append(records_n_timebins_tuple) - with ProgressBar(): - records_n_timebins_tuples: list[tuple[tuple, int]] = dask.compute( - *records_n_timebins_tuples - ) - - records, n_timebins_list = [], [] - for records_n_timebins_tuple in records_n_timebins_tuples: - record, n_timebins = records_n_timebins_tuple - records.append(record) - n_timebins_list.append(n_timebins) - - pad_length = max(n_timebins_list) - - padded = [] - for record in records: - padded.append(pad_spectrogram(record, pad_length)) - with ProgressBar(): - shapes: list[tuple[int, int]] = dask.compute(*padded) - - shape = set(shapes) - assert ( - len(shape) == 1 - ), f"Did not find a single unique shape for all spectrograms. Instead found: {shape}" - shape = shape.pop() - - unit_df = pd.DataFrame.from_records(records, columns=DF_COLUMNS) - - return unit_df, shape diff --git a/src/vak/prep/vae/__init__.py b/src/vak/prep/vae/__init__.py new file mode 100644 index 000000000..3fcb3a6f6 --- /dev/null +++ b/src/vak/prep/vae/__init__.py @@ -0,0 +1 @@ +from .vae import prep_vae_dataset diff --git a/src/vak/prep/vae/segment_vae.py b/src/vak/prep/vae/segment_vae.py new file mode 100644 index 000000000..56064608d --- /dev/null +++ b/src/vak/prep/vae/segment_vae.py @@ -0,0 +1,232 @@ +"""Prepare a dataset of segments for a VAE model.""" +from __future__ import annotations + +import json +import logging +import pathlib + +import pandas as pd + +from ...common import labels +from ...config.spect_params import SpectParamsConfig +from .. import dataset_df_helper, split +from ..segment_dataset import prep_segment_dataset, make_splits + + +logger = logging.getLogger(__name__) + + +def prep_segment_vae_dataset( + data_dir: str | pathlib.Path, + dataset_path: str | pathlib.Path, + dataset_csv_path: str | pathlib.Path, + purpose: str, + audio_format: str | None = None, + spect_params: SpectParamsConfig | None = None, + annot_format: str | None = None, + annot_file: str | pathlib.Path | None = None, + labelset: set | None = None, + context_s: float = 0.005, + max_dur: float | None = None, + target_shape: tuple[int, int] | None = None, + train_dur: int | None = None, + val_dur: int | None = None, + test_dur: int | None = None, + train_set_durs: list[float] | None = None, + num_replicates: int | None = None, + spect_key: str = "s", + timebins_key: str = "t", +) -> tuple[pd.DataFrame, tuple[int]]: + """Prepare a dataset of segments for a VAE model. + + Parameters + ---------- + data_dir : str, Path + Path to directory with files from which to make dataset. + dataset_path + dataset_csv_path + purpose : str + Purpose of the dataset. + One of {'train', 'eval', 'predict', 'learncurve'}. + These correspond to commands of the vak command-line interface. + audio_format : str + Format of audio files. One of {'wav', 'cbin'}. + Default is ``None``, but either ``audio_format`` or ``spect_format`` + must be specified. + spect_params : dict, vak.config.SpectParams + Parameters for creating spectrograms. Default is ``None``. + annot_format : str + Format of annotations. Any format that can be used with the + :module:`crowsetta` library is valid. Default is ``None``. + annot_file + labelset : str, list, set + Set of unique labels for vocalizations. Strings or integers. + Default is ``None``. If not ``None``, then files will be skipped + where the associated annotation + contains labels not found in ``labelset``. + ``labelset`` is converted to a Python ``set`` using + :func:`vak.converters.labelset_to_set`. + See help for that function for details on how to specify ``labelset``. + context_s : float + Number of seconds of "context" around segment to + add, i.e., time before and after the onset + and offset respectively. Default is 0.005s, + 5 milliseconds. + max_dur : float + Maximum duration for segments. + If a float value is specified, + any segment with a duration larger than + that value (in seconds) will be omitted + from the dataset. Default is None. + target_shape : tuple + Of ints, (target number of frequency bins, + target number of time bins). + Spectrograms of segments will be reshaped + by interpolation to have the specified + number of frequency and time bins. + The transformation is only applied if both this + parameter and ``max_dur`` are specified. + Default is None. + train_dur : float + Total duration of training set, in seconds. + When creating a learning curve, + training subsets of shorter duration + will be drawn from this set. Default is None. + val_dur : float + Total duration of validation set, in seconds. + Default is None. + test_dur : float + Total duration of test set, in seconds. + Default is None. + train_set_durs : list + of int, durations in seconds of subsets taken from training data + to create a learning curve, e.g. [5, 10, 15, 20]. + num_replicates : int + number of times to replicate training for each training set duration + to better estimate metrics for a training set of that size. + Each replicate uses a different randomly drawn subset of the training + data (but of the same duration). + spect_key : str + key for accessing spectrogram in files. Default is 's'. + timebins_key : str + key for accessing vector of time bins in files. Default is 't'. + + Returns + ------- + + """ + dataset_df, shape = prep_segment_dataset( + audio_format, + dataset_path, + spect_params, + data_dir, + annot_format, + annot_file, + labelset, + context_s, + max_dur, + target_shape, + ) + if dataset_df.empty: + raise ValueError( + "Calling `vak.prep.segment_dataset.prep_segment_dataset` " + "with arguments passed to `vak.core.prep.vae.prep_segment_vae_dataset` " + "returned an empty dataframe.\n" + "Please double-check arguments to `vak.core.prep` function." + ) + + # save before (possibly) splitting, just in case duration args are not valid + # (we can't know until we make dataset) + dataset_df.to_csv(dataset_csv_path) + + # ---- (possibly) split into train / val / test sets --------------------------------------------- + # catch case where user specified duration for just training set, raise a helpful error instead of failing silently + if (purpose == "train" or purpose == "learncurve") and ( + (train_dur is not None and train_dur > 0) + and (val_dur is None or val_dur == 0) + and (test_dur is None or val_dur == 0) + ): + raise ValueError( + "A duration was specified for just the training set, " + "but prep function does not currently support creating a " + "single split of a specified duration. Either remove the train_dur option from the prep section and " + "rerun, in which case all data will be included in the training set, or specify values greater than " + "zero for test_dur (and val_dur, if a validation set will be used)" + ) + + if all( + [dur is None for dur in (train_dur, val_dur, test_dur)] + ) or purpose in ( + "eval", + "predict", + ): + # then we're not going to split + logger.info("Will not split dataset.") + do_split = False + else: + if val_dur is not None and train_dur is None and test_dur is None: + raise ValueError( + "cannot specify only val_dur, unclear how to split dataset into training and test sets" + ) + else: + logger.info("Will split dataset.") + do_split = True + + if do_split: + dataset_df = split.segment_dataframe( + dataset_df, + dataset_path, + labelset=labelset, + train_dur=train_dur, + val_dur=val_dur, + test_dur=test_dur, + ) + + elif ( + do_split is False + ): # add a split column, but assign everything to the same 'split' + # ideally we would just say split=purpose in call to add_split_col, but + # we have to special case, because "eval" looks for a 'test' split (not an "eval" split) + if purpose == "eval": + split_name = ( + "test" # 'split_name' to avoid name clash with split package + ) + elif purpose == "predict": + split_name = "predict" + + dataset_df = dataset_df_helper.add_split_col( + dataset_df, split=split_name + ) + + # ---- create and save labelmap ------------------------------------------------------------------------------------ + # we do this before creating array files since we need to load the labelmap to make frame label vectors + if purpose != "predict": + # TODO: add option to generate predict using existing dataset, so we can get labelmap from it + labelmap = labels.to_map(labelset, map_unlabeled=False) + logger.info( + f"Number of classes in labelmap: {len(labelmap)}", + ) + # save labelmap in case we need it later + with (dataset_path / "labelmap.json").open("w") as fp: + json.dump(labelmap, fp) + + # ---- make arrays that represent final dataset -------------------------------------------------------------------- + make_splits( + dataset_df, + dataset_path, + ) + # + # ---- if purpose is learncurve, additionally prep splits for that ----------------------------------------------- + # if purpose == 'learncurve': + # dataset_df = make_learncurve_splits_from_dataset_df( + # dataset_df, + # train_set_durs, + # num_replicates, + # dataset_path, + # labelmap, + # audio_format, + # spect_key, + # timebins_key, + # ) + + return dataset_df, shape diff --git a/src/vak/prep/vae/vae.py b/src/vak/prep/vae/vae.py new file mode 100644 index 000000000..ca77914dd --- /dev/null +++ b/src/vak/prep/vae/vae.py @@ -0,0 +1,329 @@ +"""Prepare datasets for VAE models.""" +from __future__ import annotations + +import logging +import pathlib +import warnings + +import crowsetta + +from ... import datasets +from ...common.converters import expanded_user_path, labelset_to_set +from ...common.logging import config_logging_for_cli, log_version +from ...common.timenow import get_timenow_as_str +from .. import dataset_df_helper +from .segment_vae import prep_segment_vae_dataset +from .window_vae import prep_window_vae_dataset + + +logger = logging.getLogger(__name__) + + +VAE_DATASET_TYPES = { + "vae-segment", "vae-window" +} + + +def prep_vae_dataset( + data_dir: str | pathlib.Path, + purpose: str, + dataset_type: str, + output_dir: str | pathlib.Path | None = None, + audio_format: str | None = None, + spect_format: str | None = None, + spect_params: dict | None = None, + annot_format: str | None = None, + annot_file: str | pathlib.Path | None = None, + labelset: set | None = None, + audio_dask_bag_kwargs: dict | None = None, + context_s: float = 0.015, + max_dur: float | None = None, + target_shape: tuple[int, int] | None = None, + train_dur: int | None = None, + val_dur: int | None = None, + test_dur: int | None = None, + train_set_durs: list[float] | None = None, + num_replicates: int | None = None, + spect_key: str = "s", + timebins_key: str = "t", +): + """Prepare datasets for VAE models. + + For general information on dataset preparation, + see the docstring for :func:`vak.prep.prep`. + + Parameters + ---------- + data_dir : str, Path + Path to directory with files from which to make dataset. + purpose : str + Purpose of the dataset. + One of {'train', 'eval', 'predict', 'learncurve'}. + These correspond to commands of the vak command-line interface. + dataset_type : str + Type of VAE dataset. One of {"segment-vae", "window-vae"}. + output_dir : str + Path to location where data sets should be saved. + Default is ``None``, in which case it defaults to ``data_dir``. + audio_format : str + Format of audio files. One of {'wav', 'cbin'}. + Default is ``None``, but either ``audio_format`` or ``spect_format`` + must be specified. + spect_format : str + Format of files containing spectrograms as 2-d matrices. One of {'mat', 'npz'}. + Default is None, but either audio_format or spect_format must be specified. + spect_params : dict, vak.config.SpectParams + Parameters for creating spectrograms. Default is ``None``. + annot_format : str + Format of annotations. Any format that can be used with the + :module:`crowsetta` library is valid. Default is ``None``. + labelset : str, list, set + Set of unique labels for vocalizations. Strings or integers. + Default is ``None``. If not ``None``, then files will be skipped + where the associated annotation + contains labels not found in ``labelset``. + ``labelset`` is converted to a Python ``set`` using + :func:`vak.converters.labelset_to_set`. + See help for that function for details on how to specify ``labelset``. + audio_dask_bag_kwargs : dict + Keyword arguments used when calling :func:`dask.bag.from_sequence` + inside :func:`vak.io.audio`, where it is used to parallelize + the conversion of audio files into spectrograms. + Option should be specified in config.toml file as an inline table, + e.g., ``audio_dask_bag_kwargs = { npartitions = 20 }``. + Allows for finer-grained control + when needed to process files of different sizes. + context_s : float + Number of seconds of "context" around a segment to + add, i.e., time before and after the onset + and offset respectively. Default is 0.005s, + 5 milliseconds. This parameter is only used for + Parametric UMAP and segment-VAE datasets. + max_dur : float + Maximum duration for segments. + If a float value is specified, + any segment with a duration larger than + that value (in seconds) will be omitted + from the dataset. Default is None. + This parameter is only used for + segment-VAE datasets. + target_shape : tuple + Of ints, (target number of frequency bins, + target number of time bins). + Spectrograms of segments will be reshaped + by interpolation to have the specified + number of frequency and time bins. + The transformation is only applied if both this + parameter and ``max_dur`` are specified. + Default is None. + This parameter is only used for + segment-VAE datasets. + train_dur : float + Total duration of training set, in seconds. + When creating a learning curve, + training subsets of shorter duration + will be drawn from this set. Default is None. + val_dur : float + Total duration of validation set, in seconds. + Default is None. + test_dur : float + Total duration of test set, in seconds. + Default is None. + train_set_durs : list + of int, durations in seconds of subsets taken from training data + to create a learning curve, e.g. [5, 10, 15, 20]. + num_replicates : int + number of times to replicate training for each training set duration + to better estimate metrics for a training set of that size. + Each replicate uses a different randomly drawn subset of the training + data (but of the same duration). + spect_key : str + key for accessing spectrogram in files. Default is 's'. + timebins_key : str + key for accessing vector of time bins in files. Default is 't'. + + Returns + ------- + dataset_df : pandas.DataFrame + That represents a dataset. + dataset_path : pathlib.Path + Path to csv saved from ``dataset_df``. + """ + from .. import constants # avoid circular import + + # pre-conditions --------------------------------------------------------------------------------------------------- + if purpose not in constants.VALID_PURPOSES: + raise ValueError( + f"purpose must be one of: {constants.VALID_PURPOSES}\n" + f"Value for purpose was: {purpose}" + ) + + if dataset_type not in VAE_DATASET_TYPES: + raise ValueError( + f"`dataset_type` must be one of '{VAE_DATASET_TYPES}', but was: {dataset_type}" + ) + logger.info(f"Type of VAE dataset that will be prepared : {dataset_type}") + + if labelset is not None: + labelset = labelset_to_set(labelset) + + data_dir = expanded_user_path(data_dir) + if not data_dir.is_dir(): + raise NotADirectoryError( + f"Path specified for ``data_dir`` not found: {data_dir}" + ) + + if output_dir: + output_dir = expanded_user_path(output_dir) + else: + output_dir = data_dir + + if not output_dir.is_dir(): + raise NotADirectoryError( + f"Path specified for ``output_dir`` not found: {output_dir}" + ) + + if annot_file is not None: + annot_file = expanded_user_path(annot_file) + if not annot_file.exists(): + raise FileNotFoundError( + f"Path specified for ``annot_file`` not found: {annot_file}" + ) + + if purpose == "predict": + if labelset is not None: + warnings.warn( + "The ``purpose`` argument was set to 'predict`, but a ``labelset`` was provided." + "This would cause an error because the ``prep_spectrogram_dataset`` function will attempt to " + "check whether the files in the ``data_dir`` have labels in " + "``labelset``, even though those files don't have annotation.\n" + "Setting ``labelset`` to None." + ) + labelset = None + else: # if purpose is not predict + if labelset is None: + raise ValueError( + f"The ``purpose`` argument was set to '{purpose}', but no ``labelset`` was provided." + "This will cause an error when trying to split the dataset, " + "e.g. into training and test splits, " + "or a silent error, e.g. when calculating metrics with an evaluation set. " + "Please specify a ``labelset`` when calling ``vak.prep.vae.prep_vae_dataset`` " + f"with ``purpose='{purpose}'." + ) + + logger.info(f"Purpose for VAE dataset: {purpose}") + # ---- set up directory that will contain dataset, and csv file name ----------------------------------------------- + data_dir_name = data_dir.name + timenow = get_timenow_as_str() + dataset_path = ( + output_dir + / f"{data_dir_name}-vak-vae-dataset-generated-{timenow}" + ) + dataset_path.mkdir() + + if annot_file and annot_format == "birdsong-recognition-dataset": + # we do this normalization / canonicalization after we make dataset_path + # so that we can put the new annot_file inside of dataset_path, instead of + # making new files elsewhere on a user's system + logger.info( + "The ``annot_format`` argument was set to 'birdsong-recognition-format'; " + "this format requires the audio files for their sampling rate " + "to convert onset and offset times of birdsong syllables to seconds." + "Converting this format to 'generic-seq' now with the times in seconds, " + "so that the dataset prepared by vak will not require the audio files." + ) + birdsongrec = crowsetta.formats.seq.BirdsongRec.from_file(annot_file) + annots = birdsongrec.to_annot() + # note we point `annot_file` at a new file we're about to make + annot_file = ( + dataset_path / f"{annot_file.stem}.converted-to-generic-seq.csv" + ) + # and we remake Annotations here so that annot_path points to this new file, not the birdsong-rec Annotation.xml + annots = [ + crowsetta.Annotation( + seq=annot.seq, + annot_path=annot_file, + notated_path=annot.notated_path, + ) + for annot in annots + ] + generic_seq = crowsetta.formats.seq.GenericSeq(annots=annots) + generic_seq.to_file(annot_file) + # and we now change `annot_format` as well. Both these will get passed to io.prep_spectrogram_dataset + annot_format = "generic-seq" + + # NOTE we set up logging here (instead of cli) so the prep log is included in the dataset + config_logging_for_cli( + log_dst=dataset_path, log_stem="prep", level="INFO", force=True + ) + log_version(logger) + + dataset_csv_path = dataset_df_helper.get_dataset_csv_path( + dataset_path, data_dir_name, timenow + ) + logger.info(f"Will prepare dataset as directory: {dataset_path}") + + # ---- actually make the dataset ----------------------------------------------------------------------------------- + logger.info(f"Preparing files for '{dataset_type}' dataset") + if dataset_type == 'vae-segment': + dataset_df, shape = prep_segment_vae_dataset( + data_dir, + dataset_path, + dataset_csv_path, + purpose, + audio_format, + spect_params, + annot_format, + annot_file, + labelset, + context_s, + max_dur, + target_shape, + train_dur, + val_dur, + test_dur, + train_set_durs, + num_replicates, + spect_key, + timebins_key, + ) + elif dataset_type == 'vae-window': + dataset_df = prep_window_vae_dataset( + data_dir, + dataset_path, + dataset_csv_path, + purpose, + audio_format, + spect_format, + spect_params, + annot_format, + annot_file, + labelset, + audio_dask_bag_kwargs, + train_dur, + val_dur, + test_dur, + train_set_durs, + num_replicates, + spect_key, + timebins_key, + ) + # only segment-vae dataset has shape -- we set to None for metadata below + shape = None + + # ---- save csv file that captures provenance of source data ------------------------------------------------------- + logger.info(f"Saving dataset csv file: {dataset_csv_path}") + dataset_df.to_csv( + dataset_csv_path, index=False + ) # index is False to avoid having "Unnamed: 0" column when loading + + # ---- save metadata ----------------------------------------------------------------------------------------------- + metadata = datasets.vae.Metadata( + dataset_csv_filename=str(dataset_csv_path.name), + dataset_type=dataset_type, + audio_format=audio_format, + shape=shape, + ) + metadata.to_json(dataset_path) + + return dataset_df, dataset_path diff --git a/src/vak/prep/vae/window_vae.py b/src/vak/prep/vae/window_vae.py new file mode 100644 index 000000000..90713cd2a --- /dev/null +++ b/src/vak/prep/vae/window_vae.py @@ -0,0 +1,147 @@ +from __future__ import annotations + +import json +import logging +import pathlib + +import pandas as pd + +from ...common import labels +from .. import sequence_dataset +from ..spectrogram_dataset import prep_spectrogram_dataset +from ..frame_classification.assign_samples_to_splits import assign_samples_to_splits +from ..frame_classification.learncurve import make_subsets_from_dataset_df +from ..frame_classification.make_splits import make_splits + + +logger = logging.getLogger(__name__) + + +def prep_window_vae_dataset( + data_dir: str | pathlib.Path, + dataset_path: str | pathlib.Path, + dataset_csv_path: str | pathlib.Path, + purpose: str, + audio_format: str | None = None, + spect_format: str | None = None, + spect_params: dict | None = None, + annot_format: str | None = None, + annot_file: str | pathlib.Path | None = None, + labelset: set | None = None, + audio_dask_bag_kwargs: dict | None = None, + train_dur: int | None = None, + val_dur: int | None = None, + test_dur: int | None = None, + train_set_durs: list[float] | None = None, + num_replicates: int | None = None, + spect_key: str = "s", + timebins_key: str = "t", + freqbins_key: str = "f", +) -> pd.DataFrame: + """ + + Parameters + ---------- + data_dir + dataset_path + dataset_csv_path + purpose + audio_format + spect_format + spect_params + spect_output_dir + annot_format + annot_file + labelset + audio_dask_bag_kwargs + train_dur + val_dur + test_dur + train_set_durs + num_replicates + spect_key + timebins_key + + Returns + ------- + + """ + source_files_df = prep_spectrogram_dataset( + data_dir, + annot_format, + labelset, + annot_file, + audio_format, + spect_format, + spect_params, + audio_dask_bag_kwargs=audio_dask_bag_kwargs, + ) + + # save before (possibly) splitting, just in case duration args are not valid + # (we can't know until we make dataset) + source_files_df.to_csv(dataset_csv_path) + + # ---- assign samples to splits; adds a 'split' column to dataset_df, calling `vak.prep.split` if needed ----------- + # once we assign a split, we consider this the ``dataset_df`` + dataset_df: pd.DataFrame = assign_samples_to_splits( + purpose, + source_files_df, + dataset_path, + train_dur, + val_dur, + test_dur, + labelset, + ) + + # ---- create and save labelmap ------------------------------------------------------------------------------------ + # we do this before creating array files since we need to load the labelmap to make frame label vectors + if purpose != "predict": + # TODO: add option to generate predict using existing dataset, so we can get labelmap from it + map_unlabeled_segments = sequence_dataset.has_unlabeled_segments( + dataset_df + ) + labelmap = labels.to_map( + labelset, map_unlabeled=map_unlabeled_segments + ) + logger.info( + f"Number of classes in labelmap: {len(labelmap)}", + ) + # save labelmap in case we need it later + with (dataset_path / "labelmap.json").open("w") as fp: + json.dump(labelmap, fp) + else: + labelmap = None + + # ---- actually move/copy/create files into directories representing splits ---------------------------------------- + # now we're *remaking* the dataset_df (actually adding additional rows with the splits) + dataset_df: pd.DataFrame = make_splits( + dataset_df, + dataset_path, + # input_type="spect", we only make spectrogram datasets for now + "spect", + purpose, + labelmap, + audio_format, + spect_key, + timebins_key, + freqbins_key, + ) + + # ---- if purpose is learncurve, additionally prep training data subsets for the learning curve -------------------- + if purpose == "learncurve": + dataset_df: pd.DataFrame = make_subsets_from_dataset_df( + dataset_df, + input_type, + train_set_durs, + num_replicates, + dataset_path, + labelmap, + ) + + # ---- save csv file that captures provenance of source data ------------------------------------------------------- + logger.info(f"Saving dataset csv file: {dataset_csv_path}") + dataset_df.to_csv( + dataset_csv_path, index=False + ) # index is False to avoid having "Unnamed: 0" column when loading + + return dataset_df diff --git a/src/vak/train/train_.py b/src/vak/train/train_.py index 79ee2897f..ecae272aa 100644 --- a/src/vak/train/train_.py +++ b/src/vak/train/train_.py @@ -8,6 +8,8 @@ from ..common import validators from .frame_classification import train_frame_classification_model from .parametric_umap import train_parametric_umap_model +from .vae import train_vae_model + logger = logging.getLogger(__name__) @@ -207,5 +209,25 @@ def train( device=device, subset=subset, ) + elif model_family == "VAEModel": + train_vae_model( + model_name=model_name, + model_config=model_config, + dataset_path=dataset_path, + batch_size=batch_size, + num_epochs=num_epochs, + num_workers=num_workers, + train_transform_params=train_transform_params, + train_dataset_params=train_dataset_params, + val_transform_params=val_transform_params, + val_dataset_params=val_dataset_params, + checkpoint_path=checkpoint_path, + results_path=results_path, + shuffle=shuffle, + val_step=val_step, + ckpt_step=ckpt_step, + device=device, + subset=subset, + ) else: raise ValueError(f"Model family not recognized: {model_family}") diff --git a/src/vak/train/vae.py b/src/vak/train/vae.py new file mode 100644 index 000000000..bfcf7b170 --- /dev/null +++ b/src/vak/train/vae.py @@ -0,0 +1,285 @@ +"""Function that trains models in the Variational Autoencoder family.""" +from __future__ import annotations + +import datetime +import logging +import pathlib + +import pandas as pd +import pytorch_lightning as lightning +import torch.utils.data + +from .. import datasets, models, transforms +from ..common import validators +from ..common.device import get_default as get_default_device +from ..datasets.vae import SegmentDataset, WindowDataset +from .frame_classification import get_split_dur + + +logger = logging.getLogger(__name__) + + +def get_trainer( + max_epochs: int, + ckpt_root: str | pathlib.Path, + ckpt_step: int, + log_save_dir: str | pathlib.Path, + device: str = "cuda", +) -> lightning.Trainer: + """Returns an instance of ``lightning.Trainer`` + with a default set of callbacks. + Used by ``vak.core`` functions.""" + # TODO: use accelerator parameter, https://github.com/vocalpy/vak/issues/691 + if device == "cuda": + accelerator = "gpu" + else: + accelerator = "auto" + + ckpt_callback = lightning.callbacks.ModelCheckpoint( + dirpath=ckpt_root, + filename="checkpoint", + every_n_train_steps=ckpt_step, + save_last=True, + verbose=True, + ) + ckpt_callback.CHECKPOINT_NAME_LAST = "checkpoint" + ckpt_callback.FILE_EXTENSION = ".pt" + + val_ckpt_callback = lightning.callbacks.ModelCheckpoint( + monitor="val_loss", + dirpath=ckpt_root, + save_top_k=1, + mode="min", + filename="min-val-loss-checkpoint", + auto_insert_metric_name=False, + verbose=True, + ) + val_ckpt_callback.FILE_EXTENSION = ".pt" + + callbacks = [ + ckpt_callback, + val_ckpt_callback, + ] + + logger = lightning.loggers.TensorBoardLogger(save_dir=log_save_dir) + + trainer = lightning.Trainer( + max_epochs=max_epochs, + accelerator=accelerator, + logger=logger, + callbacks=callbacks, + ) + return trainer + + +def train_vae_model( + model_name: str, + model_config: dict, + dataset_path: str | pathlib.Path, + batch_size: int, + num_epochs: int, + num_workers: int, + train_transform_params: dict | None = None, + train_dataset_params: dict | None = None, + val_transform_params: dict | None = None, + val_dataset_params: dict | None = None, + checkpoint_path: str | pathlib.Path | None = None, + spect_scaler_path: str | pathlib.Path | None = None, + results_path: str | pathlib.Path | None = None, + shuffle: bool = True, + val_step: int | None = None, + ckpt_step: int | None = None, + device: str | None = None, + subset: str | None = None, +) -> None: + """Train a model from the Variational Autoencoder family + and save results. + + Parameters + ---------- + model_name : str + Model name, must be one of vak.models.registry.MODEL_NAMES. + model_config : dict + Model configuration in a ``dict``, + as loaded from a .toml file, + and used by the model method ``from_config``. + dataset_path : str + Path to dataset, a directory generated by running ``vak prep``. + batch_size : int + number of samples per batch presented to models during training. + num_epochs : int + number of training epochs. One epoch = one iteration through the entire + training set. + num_workers : int + Number of processes to use for parallel loading of data. + Argument to torch.DataLoader. + train_transform_params + train_dataset_params + val_transform_params + val_dataset_params + checkpoint_path + spect_scaler_path + results_path + normalize_spectrograms + shuffle + val_step + ckpt_step + device + subset + + Returns + ------- + + """ + for path, path_name in zip( + (checkpoint_path, spect_scaler_path), + ("checkpoint_path", "spect_scaler_path"), + ): + if path is not None: + if not validators.is_a_file(path): + raise FileNotFoundError( + f"value for ``{path_name}`` not recognized as a file: {path}" + ) + + dataset_path = pathlib.Path(dataset_path) + if not dataset_path.exists() or not dataset_path.is_dir(): + raise NotADirectoryError( + f"`dataset_path` not found or not recognized as a directory: {dataset_path}" + ) + + logger.info( + f"Loading dataset from path: {dataset_path}", + ) + metadata = datasets.vae.Metadata.from_dataset_path( + dataset_path + ) + dataset_csv_path = dataset_path / metadata.dataset_csv_filename + dataset_df = pd.read_csv(dataset_csv_path) + # ---------------- pre-conditions ---------------------------------------------------------------------------------- + if val_step and not dataset_df["split"].str.contains("val").any(): + raise ValueError( + f"val_step set to {val_step} but dataset does not contain a validation set; " + f"please run `vak prep` with a config.toml file that specifies a duration for the validation set." + ) + + # ---- set up directory to save output ----------------------------------------------------------------------------- + results_path = pathlib.Path(results_path).expanduser().resolve() + if not results_path.is_dir(): + raise NotADirectoryError( + f"results_path not recognized as a directory: {results_path}" + ) + + # ---------------- load training data ----------------------------------------------------------------------------- + logger.info(f"Using training split from dataset: {dataset_path}") + # below, if we're going to train network to predict unlabeled segments, then + # we need to include a class for those unlabeled segments in labelmap, + # the mapping from labelset provided by user to a set of consecutive + # integers that the network learns to predict + train_dur = get_split_dur(dataset_df, "train") + print( + f"Total duration of training split from dataset (in s): {train_dur}", + ) + + if train_transform_params is None: + train_transform_params = {} + transform = transforms.defaults.get_default_transform( + model_name, "train", train_transform_params + ) + + if train_dataset_params is None: + train_dataset_params = {} + if metadata.dataset_type == 'vae-segment': + train_dataset = SegmentDataset.from_dataset_path( + dataset_path=dataset_path, + split="train", + subset=subset, + transform=transform, + **train_dataset_params, + ) + elif metadata.dataset_type == 'vae-window': + train_dataset = WindowDataset.from_dataset_path( + dataset_path=dataset_path, + split="train", + subset=subset, + transform=transform, + **train_dataset_params, + ) + + train_loader = torch.utils.data.DataLoader( + dataset=train_dataset, + shuffle=shuffle, + batch_size=batch_size, + num_workers=num_workers, + ) + + # ---------------- load validation set (if there is one) ----------------------------------------------------------- + if val_step: + if val_transform_params is None: + val_transform_params = {} + transform = transforms.defaults.get_default_transform( + model_name, "eval", val_transform_params + ) + if val_dataset_params is None: + val_dataset_params = {} + if metadata.dataset_type == 'vae-segment': + val_dataset = SegmentDataset.from_dataset_path( + dataset_path=dataset_path, + split="val", + transform=transform, + **val_dataset_params, + ) + elif metadata.dataset_type == 'vae-window': + val_dataset = WindowDataset.from_dataset_path( + dataset_path=dataset_path, + split="val", + transform=transform, + **val_dataset_params, + ) + print( + f"Duration of dataset used for validation, in seconds: {val_dataset.duration}", + ) + val_loader = torch.utils.data.DataLoader( + dataset=val_dataset, + shuffle=False, + batch_size=batch_size, + num_workers=num_workers, + ) + + if device is None: + device = get_default_device() + + model = models.get( + model_name, + config=model_config, + input_shape=train_dataset.shape, + ) + + if checkpoint_path is not None: + logger.info( + f"loading checkpoint for {model_name} from path: {checkpoint_path}", + ) + model.load_state_dict_from_path(checkpoint_path) + + results_model_root = results_path.joinpath(model_name) + results_model_root.mkdir() + ckpt_root = results_model_root.joinpath("checkpoints") + ckpt_root.mkdir(exist_ok=True) + logger.info(f"Training model: {model_name}") + trainer = get_trainer( + max_epochs=num_epochs, + log_save_dir=results_model_root, + device=device, + ckpt_root=ckpt_root, + ckpt_step=ckpt_step, + ) + train_time_start = datetime.datetime.now() + logger.info(f"Training start time: {train_time_start.isoformat()}") + trainer.fit( + model=model, + train_dataloaders=train_loader, + val_dataloaders=val_loader, + ) + train_time_stop = datetime.datetime.now() + logger.info(f"Training stop time: {train_time_stop.isoformat()}") + elapsed = train_time_stop - train_time_start + logger.info(f"Elapsed training time: {elapsed}") diff --git a/src/vak/transforms/defaults/get.py b/src/vak/transforms/defaults/get.py index 0851d515c..e7a6e2abd 100644 --- a/src/vak/transforms/defaults/get.py +++ b/src/vak/transforms/defaults/get.py @@ -2,7 +2,7 @@ from __future__ import annotations from ... import models -from . import frame_classification, parametric_umap +from . import frame_classification, parametric_umap, vae def get_default_transform( @@ -44,3 +44,8 @@ def get_default_transform( return parametric_umap.get_default_parametric_umap_transform( transform_kwargs ) + + elif model_family == "VAEModel": + return vae.get_default_vae_transform( + transform_kwargs + ) diff --git a/src/vak/transforms/defaults/parametric_umap.py b/src/vak/transforms/defaults/parametric_umap.py index 83c568b06..be4b51864 100644 --- a/src/vak/transforms/defaults/parametric_umap.py +++ b/src/vak/transforms/defaults/parametric_umap.py @@ -9,7 +9,7 @@ def get_default_parametric_umap_transform( transform_kwargs, ) -> torchvision.transforms.Compose: - """Get default transform for frame classification model. + """Get default transform for Parametric UMAP model. Parameters ---------- diff --git a/src/vak/transforms/defaults/vae.py b/src/vak/transforms/defaults/vae.py new file mode 100644 index 000000000..53126c33f --- /dev/null +++ b/src/vak/transforms/defaults/vae.py @@ -0,0 +1,26 @@ +"""Default transforms for VAE models.""" +from __future__ import annotations + +import torchvision.transforms + +from .. import transforms as vak_transforms + + +def get_default_vae_transform( + transform_kwargs, +) -> torchvision.transforms.Compose: + """Get default transform for VAE model. + + Parameters + ---------- + transform_kwargs : dict + + Returns + ------- + transform : Callable + """ + transforms = [ + vak_transforms.ToFloatTensor(), + vak_transforms.AddChannel(), + ] + return torchvision.transforms.Compose(transforms) diff --git a/tests/data_for_tests/configs/AVA_segment_vae_learncurve_audio_cbin_annot_notmat.toml b/tests/data_for_tests/configs/AVA_segment_vae_learncurve_audio_cbin_annot_notmat.toml new file mode 100644 index 000000000..ac6d86a77 --- /dev/null +++ b/tests/data_for_tests/configs/AVA_segment_vae_learncurve_audio_cbin_annot_notmat.toml @@ -0,0 +1,38 @@ +[PREP] +dataset_type = "vae-segment" +input_type = "spect" +data_dir = "./tests/data_for_tests/source/audio_cbin_annot_notmat/gy6or6/032312" +output_dir = "./tests/data_for_tests/generated/prep/train/audio_cbin_annot_notmat/AVA_segment_vae" +audio_format = "cbin" +annot_format = "notmat" +labelset = "iabcdefghjk" +train_dur = 20 +val_dur = 5 +test_dur = 10 +context_s = 0.01 +max_dur = 0.2 +target_shape = [ 128, 128,] +train_set_durs = [ 4, 6,] +num_replicates = 2 + +[SPECT_PARAMS] +fft_size = 512 +step_size = 256 +transform_type = "log_spect" +freq_cutoffs = [ 400, 10000,] +normalize = false +min_val = -6.0 +max_val = 0.0 + +[LEARNCURVE] +model = "AVA" +batch_size = 64 +num_epochs = 150 +val_step = 500 +ckpt_step = 1000 +num_workers = 16 +device = "cuda" +root_results_dir = "./tests/data_for_tests/generated/results/train/audio_cbin_annot_notmat/AVA" + +[AVA.optimizer] +lr = 0.001 diff --git a/tests/data_for_tests/configs/AVA_segment_vae_train_audio_cbin_annot_notmat.toml b/tests/data_for_tests/configs/AVA_segment_vae_train_audio_cbin_annot_notmat.toml new file mode 100644 index 000000000..52414440b --- /dev/null +++ b/tests/data_for_tests/configs/AVA_segment_vae_train_audio_cbin_annot_notmat.toml @@ -0,0 +1,37 @@ +[PREP] +dataset_type = "vae-segment" +input_type = "spect" +data_dir = "./tests/data_for_tests/source/audio_cbin_annot_notmat/gy6or6/032312" +output_dir = "./tests/data_for_tests/generated/prep/train/audio_cbin_annot_notmat/AVA_segment_vae" +audio_format = "cbin" +annot_format = "notmat" +labelset = "iabcdefghjk" +train_dur = 20 +val_dur = 5 +test_dur = 10 +context_s = 0.01 +max_dur = 0.2 +target_shape = [ 128, 128,] + +[SPECT_PARAMS] +fft_size = 512 +step_size = 256 +transform_type = "log_spect" +freq_cutoffs = [ 400, 10000,] +normalize = false +min_val = -6.0 +max_val = 0.0 + +[TRAIN] +model = "AVA" +batch_size = 64 +num_epochs = 150 +val_step = 500 +ckpt_step = 1000 +num_workers = 16 +device = "cuda" +root_results_dir = "./tests/data_for_tests/generated/results/train/audio_cbin_annot_notmat/AVA" +dataset_path = "tests/data_for_tests/generated/prep/train/audio_cbin_annot_notmat/AVA_segment_vae/032312-vak-vae-dataset-generated-240111_222325" + +[AVA.optimizer] +lr = 0.001 diff --git a/tests/data_for_tests/configs/AVA_window_vae_train_audio_cbin_annot_notmat.toml b/tests/data_for_tests/configs/AVA_window_vae_train_audio_cbin_annot_notmat.toml new file mode 100644 index 000000000..7b308915f --- /dev/null +++ b/tests/data_for_tests/configs/AVA_window_vae_train_audio_cbin_annot_notmat.toml @@ -0,0 +1,39 @@ +[PREP] +dataset_type = "vae-window" +input_type = "spect" +data_dir = "./tests/data_for_tests/source/audio_cbin_annot_notmat/gy6or6/032312" +output_dir = "./tests/data_for_tests/generated/prep/train/audio_cbin_annot_notmat/AVA_segment_vae" +audio_format = "cbin" +annot_format = "notmat" +labelset = "iabcdefghjk" +train_dur = 0.5 +val_dur = 0.2 +test_dur = 0.25 + +[SPECT_PARAMS] +fft_size = 512 +step_size = 32 +transform_type = "log_spect_plus_one" + +[TRAIN] +model = "AVA" +batch_size = 64 +num_epochs = 1 +val_step = 1 +ckpt_step = 1000 +num_workers = 16 +device = "cuda" +root_results_dir = "./tests/data_for_tests/generated/results/train/audio_cbin_annot_notmat/AVA" +dataset_path = "tests/data_for_tests/generated/prep/train/audio_cbin_annot_notmat/AVA_segment_vae/032312-vak-vae-dataset-generated-240103_115504" + +#[AVA.network] +#conv1_filters = 8 +#conv2_filters = 16 +#conv_kernel_size = 3 +#conv_stride = 2 +#conv_padding = 1 +#n_features_linear = 32 +#n_components = 2 + +[AVA.optimizer] +lr = 0.001 diff --git a/tests/test_datasets/test_vae/__init__.py b/tests/test_datasets/test_vae/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/test_datasets/test_vae/test_vae.py b/tests/test_datasets/test_vae/test_vae.py new file mode 100644 index 000000000..751ce982a --- /dev/null +++ b/tests/test_datasets/test_vae/test_vae.py @@ -0,0 +1,4 @@ +class TestVAEDataset: + def __init__(self): + # TODO: write these tests + assert False diff --git a/tests/test_models/test_vae.py b/tests/test_models/test_vae.py new file mode 100644 index 000000000..e56cd1fc9 --- /dev/null +++ b/tests/test_models/test_vae.py @@ -0,0 +1,38 @@ +import pytest + +import vak + + +class TestConvEncoderUMAP: + @pytest.mark.parametrize( + 'input_shape', + [ + (1, 32, 32), + (1, 64, 64), + ] + ) + def test_init(self, input_shape): + # TODO: actually write this test + assert False + network = { + 'encoder': vak.models.ConvEncoderUMAP.definition.network['encoder'](input_shape=input_shape) + } + model = vak.models.ConvEncoderUMAP(network=network) + assert isinstance(model, vak.models.ConvEncoderUMAP) + for attr in ('network', 'loss', 'optimizer'): + assert hasattr(model, attr) + attr_from_definition = getattr(vak.models.convencoder_umap.ConvEncoderUMAP.definition, attr) + if isinstance(attr_from_definition, dict): + attr_from_model = getattr(model, attr) + assert isinstance(attr_from_model, dict) + assert attr_from_model.keys() == attr_from_definition.keys() + for net_name, net_instance in attr_from_model.items(): + assert isinstance(net_instance, attr_from_definition[net_name]) + else: + assert isinstance(getattr(model, attr), + getattr(vak.models.convencoder_umap.ConvEncoderUMAP.definition, attr)) + assert hasattr(model, 'metrics') + assert isinstance(model.metrics, dict) + for metric_name, metric_callable in model.metrics.items(): + assert isinstance(metric_callable, + vak.models.convencoder_umap.ConvEncoderUMAP.definition.metrics[metric_name]) diff --git a/tests/test_nets/test_ava.py b/tests/test_nets/test_ava.py new file mode 100644 index 000000000..fcf10bc41 --- /dev/null +++ b/tests/test_nets/test_ava.py @@ -0,0 +1,61 @@ +import torch +import pytest + +import vak.nets + + +class TestAVA: + + @pytest.mark.parametrize( + 'input_shape', + [ + ( + 1, 128, 128, + ), + ( + 1, 256, 256, + ), + ] + ) + def test_init(self, input_shape): + """test we can instantiate AVA + and it has the expected attributes""" + net = vak.nets.AVA(input_shape) + assert isinstance(net, vak.nets.AVA) + for expected_attr, expected_type in ( + ('input_shape', tuple), + ('in_channels', int), + ('x_shape', tuple), + ('x_dim', int), + ('encoder', torch.nn.Module), + ('shared_encoder_fc', torch.nn.Module), + ('mu_fc', torch.nn.Module), + ('cov_factor_fc', torch.nn.Module), + ('cov_diag_fc', torch.nn.Module), + ('decoder_fc', torch.nn.Module), + ('decoder', torch.nn.Module), + ): + assert hasattr(net, expected_attr) + assert isinstance(getattr(net, expected_attr), expected_type) + + assert net.input_shape == input_shape + + @pytest.mark.parametrize( + 'input_shape, batch_size', + [ + ((1, 128, 128,), 32), + ((1, 256, 256,), 64), + ] + ) + def test_forward(self, input_shape, batch_size): + """test we can forward a tensor through a ConvEncoder instance + and get the expected output""" + + input = torch.rand(batch_size, *input_shape) # a "batch" + net = vak.nets.AVA(input_shape) + out = net(input) + assert len(out) == 3 + x_rec, z, latent_dist = out + for tensor in (x_rec, z): + assert isinstance(tensor, torch.Tensor) + assert isinstance(latent_dist, torch.distributions.LowRankMultivariateNormal) diff --git a/tests/test_nets/test_convencoder.py b/tests/test_nets/test_convencoder.py index eaa3b6f6d..f89646139 100644 --- a/tests/test_nets/test_convencoder.py +++ b/tests/test_nets/test_convencoder.py @@ -1,5 +1,3 @@ -import inspect - import torch import pytest @@ -50,4 +48,3 @@ def test_forward(self, input_shape, batch_size): net = vak.nets.ConvEncoder(input_shape) out = net(input) assert isinstance(out, torch.Tensor) - diff --git a/tests/test_prep/test_segment_dataset/__init__.py b/tests/test_prep/test_segment_dataset/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/test_prep/test_segment_dataset/test_learncurve.py b/tests/test_prep/test_segment_dataset/test_learncurve.py new file mode 100644 index 000000000..10d571190 --- /dev/null +++ b/tests/test_prep/test_segment_dataset/test_learncurve.py @@ -0,0 +1,2 @@ +def test_make_subsets_from_dataset_df(): + assert False diff --git a/tests/test_prep/test_segment_dataset/test_make_splits.py b/tests/test_prep/test_segment_dataset/test_make_splits.py new file mode 100644 index 000000000..8bd452381 --- /dev/null +++ b/tests/test_prep/test_segment_dataset/test_make_splits.py @@ -0,0 +1,2 @@ +def test_make_splits(): + assert False diff --git a/tests/test_prep/test_segment_dataset/test_segment_dataset.py b/tests/test_prep/test_segment_dataset/test_segment_dataset.py new file mode 100644 index 000000000..3f681ad63 --- /dev/null +++ b/tests/test_prep/test_segment_dataset/test_segment_dataset.py @@ -0,0 +1,41 @@ +class TestSegment: + def test_init(self): + assert False + + +def test_get_segment_list(): + assert False + + +def test_spectrogram_from_segment(): + # TODO: mock calling spectrogram + assert False + + +class TestSpectToSave: + def test_init(self): + assert False + + +def test_save_spect(): + assert False + + +def test_abspath(): + assert False + + +def test_make_spect_return_record(): + assert False + + +def test_pad_spectrogram(): + assert False + + +def test_interp_spectrogram(): + assert False + + +def test_prep_segment_dataset(): + assert False diff --git a/train-ava.ipynb b/train-ava.ipynb new file mode 100644 index 000000000..abe21ef16 --- /dev/null +++ b/train-ava.ipynb @@ -0,0 +1,365 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "c88ee0dc-7579-40af-9c5d-68bef9b30c49", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/pimienta/Documents/repos/coding/vocalpy/vak-vocalpy/.venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + } + ], + "source": [ + "from __future__ import annotations\n", + "\n", + "import datetime\n", + "import logging\n", + "import pathlib\n", + "\n", + "import numpy as np\n", + "import pandas as pd\n", + "import pytorch_lightning as lightning\n", + "import torch.utils.data\n", + "\n", + "from src import vak" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "190f11c3-c115-408e-92de-3d87cd421748", + "metadata": {}, + "outputs": [], + "source": [ + "def get_split_dur(df: pd.DataFrame, split: str) -> float:\n", + " \"\"\"Get duration of a split in a dataset from a pandas DataFrame representing the dataset.\"\"\"\n", + " return df[df[\"split\"] == split][\"duration\"].sum()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "166b6c26-ea53-46cc-82df-0f3057c801b8", + "metadata": {}, + "outputs": [], + "source": [ + "def get_trainer(\n", + " max_epochs: int,\n", + " ckpt_root: str | pathlib.Path,\n", + " ckpt_step: int,\n", + " log_save_dir: str | pathlib.Path,\n", + " device: str = \"cuda\",\n", + ") -> lightning.Trainer:\n", + " \"\"\"Returns an instance of ``lightning.Trainer``\n", + " with a default set of callbacks.\n", + " Used by ``vak.core`` functions.\"\"\"\n", + " # TODO: use accelerator parameter, https://github.com/vocalpy/vak/issues/691\n", + " if device == \"cuda\":\n", + " accelerator = \"gpu\"\n", + " else:\n", + " accelerator = \"auto\"\n", + "\n", + " ckpt_callback = lightning.callbacks.ModelCheckpoint(\n", + " dirpath=ckpt_root,\n", + " filename=\"checkpoint\",\n", + " every_n_train_steps=ckpt_step,\n", + " save_last=True,\n", + " verbose=True,\n", + " )\n", + " ckpt_callback.CHECKPOINT_NAME_LAST = \"checkpoint\"\n", + " ckpt_callback.FILE_EXTENSION = \".pt\"\n", + "\n", + " val_ckpt_callback = lightning.callbacks.ModelCheckpoint(\n", + " monitor=\"val_loss\",\n", + " dirpath=ckpt_root,\n", + " save_top_k=1,\n", + " mode=\"min\",\n", + " filename=\"min-val-loss-checkpoint\",\n", + " auto_insert_metric_name=False,\n", + " verbose=True,\n", + " )\n", + " val_ckpt_callback.FILE_EXTENSION = \".pt\"\n", + "\n", + " callbacks = [\n", + " ckpt_callback,\n", + " val_ckpt_callback,\n", + " ]\n", + "\n", + " logger = lightning.loggers.TensorBoardLogger(save_dir=log_save_dir)\n", + "\n", + " trainer = lightning.Trainer(\n", + " max_epochs=max_epochs,\n", + " accelerator=accelerator,\n", + " logger=logger,\n", + " callbacks=callbacks,\n", + " )\n", + " return trainer" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fcf49554-27a4-4baf-b5ae-a2fb9288f4e4", + "metadata": {}, + "outputs": [], + "source": [ + "class SpectrogramPipe(torch.utils.data.Dataset):\n", + " \"\"\"Pipeline for loading samples from a dataset of spectrograms\n", + " \n", + " This is a simplified version of ``vak.datasets.parametric_umap.ParametricUmapInferenceDataset``.\n", + " \"\"\"\n", + " def __init__(\n", + " self,\n", + " data: npt.NDArray,\n", + " dataset_df: pd.DataFrame,\n", + " transform: Callable | None = None,\n", + " ):\n", + " self.data = data\n", + " self.dataset_df = dataset_df\n", + " self.transform = transform\n", + "\n", + " @property\n", + " def duration(self):\n", + " return self.dataset_df[\"duration\"].sum()\n", + "\n", + " def __len__(self):\n", + " return self.data.shape[0]\n", + "\n", + " @property\n", + " def shape(self):\n", + " tmp_x_ind = 0\n", + " tmp_item = self.__getitem__(tmp_x_ind)\n", + " return tmp_item[\"x\"].shape\n", + "\n", + " def __getitem__(self, index):\n", + " x = self.data[index]\n", + " df_index = self.dataset_df.index[index]\n", + " if self.transform:\n", + " x = self.transform(x)\n", + " return {\"x\": x, \"df_index\": df_index}\n", + "\n", + " @classmethod\n", + " def from_dataset_path(\n", + " cls,\n", + " dataset_path: str | pathlib.Path,\n", + " split: str,\n", + " transform: Callable | None = None,\n", + " ):\n", + " import vak.datasets # import here just to make classmethod more explicit\n", + "\n", + " dataset_path = pathlib.Path(dataset_path)\n", + " metadata = vak.datasets.parametric_umap.Metadata.from_dataset_path(\n", + " dataset_path\n", + " )\n", + "\n", + " dataset_csv_path = dataset_path / metadata.dataset_csv_filename\n", + " dataset_df = pd.read_csv(dataset_csv_path)\n", + " split_df = dataset_df[dataset_df.split == split]\n", + "\n", + " data = np.stack(\n", + " [\n", + " np.load(dataset_path / spect_path)\n", + " for spect_path in split_df.spect_path.values\n", + " ]\n", + " )\n", + " return cls(\n", + " data,\n", + " split_df,\n", + " transform=transform,\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f1c9c5bb-0b94-4649-80f7-db3185e9b480", + "metadata": {}, + "outputs": [], + "source": [ + "dataset_path = pathlib.Path(\n", + " './tests/data_for_tests/generated/prep/train/audio_cbin_annot_notmat/ConvEncoderUMAP/032312-vak-dimensionality-reduction-dataset-generated-231010_165846/'\n", + ")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7b1a3d7b-dd9c-4977-9de1-bc31049a06ca", + "metadata": {}, + "outputs": [], + "source": [ + "metadata = vak.datasets.parametric_umap.Metadata.from_dataset_path(\n", + " dataset_path\n", + ")\n", + "dataset_csv_path = dataset_path / metadata.dataset_csv_filename\n", + "dataset_df = pd.read_csv(dataset_csv_path)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "565cdc37-738f-42a1-b65d-4348fd567d99", + "metadata": {}, + "outputs": [], + "source": [ + "val_step = 2000" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3d4a4216-ebba-40ea-ab26-8a5d88211c4d", + "metadata": {}, + "outputs": [], + "source": [ + "results_path = pathlib.Path(\n", + " './tests/data_for_tests/generated/results/train/audio_cbin_annot_notmat/AVA'\n", + ")\n", + "results_path.mkdir(exist_ok=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0f0ac2fb-fa28-4645-8397-50f5cc1c6bb8", + "metadata": {}, + "outputs": [], + "source": [ + "# ---------------- load training data -----------------------------------------------------------------------------\n", + "\n", + "# below, if we're going to train network to predict unlabeled segments, then\n", + "# we need to include a class for those unlabeled segments in labelmap,\n", + "# the mapping from labelset provided by user to a set of consecutive\n", + "# integers that the network learns to predict\n", + "train_dur = get_split_dur(dataset_df, \"train\")\n", + "print(\n", + " f\"Total duration of training split from dataset (in s): {train_dur}\",\n", + ")\n", + "\n", + "\n", + "train_transform_params = {}\n", + "transform = vak.transforms.defaults.get_default_transform(\n", + " \"ConvEncoderUMAP\", \"train\", train_transform_params\n", + ")\n", + "\n", + "\n", + "train_dataset_params = {}\n", + "train_dataset = SpectrogramPipe.from_dataset_path(\n", + " dataset_path=dataset_path,\n", + " split=\"train\",\n", + " transform=transform,\n", + " **train_dataset_params,\n", + ")\n", + "\n", + "train_loader = torch.utils.data.DataLoader(\n", + " dataset=train_dataset,\n", + " shuffle=True,\n", + " batch_size=64,\n", + " num_workers=16,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "925e06c9-91f8-4614-b56d-56b07e12f564", + "metadata": {}, + "outputs": [], + "source": [ + "# ---------------- load validation set (if there is one) -----------------------------------------------------------\n", + "\n", + "\n", + "val_transform_params = {}\n", + "transform = vak.transforms.defaults.get_default_transform(\n", + " \"ConvEncoderUMAP\", \"eval\", val_transform_params\n", + ")\n", + "val_dataset_params = {}\n", + "val_dataset = SpectrogramPipe.from_dataset_path(\n", + " dataset_path=dataset_path,\n", + " split=\"val\",\n", + " transform=transform,\n", + " **val_dataset_params,\n", + ")\n", + "print(\n", + " f\"Duration of ParametricUMAPDataset used for validation, in seconds: {val_dataset.duration}\",\n", + ")\n", + "val_loader = torch.utils.data.DataLoader(\n", + " dataset=val_dataset,\n", + " shuffle=False,\n", + " batch_size=64,\n", + " num_workers=16,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6cb39ce3-72db-4b53-b4b9-a08b449b37d4", + "metadata": {}, + "outputs": [], + "source": [ + "device = vak.common.device.get_default()\n", + "\n", + "model = vak.models.get(\n", + " \"AVA\",\n", + " config={\"network\": {}, \"optimizer\": {\"lr\": 0.001}},\n", + " input_shape=train_dataset.shape,\n", + ")\n", + "\n", + "results_model_root = results_path.joinpath(\"AVA\")\n", + "results_model_root.mkdir(exist_ok=True)\n", + "ckpt_root = results_model_root.joinpath(\"checkpoints\")\n", + "ckpt_root.mkdir(exist_ok=True)\n", + "\n", + "trainer = get_trainer(\n", + " max_epochs=50,\n", + " log_save_dir=results_model_root,\n", + " device=device,\n", + " ckpt_root=ckpt_root,\n", + " ckpt_step=250,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fe993c8b-4571-4892-b645-91d778df2fb6", + "metadata": {}, + "outputs": [], + "source": [ + "trainer.fit(\n", + " model=model,\n", + " train_dataloaders=train_loader,\n", + " val_dataloaders=val_loader,\n", + ")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}