diff --git a/pyannote/metrics/base.py b/pyannote/metrics/base.py index 5d1dba7..49c22c1 100755 --- a/pyannote/metrics/base.py +++ b/pyannote/metrics/base.py @@ -25,6 +25,7 @@ # AUTHORS # Hervé BREDIN - http://herve.niderb.fr +import inspect from typing import List, Union, Optional, Set, Tuple import warnings @@ -36,6 +37,49 @@ from pyannote.metrics.types import Details, MetricComponents +def clone(metric): + """Construct a new empty metric with the same parameters as `metric`. + + Clone does a deep copy of the metric without actually copying any + results or accumulators. It returns a new metric (with the same + parameters) that has not yet been used to evaluate any data. + + Parameters + ---------- + metric : metric instance + The metric to be cloned. + + Returns + ------- + metric : object + The deep copy of the metric, + """ + if hasattr(metric, '__pyannote_clone__') and not inspect.isclass(metric): + # If metric implements a custom cloning method, default to that. + return metric.__pyannote_clone__() + cls = metric.__class__ + new_metric_params = metric.get_params() + new_metric = cls(**new_metric_params) + + # Sanity check that all parameters were saved by the constructor without + # modification. + actual_params = new_metric.get_params() + for pname in new_metric_params: + error_msg = ( + f'Cannot clone metric {repr(metric)}, as the constructor either ' + f'does not set or modifies parameter {pname}.') + print(error_msg) + expected_param = new_metric_params[pname] + try: + actual_param = actual_params[pname] + except KeyError: + raise RuntimeError(error_msg()) + if expected_param != actual_param: + raise RuntimeError(error_msg()) + + return new_metric + + class BaseMetric: """ :class:`BaseMetric` is the base class for most pyannote evaluation metrics. @@ -66,6 +110,33 @@ def __init__(self, **kwargs): self.components_: Set[str] = set(self.__class__.metric_components()) self.reset() + @classmethod + def _get_param_names(cls): + pnames = set() + for cls_ in cls.__mro__: + init = cls_.__init__ + init_signature = inspect.signature(init) + for p in init_signature.parameters.values(): + if p.name == 'self': + continue + if p.kind in {p.VAR_POSITIONAL, p.VAR_KEYWORD}: + # Skip *args/**kwargs. + continue + pnames.add(p.name) + return pnames + + def get_params(self): + """Return parameters for this metric. + + Returns + ------- + params : dict + Mapping from parameter names to values. + """ + return {pname: getattr(self, pname) + for pname in self._get_param_names()} + + def init_components(self): return {value: 0.0 for value in self.components_} @@ -81,9 +152,6 @@ def name(self): """Metric name.""" return self.metric_name() - # TODO: use joblib/locky to allow parallel processing? - # TODO: signature could be something like __call__(self, reference_iterator, hypothesis_iterator, ...) - def __call__(self, reference: Union[Timeline, Annotation], hypothesis: Union[Timeline, Annotation], detailed: bool = False, uri: Optional[str] = None, **kwargs): @@ -219,6 +287,13 @@ def __str__(self): sparsify=False, float_format=lambda f: "{0:.2f}".format(f) ) + def __repr__(self): + cls = self.__class__.__name__ + params = self.get_params() + pnames = sorted(params) + signature = ', '.join(f'{pname}={params[pname]}' for pname in pnames) + return f'{cls}({signature})' + def __abs__(self): """Compute metric value from accumulated components""" return self.compute_metric(self.accumulated_) @@ -247,6 +322,21 @@ def __iter__(self): for uri, component in self.results_: yield uri, component + def __add__(self, other): + cls = self.__class__ + result = cls() + result.results_ = self.results_ + other.results_ + for cname in self.components_: + result.accumulated_[cname] += self.accumulated_[cname] + result.accumulated_[cname] += other.accumulated_[cname] + return result + + def __radd__(self, other): + if other == 0: + return self + else: + return self.__add__(other) + def compute_components(self, reference: Union[Timeline, Annotation], hypothesis: Union[Timeline, Annotation], @@ -320,12 +410,12 @@ def confidence_interval(self, alpha: float = 0.9) \ if len(values) == 0: raise ValueError("Please evaluate a bunch of files before computing confidence interval.") - + elif len(values) == 1: warnings.warn("Cannot compute a reliable confidence interval out of just one file.") center = lower = upper = values[0] return center, (lower, upper) - + else: return scipy.stats.bayes_mvs(values, alpha=alpha)[0] diff --git a/tests/test_base.py b/tests/test_base.py new file mode 100644 index 0000000..60768a4 --- /dev/null +++ b/tests/test_base.py @@ -0,0 +1,157 @@ +#!/usr/bin/env python +# encoding: utf-8 + +# The MIT License (MIT) + +# Copyright (c) 2020 CNRS + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# AUTHORS +# Hervé BREDIN - http://herve.niderb.fr + + +import pytest + +from pyannote.core import Annotation +from pyannote.core import Segment +from pyannote.core import Timeline +from pyannote.metrics.base import clone, BaseMetric +from pyannote.metrics.detection import DetectionAccuracy + + +import numpy.testing as npt + +# rec1 +# +# Time 0 1 2 3 4 5 6 +# Reference |-----| +# Hypothesis |-----| +# UEM |-----------------| + +# rec2 +# +# Time 0 1 2 3 4 5 6 7 +# Reference |--------| +# Hypothesis |-----| +# UEM |--------------------| + + +@pytest.fixture +def reference(): + reference = {} + reference['rec1'] = Annotation() + reference['rec1'][Segment(0, 2)] = 'A' + reference['rec2'] = Annotation() + reference['rec2'][Segment(1, 4)] = 'A' + return reference + + +@pytest.fixture +def hypothesis(): + hypothesis = {} + hypothesis['rec1'] = Annotation() + hypothesis['rec1'][Segment(1, 3)] = 'A' + hypothesis['rec2'] = Annotation() + hypothesis['rec2'][Segment(3, 4)] = 'A' + return hypothesis + + +@pytest.fixture +def uem(): + return { + 'rec1': Timeline([Segment(0, 6)]), + 'rec2': Timeline([Segment(0, 7)])} + + +def test_summation(reference, hypothesis, uem): + # Expected error rate. + expected = 9 / 13 + + # __add__ + m1 = DetectionAccuracy() + m1(reference['rec1'], hypothesis['rec1'], uem=uem['rec1']) + m2 = DetectionAccuracy() + m2(reference['rec2'], hypothesis['rec2'], uem=uem['rec2']) + npt.assert_almost_equal(abs(m1 + m2), expected, decimal=3) + + # __radd__ + m = sum([m1, m2]) + npt.assert_almost_equal(abs(m), expected, decimal=3) + + +class M1(BaseMetric): + def __init__(self, a=9, **kwargs): + super().__init__(**kwargs) + self.a = a + + @classmethod + def metric_name(cls): + return 'M1' + + @classmethod + def metric_components(cls): + return ['c1'] + + def compute_metric(self, foo): + return 1. + + +class M2(M1): + def __init__(self, b=10, **kwargs): + super().__init__(**kwargs) + self.b = b + + @classmethod + def metric_name(cls): + return 'M2' + + +def test_get_params(): + # Subclass of BaseMetric. + m = M1(a=100) + expected = {'a': 100} + actual = m.get_params() + assert actual == expected + + # Subclass of subclass of BaseMetric. + m = M2(a=100, b=1000) + expected = {'a': 100, 'b': 1000} + actual = m.get_params() + assert actual == expected + + +def test_clone(): + # Tests that clone creates deep copy of "unfit" metric. + metric = M1(a=10) + metric_new = clone(metric) + assert metric is not metric_new + assert metric.get_params() == metric_new.get_params() + + # Tests that clone doesn't copy anything beyond the parameters; e.g., + # results_ or accumulated_ + metric = M1(a=10) + metric.accumulated_['c1'] = 999999 + metric_new = clone(metric) + assert metric_new.accumulated_['c1'] == 0 + + +def test_repr(): + m = M1(a=10) + assert repr(m) == 'M1(a=10)'