"""Miscellaneous types, type aliases, type guards, and other related utilities
used throughout the package."""
from __future__ import annotations
from collections.abc import Collection
from typing import TypeAlias
import numpy as np
from typing_extensions import Any, Literal, TypeIs, TypeVar
# Snippet used to generate and store the image artifacts:
# >>> def save_fig(fig, name):
# ... height = 400
# ... width = 800
# ... fig.update_layout(
# ... height=height,
# ... width=width,
# ... margin=dict(l=0, r=0, t=40, b=0),
# ... showlegend=False,
# ... )
# ... out = f"docs/_static/img/api/types/{name}.webp"
# ... print(f"Writing to: {out}")
# ... fig.write_image(
# ... out,
# ... format="webp",
# ... width=width,
# ... height=height,
# ... scale=2,
# ... engine="kaleido",
# ... )
_T = TypeVar("_T")
# ========================================================
# --- Miscellaneous types
# ========================================================
Color: TypeAlias = str | tuple[float, float, float]
"""A color can be represented by a tuple of ``(r, g, b)`` values or any valid
CSS color string - including hex, rgb/a, hsl/a, hsv/a, and named CSS colors."""
ColorScale: TypeAlias = Collection[tuple[float, Color]]
"""The canonical form for a color scale is represented by a list of tuples of
two elements:
0. the first element (a *scale value*) is a float bounded to the
interval ``[0, 1]``
1. the second element should be a valid :data:`Color` representation.
For instance, the Viridis color scale can be represented as:
>>> viridis: ColorScale = [
(0.0, 'rgb(68, 1, 84)'),
(0.1111111111111111, 'rgb(72, 40, 120)'),
(0.2222222222222222, 'rgb(62, 73, 137)'),
(0.3333333333333333, 'rgb(49, 104, 142)'),
(0.4444444444444444, 'rgb(38, 130, 142)'),
(0.5555555555555556, 'rgb(31, 158, 137)'),
(0.6666666666666666, 'rgb(53, 183, 121)'),
(0.7777777777777777, 'rgb(110, 206, 88)'),
(0.8888888888888888, 'rgb(181, 222, 43)'),
(1.0, 'rgb(253, 231, 37)')
]
"""
NormalisationOption: TypeAlias = Literal["probability", "percent"]
"""A :data:`~typing.Literal` type that represents the normalisation options
available for the ridgeplot. See :paramref:`ridgeplot.ridgeplot.norm` for more
details."""
# ========================================================
# --- Base nested Collection types (ragged arrays)
# ========================================================
CollectionL1: TypeAlias = Collection[_T]
"""A :data:`~typing.TypeAlias` for a 1-level-deep :class:`~collections.abc.Collection`.
Example
-------
>>> c1 = [1, 2, 3]
"""
CollectionL2: TypeAlias = Collection[Collection[_T]]
"""A :data:`~typing.TypeAlias` for a 2-level-deep :class:`~collections.abc.Collection`.
Example
-------
>>> c2 = [[1, 2, 3], [4, 5, 6]]
"""
CollectionL3: TypeAlias = Collection[Collection[Collection[_T]]]
"""A :data:`~typing.TypeAlias` for a 3-level-deep :class:`~collections.abc.Collection`.
Example
-------
>>> c3 = [
... [[1, 2], [3, 4]],
... [[5, 6], [7, 8]],
... ]
"""
# ========================================================
# --- Numeric types
# ========================================================
Float: TypeAlias = float | np.floating[Any]
"""A :data:`~typing.TypeAlias` for float types."""
Int: TypeAlias = int | np.integer[Any]
"""A :data:`~typing.TypeAlias` for a int types."""
Numeric: TypeAlias = Int | Float
"""A :data:`~typing.TypeAlias` for *numeric* types."""
NumericT = TypeVar("NumericT", bound=Numeric)
"""A :class:`~typing.TypeVar` variable bound to :data:`Numeric` types."""
[docs]
def _is_numeric(obj: Any) -> TypeIs[Numeric]:
"""Type guard for :data:`Numeric`.
Examples
--------
>>> _is_numeric(42)
True
>>> _is_numeric(12.3)
True
>>> _is_numeric(np.int64(17))
True
>>> _is_numeric(np.float64(3.14))
True
>>> _is_numeric("42")
False
>>> _is_numeric("12.3")
False
>>> _is_numeric([42])
False
>>> _is_numeric(None)
False
"""
return isinstance(obj, (int, float, np.number))
# ========================================================
# --- `Densities` array
# ========================================================
XYCoordinate: TypeAlias = tuple[Numeric, Numeric]
"""A 2D :math:`(x, y)` coordinate, represented as a :class:`~tuple` of
two :data:`Numeric` values.
Example
-------
>>> xy_coord = (1, 2)
"""
DensityTrace: TypeAlias = CollectionL1[XYCoordinate]
r"""A 2D line/trace represented as a collection of :math:`(x, y)` coordinates
(i.e. :data:`XYCoordinate`\s).
These are equivalent:
- ``DensityTrace``
- ``CollectionL1[XYCoordinate]``
- ``Collection[tuple[Numeric, Numeric]]``
By convention, the :math:`x` values should be non-repeating and increasing. For
instance, the following is a valid 2D line trace:
.. tab-set::
.. tab-item:: Code example
>>> density_trace = [(0, 0), (1, 1), (2, 2), (3, 1), (4, 0)]
.. tab-item:: Graphical representation
..
The plot below was generated using the following code:
>>> save_fig(ridgeplot(densities=[[density_trace]]), "density_trace")
.. image:: /_static/img/api/types/density_trace.webp
"""
DensitiesRow: TypeAlias = CollectionL1[DensityTrace]
r"""A :data:`DensitiesRow` represents a set of :data:`DensityTrace`\s that
are to be plotted on a given row of a ridgeplot.
These are equivalent:
- ``DensitiesRow``
- ``CollectionL2[XYCoordinate]``
- ``Collection[Collection[Tuple[Numeric, Numeric]]]``
Example
-------
.. tab-set::
.. tab-item:: Code example
>>> densities_row = [
... [(0, 0), (1, 1), (2, 0)], # Trace 1
... [(1, 0), (2, 1), (3, 2), (4, 1)], # Trace 2
... [(3, 0), (4, 1), (5, 2), (6, 1), (7, 0)], # Trace 3
... ]
.. tab-item:: Graphical representation
..
The plot below was generated using the following code:
>>> save_fig(ridgeplot(densities=[densities_row]), "densities_row")
.. image:: /_static/img/api/types/densities_row.webp
"""
Densities: TypeAlias = CollectionL1[DensitiesRow]
r"""The :data:`Densities` type represents the entire collection of traces that
are to be plotted on a ridgeplot.
In a ridgeplot, several traces can be plotted on different rows. Each row is
represented by a :data:`DensitiesRow` object which, in turn, is a collection of
:data:`DensityTrace`\s. Therefore, the :data:`Densities` type is a collection
of :data:`DensitiesRow`\s.
These are equivalent:
- ``Densities``
- ``CollectionL1[DensitiesRow]``
- ``CollectionL3[XYCoordinate]``
- ``Collection[Collection[Collection[Tuple[Numeric, Numeric]]]]``
Example
-------
.. tab-set::
.. tab-item:: Code example
>>> densities = [
... [ # Row 1
... [(0, 0), (1, 1), (2, 0)], # Trace 1
... [(1, 0), (2, 1), (3, 2), (4, 1)], # Trace 2
... [(3, 0), (4, 1), (5, 2), (6, 1), (7, 0)], # Trace 3
... ],
... [ # Row 2
... [(-2, 0), (-1, 1), (0, 0)], # Trace 4
... [(0, 0), (1, 1), (2, 1), (3, 0)], # Trace 5
... ],
... ]
.. tab-item:: Graphical representation
..
The plot below was generated using the following code:
>>> save_fig(ridgeplot(densities=densities, spacing=1), "densities")
.. image:: /_static/img/api/types/densities.webp
"""
ShallowDensities: TypeAlias = CollectionL1[DensityTrace]
"""Shallow type for :data:`Densities` where each row of the ridgeplot contains
only a single trace.
These are equivalent:
- ``Densities``
- ``CollectionL1[DensityTrace]``
- ``CollectionL2[XYCoordinate]``
- ``Collection[Collection[Tuple[Numeric, Numeric]]]``
Example
-------
.. tab-set::
.. tab-item:: Code example
>>> shallow_densities = [
... [(0, 0), (1, 1), (2, 0)], # Trace 1
... [(1, 0), (2, 1), (3, 0)], # Trace 2
... [(2, 0), (3, 1), (4, 0)], # Trace 3
... ]
.. tab-item:: Graphical representation
..
The plot below was generated using the following code:
>>> save_fig(ridgeplot(densities=shallow_densities), "shallow_densities")
.. image:: /_static/img/api/types/shallow_densities.webp
"""
[docs]
def is_xy_coord(obj: Any) -> TypeIs[XYCoordinate]:
"""Type guard for :data:`XYCoordinate`."""
return isinstance(obj, tuple) and len(obj) == 2 and all(map(_is_numeric, obj))
[docs]
def is_density_trace(obj: Any) -> TypeIs[DensityTrace]:
"""Type guard for :data:`DensityTrace`."""
return isinstance(obj, Collection) and all(map(is_xy_coord, obj))
[docs]
def is_shallow_densities(obj: Any) -> TypeIs[ShallowDensities]:
"""Type guard for :data:`ShallowDensities`.
Examples
--------
>>> is_shallow_densities("definitely not")
False
>>> is_shallow_densities([["also"], ["not"]])
False
>>> deep_density = [
... [
... [(0, 0), (1, 1), (2, 2), (3, 3)],
... [(0, 0), (1, 1), (2, 2)],
... [(0, 0), (1, 1), (2, 2), (3, 3), (4, 4)],
... ],
... [
... [(-2, 2), (-1, 1), (0, 1)],
... [(2, 2), (3, 1), (4, 1)],
... ],
... ]
>>> is_shallow_densities(deep_density)
False
>>> shallow_density = [[(0, 0), (1, 1)], [(2, 2), (3, 1)]]
>>> is_shallow_densities(shallow_density)
True
>>> shallow_samples = [[0, 1, 2], [2, 3, 4]]
>>> is_shallow_densities(shallow_samples)
False
"""
return isinstance(obj, Collection) and all(map(is_density_trace, obj))
# ========================================================
# --- `Samples` array
# ========================================================
SamplesTrace: TypeAlias = CollectionL1[Numeric]
"""A :data:`SamplesTrace` is a collection of numeric values representing a
set of samples from which a :data:`DensityTrace` can be estimated via KDE.
Example
-------
.. tab-set::
.. tab-item:: Code example
>>> samples_trace = [0, 1, 1, 2, 2, 2, 3, 3, 4]
.. tab-item:: Graphical representation
..
The plot below was generated using the following code:
>>> save_fig(ridgeplot(samples=[[samples_trace]]), "samples_trace")
.. image:: /_static/img/api/types/samples_trace.webp
"""
SamplesRow: TypeAlias = CollectionL1[SamplesTrace]
r"""A :data:`SamplesRow` represents a set of :data:`SamplesTrace`\s that are to be
plotted on a given row of a ridgeplot.
i.e. a :data:`SamplesRow` is a collection of :data:`SamplesTrace`\s and can be
converted into a :data:`DensitiesRow` by applying KDE to each trace.
Example
-------
.. tab-set::
.. tab-item:: Code example
>>> samples_row = [
... [0, 1, 1, 2, 2, 2, 3, 3, 4], # Trace 1
... [1, 2, 2, 3, 3, 3, 4, 4, 5], # Trace 2
... ]
.. tab-item:: Graphical representation
..
The plot below was generated using the following code:
>>> save_fig(ridgeplot(samples=[samples_row]), "samples_row")
.. image:: /_static/img/api/types/samples_row.webp
"""
Samples: TypeAlias = CollectionL1[SamplesRow]
r"""The :data:`Samples` type represents the entire collection of samples that
are to be plotted on a ridgeplot.
It is a collection of :data:`SamplesRow` objects. Each row is represented by a
:data:`SamplesRow` type which, in turn, is a collection of :data:`SamplesTrace`\s
which can be converted into :data:`DensityTrace` 's by applying a kernel density
estimation algorithm.
Therefore, the :data:`Samples` type can be converted into a :data:`Densities`
type by applying a kernel density estimation (KDE) algorithm to each trace.
See :data:`Densities` for more details.
Example
-------
.. tab-set::
.. tab-item:: Code example
>>> samples = [
... [ # Row 1
... [0, 1, 1, 2, 2, 2, 3, 3, 4], # Trace 1
... [1, 2, 2, 3, 3, 3, 4, 4, 5], # Trace 2
... ],
... [ # Row 2
... [2, 3, 3, 4, 4, 4, 5, 5, 6], # Trace 3
... [3, 4, 4, 5, 5, 5, 6, 6, 7], # Trace 4
... ],
... ]
.. tab-item:: Graphical representation
..
The plot below was generated using the following code:
>>> save_fig(ridgeplot(samples=samples), "samples")
.. image:: /_static/img/api/types/samples.webp
"""
ShallowSamples: TypeAlias = CollectionL1[SamplesTrace]
"""Shallow type for :data:`Samples` where each row of the ridgeplot contains
only a single trace.
Example
-------
.. tab-set::
.. tab-item:: Code example
>>> shallow_samples = [
... [0, 1, 1, 2, 2, 2, 3, 3, 4], # Trace 1
... [1, 2, 2, 3, 3, 3, 4, 4, 5], # Trace 2
... ]
.. tab-item:: Graphical representation
..
The plot below was generated using the following code:
>>> save_fig(ridgeplot(samples=shallow_samples), "shallow_samples")
.. image:: /_static/img/api/types/shallow_samples.webp
"""
[docs]
def is_trace_samples(obj: Any) -> TypeIs[SamplesTrace]:
"""Check if the given object is a :data:`SamplesTrace` type."""
return isinstance(obj, Collection) and all(map(_is_numeric, obj))
[docs]
def is_shallow_samples(obj: Any) -> TypeIs[ShallowSamples]:
"""Type guard for :data:`ShallowSamples`.
Examples
--------
>>> is_shallow_samples("definitely not")
False
>>> is_shallow_samples([["also"], ["not"]])
False
>>> deep_samples = [
... [
... [0, 0, 1, 1, 2, 2, 3, 3],
... [0, 0, 1, 1, 2, 2],
... [0, 0, 1, 1, 2, 2, 3, 3, 4, 4],
... ],
... [
... [-2, 2, -1, 1, 0, 1],
... [2, 2, 3, 1, 4, 1],
... ],
... ]
>>> is_shallow_samples(deep_samples)
False
>>> shallow_samples = [[0, 1, 2], [2, 3, 4]]
>>> is_shallow_samples(shallow_samples)
True
>>> shallow_density = [[(0, 0), (1, 1)], [(2, 2), (3, 1)]]
>>> is_shallow_samples(shallow_density)
False
"""
return isinstance(obj, Collection) and all(map(is_trace_samples, obj))
# ========================================================
# --- Other array types
# ========================================================
# Trace types ---
TraceType: TypeAlias = Literal["area", "bar"]
"""The type of trace to draw in a ridgeplot. See
:paramref:`ridgeplot.ridgeplot.trace_type` for more information."""
TraceTypesArray: TypeAlias = CollectionL2[TraceType]
"""A :data:`TraceTypesArray` represents the types of traces in a ridgeplot.
Example
-------
>>> trace_types_array: TraceTypesArray = [
... ["area", "bar", "area"],
... ["bar", "area"],
... ]
"""
ShallowTraceTypesArray: TypeAlias = CollectionL1[TraceType]
"""Shallow type for :data:`TraceTypesArray`.
Example
-------
>>> trace_types_array: ShallowTraceTypesArray = ["area", "bar", "area"]
"""
[docs]
def is_trace_type(obj: Any) -> TypeIs[TraceType]:
"""Type guard for :data:`TraceType`.
Examples
--------
>>> is_trace_type("area")
True
>>> is_trace_type("bar")
True
>>> is_trace_type("foo")
False
>>> is_trace_type(42)
False
"""
from typing_extensions import get_args
return isinstance(obj, str) and obj in get_args(TraceType)
[docs]
def is_shallow_trace_types_array(obj: Any) -> TypeIs[ShallowTraceTypesArray]:
"""Type guard for :data:`ShallowTraceTypesArray`.
Examples
--------
>>> is_shallow_trace_types_array(["area", "bar", "area"])
True
>>> is_shallow_trace_types_array(["area", "bar", "foo"])
False
>>> is_shallow_trace_types_array([1, 2, 3])
False
"""
return isinstance(obj, Collection) and all(map(is_trace_type, obj))
[docs]
def is_trace_types_array(obj: Any) -> TypeIs[TraceTypesArray]:
"""Type guard for :data:`TraceTypesArray`.
Examples
--------
>>> is_trace_types_array([["area", "bar"], ["area", "bar"]])
True
>>> is_trace_types_array([["area", "bar"], ["area", "foo"]])
False
>>> is_trace_types_array([["area", "bar"], ["area", 42]])
False
"""
return isinstance(obj, Collection) and all(map(is_shallow_trace_types_array, obj))
# Labels ---
LabelsArray: TypeAlias = CollectionL2[str]
"""A :data:`LabelsArray` represents the labels of traces in a ridgeplot.
Example
-------
>>> labels_array: LabelsArray = [
... ["trace 1", "trace 2", "trace 3"],
... ["trace 4", "trace 5"],
... ]
"""
ShallowLabelsArray: TypeAlias = CollectionL1[str]
"""Shallow type for :data:`LabelsArray`.
Example
-------
>>> labels_array: ShallowLabelsArray = ["trace 1", "trace 2", "trace 3"]
"""
# Sample weights ---
SampleWeights: TypeAlias = CollectionL1[Numeric] | None
"""An array of KDE weights corresponding to each sample."""
SampleWeightsArray: TypeAlias = CollectionL2[SampleWeights]
"""A :data:`SampleWeightsArray` represents the weights of the datapoints in a
:data:`Samples` array. The shape of the :data:`SampleWeightsArray` array should
match the shape of the corresponding :data:`Samples` array."""
ShallowSampleWeightsArray: TypeAlias = CollectionL1[SampleWeights]
"""Shallow type for :data:`SampleWeightsArray`."""
# ========================================================
# --- More type guards and other utilities
# ========================================================
[docs]
def is_flat_str_collection(obj: Any) -> TypeIs[CollectionL1[str]]:
"""Type guard for :data:`CollectionL1[str]`.
Note that this type-guard explicitly excludes the case where the object is
a string itself (which can be considered a collection of string characters).
Examples
--------
>>> is_flat_str_collection(["a", "b", "c"])
True
>>> is_flat_str_collection("abc")
False
>>> is_flat_str_collection(["a", "b", 1])
False
>>> is_flat_str_collection({"also", "a", "collection"})
True
>>> is_flat_str_collection((1, 2))
False
"""
if isinstance(obj, str):
# Catch edge case where the obj is actually a
# str collection, but it is a string itself
return False
return isinstance(obj, Collection) and all(isinstance(x, str) for x in obj)
[docs]
def is_flat_numeric_collection(obj: Any) -> TypeIs[CollectionL1[Numeric]]:
"""Type guard for :data:`CollectionL1[Numeric]`.
Examples
--------
>>> is_flat_numeric_collection({1, 2, 3.14})
True
>>> is_flat_numeric_collection((1, np.nan, np.inf))
True
>>> is_flat_numeric_collection([3.14, np.float64(2.71), np.int64(42)])
True
>>> is_flat_numeric_collection([1, 2, "3"])
False
>>> is_flat_numeric_collection([1, 2, None])
False
>>> is_flat_numeric_collection("definitely not")
False
"""
return isinstance(obj, Collection) and all(map(_is_numeric, obj))
[docs]
def nest_shallow_collection(shallow_collection: Collection[_T]) -> Collection[Collection[_T]]:
"""Convert a *shallow* collection type into a *deep* collection type.
This function should really only be used in the :mod:`ridgeplot._ridgeplot`
module to normalise user input.
Examples
--------
>>> nest_shallow_collection([1, "2", {"a": 3}])
[[1], ['2'], [{'a': 3}]]
"""
return [[x] for x in shallow_collection]