How to adjust theme on the fly



The key is to rebuild the pane.

import holoviews as hv
import numpy as np
import panel as pn
from bokeh.themes import built_in_themes

pn.extension()
hv.extension("bokeh")

THEMES = list(built_in_themes.keys())

def update_plot(theme):
    hv_pane = pn.pane.HoloViews(
        curve,
        theme=theme,
        sizing_mode="stretch_width"
    )
    return hv_pane

x = np.linspace(0, 10, 200)
theme_select = pn.widgets.Select(name="Bokeh theme", options=THEMES, value=THEMES[1])
curve = hv.Curve((x, np.sin(x)), "x", "sin(x)").opts(
    height=350, responsive=True, line_width=2,
)
plot_pane = pn.bind(update_plot, theme_select)
pn.Column(theme_select, plot_pane).show()

Theme builder (in progress)

import ast
import copy
import json
import re
from io import StringIO

import holoviews as hv
import numpy as np
import pandas as pd
import panel as pn
import panel_material_ui as pmui
import param
from bokeh.model import Model
from bokeh.themes import built_in_themes
from bokeh.themes.theme import Theme

pn.extension(throttled=True, notifications=True)
hv.extension("bokeh")

THEMES = list(built_in_themes.keys())

_MODEL_MAP = Model.model_class_reverse_map | {
    "figure": Model.model_class_reverse_map["Plot"]
}

THEME_MODELS = ["figure", "Grid", "Axis", "Legend", "BaseColorBar", "Title", "VBar", "HBar"]

_THEMEABLE_PREFIXES = (
    "text_",
    "label_",
    "title_",
    "axis_",
    "major_",
    "minor_",
    "grid_",
    "band_",
    "fill_",
    "line_",
    "hatch_",
    "background_",
    "border_",
    "outline_",
    "bar_",
    "inactive_",
    "item_background_",
)
_THEMEABLE_EXACT = {
    "align",
    "spacing",
    "padding",
    "margin",
    "offset",
    "standoff",
    "orientation",
    "location",
    "visible",
    "height",
    "width",
    "glyph_height",
    "glyph_width",
    "label_standoff",
    "click_policy",
    "border_radius",
    "toolbar_location",
}

_COMMENT_RE = re.compile(r"""#(?=(?:[^"']*["'][^"']*["'])*[^"']*$).*""")


def _is_themeable(name):
    return name in _THEMEABLE_EXACT or any(
        name.startswith(p) for p in _THEMEABLE_PREFIXES
    )


def strip_strings(obj):
    if isinstance(obj, dict):
        return {k: strip_strings(v) for k, v in obj.items()}
    if isinstance(obj, list):
        return [strip_strings(v) for v in obj]
    if isinstance(obj, str):
        return obj.strip()
    return obj


def validate_theme_json(theme_json):
    errors = []
    for model_name, props in theme_json.get("attrs", {}).items():
        cls = _MODEL_MAP.get(model_name)
        if cls is None:
            continue
        for prop_name, value in props.items():
            descriptor = cls.lookup(prop_name, raises=False)
            if descriptor is None:
                errors.append(f"{model_name}.{prop_name}: unknown property")
                continue
            if not descriptor.property.is_valid(value):
                errors.append(f"{model_name}.{prop_name}: invalid value {value!r}")
    return errors


def full_theme_template(base_json=None):
    base_attrs = (base_json or {}).get("attrs", {})
    attrs = {}
    for model_name in THEME_MODELS:
        cls = _MODEL_MAP.get(model_name)
        if cls is None:
            continue
        overrides = base_attrs.get(model_name, {})
        props = {}
        for name in cls.properties():
            if not _is_themeable(name):
                continue
            if name in overrides:
                props[name] = overrides[name]
            else:
                try:
                    descriptor = cls.lookup(name)
                    default = descriptor.class_default(cls)
                    if isinstance(
                        default, (str, int, float, bool, list, tuple, type(None))
                    ):
                        props[name] = default
                except Exception:
                    pass
        attrs[model_name] = dict(sorted(props.items()))
    return {"attrs": attrs}


# --- Sample data ---
np.random.seed(42)
N = 200
x = np.linspace(0, 10, N)
categories = ["A", "B", "C", "D", "E"]
df = pd.DataFrame(
    {
        "x": x,
        "sin": np.sin(x),
        "cos": np.cos(x),
        "noise": np.random.randn(N) * 0.3,
        "category": np.random.choice(categories, N),
        "value": np.random.rand(N) * 100,
        "group": np.random.choice(["G1", "G2"], N),
    }
)
df_agg = df.groupby("category")["value"].mean().reset_index()
df_hist = pd.DataFrame({"vals": np.random.randn(500)})
df_heatmap = pd.DataFrame(
    np.random.rand(10, 10),
    columns=[f"C{i}" for i in range(10)],
    index=[f"R{i}" for i in range(10)],
)

COMMON_OPTS = dict(responsive=True, min_height=300, tools=["hover"], default_tools=[])

CHART_BUILDERS = {
    "Curve": lambda: hv.Curve(
        (x, np.sin(x)),
        "x",
        "y",
    ).opts(title="Curve", **COMMON_OPTS),
    "Overlay": lambda: (
        hv.Curve((x, np.sin(x)), "x", "y", label="sin(x)").opts(line_width=2)
        * hv.Curve((x, np.cos(x)), "x", "y", label="cos(x)").opts(
            line_width=2, line_dash="dashed"
        )
    ).opts(title="Overlay", legend_position="top_right", **COMMON_OPTS),
    "Scatter": lambda: hv.Scatter(
        (df["x"], df["value"]),
        "x",
        "value",
    ).opts(title="Scatter", size=5, **COMMON_OPTS),
    "Points": lambda: hv.Points(
        df,
        ["x", "value"],
        "category",
    ).opts(
        title="Points (by category)",
        color="category",
        size=5,
        legend_position="top_right",
        **COMMON_OPTS,
    ),
    "Bars": lambda: hv.Bars(
        df_agg,
        "category",
        "value",
    ).opts(title="Bars", **COMMON_OPTS),
    "Area": lambda: hv.Area(
        (x, np.sin(x)),
        "x",
        "y",
    ).opts(title="Area", fill_alpha=0.4, **COMMON_OPTS),
    "Spread": lambda: hv.Spread(
        (x, np.sin(x), np.full(N, 0.3)),
        "x",
        ["y", "y2"],
    ).opts(title="Spread", fill_alpha=0.3, **COMMON_OPTS),
    "Histogram": lambda: hv.Histogram(
        np.histogram(df_hist["vals"], bins=30),
    ).opts(title="Histogram", **COMMON_OPTS),
    "Distribution": lambda: hv.Distribution(
        df_hist["vals"],
    ).opts(title="Distribution (KDE)", filled=True, fill_alpha=0.4, **COMMON_OPTS),
    "HeatMap": lambda: hv.HeatMap(
        [
            (c, r, df_heatmap.loc[r, c])
            for r in df_heatmap.index
            for c in df_heatmap.columns
        ],
    ).opts(title="HeatMap", colorbar=True, **COMMON_OPTS),
    "Image": lambda: hv.Image(
        np.sin(np.linspace(0, 10, 100).reshape(1, -1))
        * np.cos(np.linspace(0, 10, 100).reshape(-1, 1)),
        bounds=(-1, -1, 1, 1),
    ).opts(title="Image", colorbar=True, **COMMON_OPTS),
    "HexTiles": lambda: hv.HexTiles(
        (np.random.randn(500), np.random.randn(500)),
    ).opts(title="HexTiles", colorbar=True, **COMMON_OPTS),
    "BoxWhisker": lambda: hv.BoxWhisker(
        df,
        "category",
        "value",
    ).opts(title="BoxWhisker", **COMMON_OPTS),
    "Violin": lambda: hv.Violin(
        df,
        "category",
        "value",
    ).opts(title="Violin", **COMMON_OPTS),
    "ErrorBars": lambda: (
        hv.Curve((x[::20], np.sin(x[::20])), "x", "y")
        * hv.ErrorBars(
            {"x": x[::20], "y": np.sin(x[::20]), "yerr": np.full(10, 0.2)},
            "x",
            vdims=["y", "yerr"],
        )
    ).opts(title="ErrorBars", **COMMON_OPTS),
    "Spikes": lambda: hv.Spikes(
        np.random.randn(50),
    ).opts(title="Spikes", **COMMON_OPTS),
    "Labels": lambda: (
        hv.Scatter((df_agg["category"], df_agg["value"]), "category", "value")
        * hv.Labels(
            {
                "x": df_agg["category"],
                "y": df_agg["value"],
                "text": [f"{v:.0f}" for v in df_agg["value"]],
            },
            ["x", "y"],
            "text",
        )
    ).opts(title="Labels", **COMMON_OPTS),
"Annotations": lambda: (
        hv.Curve((x, np.sin(x)), "x", "y")
        * hv.HLine(0).opts(line_dash="dashed", line_color="red")
        * hv.VLine(np.pi).opts(line_dash="dotted", line_color="green")
        * hv.Slope(0.1, 0).opts(line_dash="dashdot", line_color="orange")
        * hv.HSpan(0.5, 1.0).opts(fill_alpha=0.1, fill_color="red")
        * hv.VSpan(4, 6).opts(fill_alpha=0.1, fill_color="blue")
    ).opts(title="Annotations", **COMMON_OPTS),
    "Contours": lambda: hv.operation.contours(
        hv.Image(
            np.sin(np.linspace(0, 10, 100).reshape(1, -1))
            * np.cos(np.linspace(0, 10, 100).reshape(-1, 1)),
            bounds=(-1, -1, 1, 1),
        ),
        levels=8,
    ).opts(title="Contours", colorbar=True, **COMMON_OPTS),
    "VectorField": lambda: hv.VectorField(
        (
            np.tile(np.linspace(0, 1, 10), 10),
            np.repeat(np.linspace(0, 1, 10), 10),
            np.random.rand(100) * 2 * np.pi,
            np.random.rand(100),
        ),
    ).opts(title="VectorField", **COMMON_OPTS),
    "HBars": lambda: hv.Bars(
        df_agg, "category", "value",
    ).opts(title="HBars", invert_axes=True, **COMMON_OPTS),
}

ALL_CHARTS = list(CHART_BUILDERS.keys())


class ThemeSwitcher(pn.viewable.Viewer):

    theme_name = param.Selector(default=THEMES[1], objects=THEMES, doc="Bokeh theme")
    visible_charts = param.ListSelector(default=ALL_CHARTS, objects=ALL_CHARTS)
    json_mode = param.Selector(
        default="tree", objects=["tree", "text"], doc="JSON editor mode"
    )

    def __init__(self, **params):
        self._chart_placeholder = pn.pane.Placeholder()
        self._json_editor = pn.widgets.JSONEditor(
            value={},
            sizing_mode="stretch_width",
        )
        self._json_editor.param.watch(self._on_json_edit, "value")
        self._text_input = pn.widgets.TextAreaInput(
            placeholder="Paste JSON or Python dict here...",
            sizing_mode="stretch_width",
            height=150,
            auto_grow=True,
        )
        self._text_input.param.watch(self._on_text_paste, "value")
        self._last_valid_theme = None
        super().__init__(**params)
        with pn.config.set(sizing_mode="stretch_width"):
            self._json_button_group = pmui.RadioButtonGroup.from_param(
                self.param.json_mode,
            )
            self._theme_widget = pmui.Select.from_param(
                self.param.theme_name,
                label="Bokeh theme",
            )
            self._load_btn = pmui.Button(
                label="Load theme",
                color="primary",
                variant="outlined",
                height=55,
            )
            self._load_btn.on_click(self._load_theme)
            self._expand_btn = pmui.Button(
                label="Expand all keys",
                color="secondary",
                variant="outlined",
            )
            self._expand_btn.on_click(self._expand_keys)
            self._download_btn = pmui.FileDownload(
                callback=self._download_theme,
                filename="theme.json",
                label="Download theme",
                color="success",
                variant="outlined",
                height=55,
            )
            self._charts_widget = pmui.MultiSelect.from_param(
                self.param.visible_charts,
                label="Charts",
                size=min(len(ALL_CHARTS), 10),
            )
        self._load_theme(None)

    @param.depends("json_mode", watch=True)
    def _on_json_mode_changed(self):
        self._json_editor.mode = self.json_mode

    def _load_theme(self, event):
        self._json_editor.value = copy.deepcopy(built_in_themes[self.theme_name]._json)

    def _expand_keys(self, event):
        self._json_editor.value = full_theme_template(self._json_editor.value)

    def _download_theme(self):
        cleaned = strip_strings(self._json_editor.value)
        sio = StringIO()
        json.dump(cleaned, sio, indent=2)
        sio.seek(0)
        return sio

    def _on_text_paste(self, event):
        raw = (event.new or "").strip()
        if not raw:
            return
        # Try JSON first
        try:
            self._json_editor.value = json.loads(raw)
            self._text_input.value = ""
            return
        except json.JSONDecodeError:
            pass
        # Try Python dict literal (handles None, True, False, trailing commas, comments)
        try:
            lines = [_COMMENT_RE.sub("", line) for line in raw.splitlines()]
            cleaned = "\n".join(lines)
            self._json_editor.value = ast.literal_eval(cleaned)
            self._text_input.value = ""
            return
        except (ValueError, SyntaxError):
            pass
        # Last resort: naive cleanup (strip comments after #, but protect quoted strings)
        try:
            cleaned = (
                raw.replace("None", "null")
                .replace("True", "true")
                .replace("False", "false")
            )
            # Remove trailing commas before } or ]
            cleaned = re.sub(r",\s*([}\]])", r"\1", cleaned)
            # Remove full-line comments
            cleaned = re.sub(r"^\s*#.*$", "", cleaned, flags=re.MULTILINE)
            self._json_editor.value = json.loads(cleaned)
            self._text_input.value = ""
            return
        except (json.JSONDecodeError, ValueError):
            pass
        pn.state.notifications.error(
            "⚠️ Could not parse as JSON or Python dict", duration=5000
        )

    def _on_json_edit(self, event):
        cleaned = strip_strings(event.new)
        errors = validate_theme_json(cleaned)
        if errors:
            pn.state.notifications.error(
                "⚠️ " + "; ".join(errors),
                duration=5000,
            )
            return
        theme = Theme(json=cleaned)
        self._last_valid_theme = theme
        self._render()

    @param.depends("visible_charts", watch=True)
    def _on_charts_changed(self):
        if self._last_valid_theme is not None:
            self._render()

    def _render(self):
        plots = [
            CHART_BUILDERS[name]()
            for name in self.visible_charts
            if name in CHART_BUILDERS
        ]
        if not plots:
            self._chart_placeholder.update(pn.pane.Markdown("*No charts selected*"))
            return
        layout = hv.Layout(plots).cols(4).opts(shared_axes=False, toolbar=None)
        self._chart_placeholder.update(
            pn.pane.HoloViews(
                layout,
                theme=self._last_valid_theme,
                sizing_mode="stretch_both",
            )
        )

    def __panel__(self):
        return pmui.Row(
            pmui.Column(
                pmui.Row(
                    self._theme_widget,
                    self._load_btn,
                    styles={"align-items": "stretch"},
                ),
                self._json_button_group,
                self._charts_widget,
                self._text_input,
                self._expand_btn,
                self._json_editor,
                self._download_btn,
                max_width=450,
                sizing_mode="stretch_both",
            ),
            pmui.Column(self._chart_placeholder, sizing_mode="stretch_both"),
        )


ThemeSwitcher().show()
1 Like