diff --git a/CHANGELOG.md b/CHANGELOG.md index 685555e..97f15fe 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ All notable changes to this project will be documented in this file. The format ## [Unreleased] ### Added Features and Improvements 🙌: - Added support for Matplotlib 3.11 +- `pplt.legend` now merges entries sharing the same label and color but differing in appearance (e.g. line vs. marker) or only in alpha into a single colored box. ## [0.13.2] - 2026-05-07 diff --git a/src/prettypyplot/pyplot.py b/src/prettypyplot/pyplot.py index b29e9e9..f4f138b 100644 --- a/src/prettypyplot/pyplot.py +++ b/src/prettypyplot/pyplot.py @@ -373,9 +373,11 @@ def _legend_deduplicate(handles, labels): Entries that share the same label and identical visual appearance are collapsed to a single entry (keeping the first occurrence). Entries that - share the same label and the same color but have *different* visual - appearances (e.g. a line, a marker, a bar, a patch) are all replaced by a - single filled-square [matplotlib.patches.Patch][] of that color. + share the same label and the same color (ignoring alpha) but have + *different* visual appearances are all replaced by a single filled-square + [matplotlib.patches.Patch][] of that color. A differing appearance can be + a different handle type (e.g. a line, a marker, a bar, a patch) or the same + color differing only in its alpha (transparency). Parameters ---------- @@ -392,15 +394,19 @@ def _legend_deduplicate(handles, labels): labels : list of str Deduplicated labels. """ - # group by (label, color), preserving first-seen insertion order - groups = {} # (label, color_key) -> {'handle_keys': set, 'first_handle': handle} - order = [] # insertion-order list of (label, color_key) + # group by (label, rgb) ignoring alpha, preserving first-seen insertion order + groups = {} # (label, rgb_key) -> {'handle_keys': set, 'first_handle': handle} + order = [] # insertion-order list of (label, rgb_key) for handle, label in zip(handles, labels): color = _legend_handle_color(handle) - color_key = tuple(round(c, 6) for c in color) if color is not None else None - group_key = (label, color_key) - hkey = _legend_handle_key(handle) + # group by rgb only so entries differing only in alpha are merged + rgb_key = tuple(round(c, 6) for c in color[:3]) if color is not None else None + # appearance key includes the full rgba so an alpha-only difference + # registers as a distinct appearance and triggers the filled square + rgba_key = tuple(round(c, 6) for c in color) if color is not None else None + group_key = (label, rgb_key) + hkey = (_legend_handle_key(handle), rgba_key) if group_key not in groups: groups[group_key] = {'handle_keys': {hkey}, 'first_handle': handle} @@ -410,11 +416,11 @@ def _legend_deduplicate(handles, labels): unique_handles, unique_labels = [], [] for group_key in order: - label, color_key = group_key + label, rgb_key = group_key entry = groups[group_key] - if len(entry['handle_keys']) > 1 and color_key is not None: + if len(entry['handle_keys']) > 1 and rgb_key is not None: # same label, same color, different appearances → filled square - patch = mpatches.Patch(facecolor=color_key, edgecolor='none') + patch = mpatches.Patch(facecolor=(*rgb_key, 1.0), edgecolor='none') unique_handles.append(patch) else: unique_handles.append(entry['first_handle']) @@ -479,22 +485,35 @@ def _to_rgba(color): return None if isinstance(handle, mlines.Line2D): - return _to_rgba(handle.get_color()) + return _apply_artist_alpha(_to_rgba(handle.get_color()), handle) if isinstance(handle, mpatches.Patch): fc = handle.get_facecolor() - return tuple(fc) if len(fc) == 4 else _to_rgba(fc) + rgba = tuple(fc) if len(fc) == 4 else _to_rgba(fc) + return _apply_artist_alpha(rgba, handle) if isinstance(handle, PathCollection): fc = handle.get_facecolor() if len(fc): - return tuple(fc[0]) + return _apply_artist_alpha(tuple(fc[0]), handle) return None if isinstance(handle, ErrorbarContainer): - return _to_rgba(handle[0].get_color()) + line = handle[0] + return _apply_artist_alpha(_to_rgba(line.get_color()), line) if isinstance(handle, BarContainer): - return tuple(handle.patches[0].get_facecolor()) + patch = handle.patches[0] + return _apply_artist_alpha(tuple(patch.get_facecolor()), patch) return None +def _apply_artist_alpha(rgba, artist): + """Override an RGBA tuple's alpha channel with the artist's alpha if set.""" + if rgba is None: + return None + alpha = getattr(artist, 'get_alpha', lambda: None)() + if alpha is not None: + return (*rgba[:3], alpha) + return rgba + + def _legend_handle_key(handle): """Return a hashable visual key for a legend handle.""" if isinstance(handle, mlines.Line2D): diff --git a/tests/test_pyplot.py b/tests/test_pyplot.py index 485f8c8..e0c6ba0 100644 --- a/tests/test_pyplot.py +++ b/tests/test_pyplot.py @@ -211,6 +211,37 @@ def test_legend_dedup_same_color_different_patches(): plt.close(fig) +def test_legend_dedup_same_color_different_alpha(): + """Same label + same color differing only in alpha → single filled square.""" + prettypyplot.use_style() + fig, ax = plt.subplots() + ax.plot([0, 1], [0, 1], color=(0.1, 0.2, 0.7, 1.0), label='data') + ax.plot([0, 1], [0.5, 0.5], color=(0.1, 0.2, 0.7, 0.3), label='data') + + leg = prettypyplot.legend(ax=ax) + assert len(leg.get_texts()) == 1 + handle = leg.legend_handles[0] + assert isinstance(handle, mpatches.Patch) + # the merged box is opaque and keeps the shared rgb + fc = handle.get_facecolor() + assert fc[:3] == pytest.approx((0.1, 0.2, 0.7)) + assert fc[3] == pytest.approx(1.0) + plt.close(fig) + + +def test_legend_dedup_alpha_via_kwarg(): + """Alpha set via the separate ``alpha`` kwarg is also detected and merged.""" + prettypyplot.use_style() + fig, ax = plt.subplots() + ax.plot([0, 1], [0, 1], color='C0', label='data') + ax.plot([0, 1], [0.5, 0.5], color='C0', alpha=0.3, label='data') + + leg = prettypyplot.legend(ax=ax) + assert len(leg.get_texts()) == 1 + assert isinstance(leg.legend_handles[0], mpatches.Patch) + plt.close(fig) + + def test_legend_handle_color_line2d(): """_legend_handle_color returns RGBA tuple for Line2D.""" fig, ax = plt.subplots()