Skip to content

Improve label customization classes and remove WIP markers #2456

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
214 changes: 142 additions & 72 deletions arviz/labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,77 +12,77 @@
"NoModelLabeller",
]


def mix_labellers(labellers, class_name="MixtureLabeller"):
"""Combine Labeller classes dynamically.

The Labeller class aims to split plot labeling in ArviZ into atomic tasks to maximize
extensibility, and the few classes provided are designed with small deviations
from the base class, in many cases only one method is modified by the child class.
It is to be expected then to want to use multiple classes "at once".

This functions helps combine classes dynamically.

For a general overview of ArviZ label customization, including
``mix_labellers``, see the :ref:`label_guide` page.
Allows dynamic creation of new labeller classes by combining multiple
subclasses of BaseLabeller. Used to customize labeling behavior by stacking
simple modifications.

Parameters
----------
labellers : iterable of types
Iterable of Labeller types to combine
Iterable of Labeller types to combine.
class_name : str, optional
The name of the generated class
The name of the generated class.

Returns
-------
type
Mixture class object. **It is not initialized**, and it should be
initialized before passing it to ArviZ functions.

Examples
--------
Combine the :class:`~arviz.labels.DimCoordLabeller` with the
:class:`~arviz.labels.MapLabeller` to generate labels in the style of the
``DimCoordLabeller`` but using the mappings defined by ``MapLabeller``.
Note that this works even though both modify the same methods because
``MapLabeller`` implements the mapping and then calls `super().method`.

.. jupyter-execute::

from arviz.labels import mix_labellers, DimCoordLabeller, MapLabeller
l1 = DimCoordLabeller()
sel = {"dim1": "a", "dim2": "top"}
print(f"Output of DimCoordLabeller alone > {l1.sel_to_str(sel, sel)}")
l2 = MapLabeller(dim_map={"dim1": "$d_1$", "dim2": r"$d_2$"})
print(f"Output of MapLabeller alone > {l2.sel_to_str(sel, sel)}")
l3 = mix_labellers(
(MapLabeller, DimCoordLabeller)
)(dim_map={"dim1": "$d_1$", "dim2": r"$d_2$"})
print(f"Output of mixture labeller > {l3.sel_to_str(sel, sel)}")

We can see how the mappings are taken into account as well as the dim+coord style. However,
he order in the ``labellers`` arg iterator is important! See for yourself:

.. jupyter-execute::

l4 = mix_labellers(
(DimCoordLabeller, MapLabeller)
)(dim_map={"dim1": "$d_1$", "dim2": r"$d_2$"})
print(f"Output of inverted mixture labeller > {l4.sel_to_str(sel, sel)}")
type
Dynamically created class combining provided labeller classes.

Notes
-----
The returned class is *not* initialized.
"""
return type(class_name, labellers, {})


class BaseLabeller:
"""WIP."""
"""
Base class for generating plot labels from xarray-like selections.

Provides methods for constructing human-readable labels for variables,
dimensions, coordinates, and model components when plotting data
stored in InferenceData objects.
"""

def dim_coord_to_str(self, dim, coord_val, coord_idx):
"""WIP."""
"""
Generate a label string for a single dimension/coordinate pair.

Parameters
----------
dim : str
The name of the dimension.
coord_val : any
The value of the coordinate along that dimension.
coord_idx : int
The index of the coordinate value in the dimension.

Returns
-------
str
A string representation of the coordinate value.
"""
return f"{coord_val}"

def sel_to_str(self, sel: dict, isel: dict):
"""WIP."""
"""
Convert selection dictionaries to a formatted string.

Parameters
----------
sel : dict
Dictionary of dimension name to coordinate value.
isel : dict
Dictionary of dimension name to index value.

Returns
-------
str
String representation of the selection.
"""
if sel:
return ", ".join(
[
Expand All @@ -93,23 +93,68 @@ def sel_to_str(self, sel: dict, isel: dict):
return ""

def var_name_to_str(self, var_name: Union[str, None]):
"""WIP."""
"""
Convert a variable name to a display label.

Parameters
----------
var_name : str or None
The name of the variable.

Returns
-------
str or None
The formatted variable name.
"""
return var_name

def var_pp_to_str(self, var_name, pp_var_name):
"""WIP."""
"""
Convert a pair of variable names (e.g., prior and posterior) to a combined label.

Parameters
----------
var_name : str
The name of the posterior variable.
pp_var_name : str
The name of the prior predictive variable.

Returns
-------
str
Combined label.
"""
var_name_str = self.var_name_to_str(var_name)
pp_var_name_str = self.var_name_to_str(pp_var_name)
if var_name_str == pp_var_name_str:
return f"{var_name_str}"
return f"{var_name_str} / {pp_var_name_str}"

def model_name_to_str(self, model_name):
"""WIP."""
"""
Convert a model name to a display label.

Parameters
----------
model_name : str
The name of the model.

Returns
-------
str
Display label for the model.
"""
return model_name

def make_label_vert(self, var_name: Union[str, None], sel: dict, isel: dict):
"""WIP."""
"""
Create a multiline (vertical) label for a variable and its selection.

Returns
-------
str
Label with variable name and selection.
"""
var_name_str = self.var_name_to_str(var_name)
sel_str = self.sel_to_str(sel, isel)
if not sel_str:
Expand All @@ -119,7 +164,14 @@ def make_label_vert(self, var_name: Union[str, None], sel: dict, isel: dict):
return f"{var_name_str}\n{sel_str}"

def make_label_flat(self, var_name: str, sel: dict, isel: dict):
"""WIP."""
"""
Create a flat (single-line) label with indexing format.

Returns
-------
str
Label in the format "var[dim:coord,...]".
"""
var_name_str = self.var_name_to_str(var_name)
sel_str = self.sel_to_str(sel, isel)
if not sel_str:
Expand All @@ -129,12 +181,26 @@ def make_label_flat(self, var_name: str, sel: dict, isel: dict):
return f"{var_name_str}[{sel_str}]"

def make_pp_label(self, var_name, pp_var_name, sel, isel):
"""WIP."""
"""
Create label for a prior-posterior pair.

Returns
-------
str
Multiline label showing both variable names and selection.
"""
names = self.var_pp_to_str(var_name, pp_var_name)
return self.make_label_vert(names, sel, isel)

def make_model_label(self, model_name, label):
"""WIP."""
"""
Create a model label combined with a component label.

Returns
-------
str
Combined model/component label.
"""
model_name_str = self.model_name_to_str(model_name)
if model_name_str is None:
return label
Expand All @@ -144,67 +210,71 @@ def make_model_label(self, model_name, label):


class DimCoordLabeller(BaseLabeller):
"""WIP."""
"""
Labeller that includes dimension names with coordinate values.
"""

def dim_coord_to_str(self, dim, coord_val, coord_idx):
"""WIP."""
return f"{dim}: {coord_val}"


class IdxLabeller(BaseLabeller):
"""WIP."""
"""
Labeller that uses only coordinate indices.
"""

def dim_coord_to_str(self, dim, coord_val, coord_idx):
"""WIP."""
return f"{coord_idx}"


class DimIdxLabeller(BaseLabeller):
"""WIP."""
"""
Labeller that combines dimension name with index.
"""

def dim_coord_to_str(self, dim, coord_val, coord_idx):
"""WIP."""
return f"{dim}#{coord_idx}"


class MapLabeller(BaseLabeller):
"""WIP."""
"""
Labeller that maps names and values using user-provided dictionaries.
"""

def __init__(self, var_name_map=None, dim_map=None, coord_map=None, model_name_map=None):
"""WIP."""
self.var_name_map = {} if var_name_map is None else var_name_map
self.dim_map = {} if dim_map is None else dim_map
self.coord_map = {} if coord_map is None else coord_map
self.model_name_map = {} if model_name_map is None else model_name_map

def dim_coord_to_str(self, dim, coord_val, coord_idx):
"""WIP."""
dim_str = self.dim_map.get(dim, dim)
coord_str = self.coord_map.get(dim, {}).get(coord_val, coord_val)
return super().dim_coord_to_str(dim_str, coord_str, coord_idx)

def var_name_to_str(self, var_name):
"""WIP."""
var_name_str = self.var_name_map.get(var_name, var_name)
return super().var_name_to_str(var_name_str)

def model_name_to_str(self, model_name):
"""WIP."""
model_name_str = self.var_name_map.get(model_name, model_name)
model_name_str = self.model_name_map.get(model_name, model_name)
return super().model_name_to_str(model_name_str)


class NoVarLabeller(BaseLabeller):
"""WIP."""
"""
Labeller that omits variable names.
"""

def var_name_to_str(self, var_name):
"""WIP."""
return None


class NoModelLabeller(BaseLabeller):
"""WIP."""
"""
Labeller that omits model labels entirely.
"""

def make_model_label(self, model_name, label):
"""WIP."""
return label

Loading