Skip to content

Commit 5ce41ae

Browse files
authored
feat(optimize): add option to pass keyword arguments to pipeline during optimization (#58)
1 parent 897bdc9 commit 5ce41ae

File tree

2 files changed

+12
-2
lines changed

2 files changed

+12
-2
lines changed

doc/source/changelog.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ Changelog
55
Version 4.0.0rc1 (2025-02-11)
66
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
77

8+
- feat(optimize): add option to pass keyword arguments to pipeline during optimization
89
- BREAKING: drop support for `Python` < 3.10
910
- BREAKING: switch to native namespace package
1011
- BREAKING: remove `pyannote.pipeline.blocks` submodule

src/pyannote/pipeline/optimizer.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
import time
3131
import warnings
3232
from pathlib import Path
33-
from typing import Iterable, Optional, Callable, Generator, Union, Dict
33+
from typing import Iterable, Optional, Callable, Generator, Mapping, Union, Dict
3434

3535
import numpy as np
3636
import optuna.logging
@@ -231,7 +231,16 @@ def objective(trial: Trial) -> float:
231231
# process input with pipeline
232232
# (and keep track of processing time)
233233
before_processing = time.time()
234-
output = pipeline(input)
234+
235+
# get optional kwargs to be passed to the pipeline
236+
# (e.g. num_speakers for speaker diarization). they
237+
# must be stored in a 'pipeline_kwargs' key in the
238+
# `input` dictionary.
239+
if isinstance(input, Mapping):
240+
pipeline_kwargs = input.get("pipeline_kwargs", {})
241+
else:
242+
pipeline_kwargs = {}
243+
output = pipeline(input, **pipeline_kwargs)
235244
after_processing = time.time()
236245
processing_time.append(after_processing - before_processing)
237246

0 commit comments

Comments
 (0)