Panel app performance question

I have a Docker image that serves multiple Panel apps and I am trying to accelerate them to the level they run interactively in JupyterLab. I used the pyinstrument profiler and have tried multiple approaches to multiproccessing (ProcessPoolExecutor, asyncio, concurrent.features) all of which have well documented issues ( Panel issue pointing to a Bokeh issue .

My application loads a Xarray.Dataset and populates a cache dict in memory with Matplotlib figures for each time step in the Dataset’s time dimension. This takes 30-60 seconds even after switching the Matplotlib backend to AGG and can go as low as 8 seconds in JupyterLab when I use Multiproccessing. Even if I could make the multiproccessing work it still seems a large portion of my time to action is spent on these “hidden layers” of the panel/base/template.

Here is what the profiler shows me

Panel Server running on following versions:

Python: 3.11.11
Panel: 1.5.5
Bokeh: 3.6.2
Param: 2.2.0

Here is my docker-compose.yml

version: '0.1'

services:
  panel-server:
    container_name: panel-verif-apps
    build: '.'
    restart: always
    ports:
      - "12900:12900"
    volumes:
      - /home:/home 
      - /fs:/fs
    environment:
      - AFSISIO=/home/smco502  
      - ESMFMKFILE=/opt/conda/lib/esmf.mk
    entrypoint: bash
    command: -c '
      panel serve
        --port 12900
        --ico-path path/favicon.ico
        --index /apps/index.html
        --static-dirs assets=./apps/static
        --log-level trace
        --admin
        --log-file path/panel.log
        --reuse-sessions
        --global-loading-spinner
        --profiler "pyinstrument"
        --num-threads 4
        --num-procs 5
        --unused-session-lifetime 60000
        --allow-websocket-origin="*"
        /apps/*.ipynb'

and my panel app class

class GeopotentialHeightAndRelativeHumidity(pn.viewable.Viewer):
    def __init__(self):
        self.run_options = self.extract_unique_sorted_model_runs(
            Path(CONFIG["GraphCast13Lvl_InitGDPS"]["model_base_path"] + "/pres")
        )
        self.load_plot_button = pn.widgets.Button(
            name="Load Dataset", button_type="success"
        )

        self.model_sel = pn.widgets.RadioButtonGroup(
            name="Model",
            options=sorted(MODELS.keys()),
            value="GraphCast13Lvl_InitGDPS",
            orientation="vertical",
            button_style="outline",
            button_type="success",
        )

        self.domain_sel = pn.widgets.RadioButtonGroup(
            name="Region",
            options=DOMAINS,
            # value=['Canada'],
            orientation="vertical",
            button_style="outline",
            button_type="success",
        )

        self.level_sel = pn.widgets.RadioButtonGroup(
            name="Level",
            options=CONFIG["levels"],
            value="1000",
            orientation="vertical",
            button_style="outline",
            button_type="success",
        )

        self.run_sel = pn.widgets.Select(
            name="Model Run",
            options=self.run_options,
            value=self.run_options[-1],
        )

        self.load_time = None

        self.ds = self._init_ds_loader(self)

        self.crs = get_crs(self.ds)

        self.forecast_widg = pn.widgets.DiscretePlayer(
            name="Forecast",
            value=self.closest_forecast_in_dataset(),
            options=list(self.ds.forecast.values),
            interval=1500,
            show_loop_controls=False,
            loop_policy="loop",
            # show_value=False
        )

        self.cache = {}
        self.populate_cache(self.ds)
        self.col = pn.Column("Init")
        self.col[0] = pn.pane.Matplotlib(
            self.cache[self.forecast_widg.value],
            format="svg",
            tight=True,
            fixed_aspect=True,
            sizing_mode="stretch_both"
        )
        self._display(self)

        self.forecast_widg.param.watch(self._display, "value")
        self.load_plot_button.on_click(self._load_model)

    @staticmethod
    def extract_unique_sorted_model_runs(directory: Path) -> List[str]:
        if not directory.is_dir():
            raise ValueError(f"The provided path '{directory}' is not a directory.")

        # Dictionary to store counts of files for each model run
        model_run_file_counts = defaultdict(int)

        # Populate the dictionary with counts
        for file in directory.glob("2*"):
            model_run = file.name[:10]
            model_run_file_counts[model_run] += 1

        # Extract unique model runs
        unique_model_runs = sorted(model_run_file_counts.keys())

        # Ensure all model runs have the same number of files
        file_counts = set(model_run_file_counts.values())
        if len(file_counts) > 1:
            raise ValueError(
                f"Inconsistent file counts across model runs: {model_run_file_counts}"
            )

        return unique_model_runs

    @staticmethod
    def get_key_from_value(dictionnary: dict, value):
        values = list(dictionnary.values())
        keys = list(dictionnary.keys())
        return keys[values.index(value)]

    def populate_cache(self, ds: xr.Dataset):
        # pn.state.busy = True
        self.cache = {}
        
        start = time.time()
        pn.state.notifications.info(f"Plotting and caching", duration=2000)
        for forecast_hour in ds.forecast.values:
            # pn.state.notifications.info(f"Plotting and caching forecast hour {int((forecast_hour / np.timedelta64(1, 'h')).item())}", duration=1000)
            self.cache[forecast_hour] = self._matplotlib(
                self.domain_sel.value, forecast_hour
            )
        pn.state.notifications.info(f"Finished plotting and caching", duration=2000)
        end = time.time()
        self.load_time = end - start
        # pn.state.busy = False
        
    def closest_forecast_in_dataset(self) -> np.timedelta64:
        if "time" not in self.ds.coords or "forecast" not in self.ds.coords:
            raise ValueError("Dataset must contain 'time' and 'forecast' dimensions.")

        ref_time = self.ds["time"].values
        if not isinstance(ref_time, np.datetime64):
            raise ValueError("'time' coordinate must be of type np.datetime64.")

        forecast_hours = self.ds["forecast"].values
        if not np.issubdtype(forecast_hours.dtype, np.timedelta64):
            raise ValueError("'forecast' coordinate must be of type np.timedelta64.")

        current_time = np.datetime64(datetime.now(timezone.utc))

        forecast_abs_times = ref_time + forecast_hours
        differences = np.abs(forecast_abs_times - current_time)

        closest_index = np.argmin(differences)

        return forecast_hours[closest_index]

    def _init_ds_loader(self, event: str):
        files = utils.get_files_to_process(
            CONFIG[self.model_sel.value]["model_base_path"] + "/pres",
            self.run_sel.value,
        )

        # level = utils.convert_ai_model_pres_ip1s(level, self.model_sel.value)
        ds = fstd2nc.Buffer(
            files,
            vars=["HR", "GZ"],
            filter=[f"ip1=={int(self.level_sel.value)}"],
            forecast_axis=True,
        ).to_xarray()

        self.crs = get_crs(ds)

        if self.model_sel.value == "GDPS" and "forecast1" in ds.coords:
            forecasts = np.intersect1d(ds.forecast1.values, ds.forecast2.values)
            ds = ds.sel(forecast1=forecasts, forecast2=forecasts)
            hr_ds = ds.HR.copy(deep=True).rename({"forecast2": "forecast"})
            gz_ds = ds.GZ.copy(deep=True).rename({"forecast1": "forecast"})
            ds = xr.merge([hr_ds, gz_ds])
            ds = ds.assign_coords(lon=(ds.lon - 180))
            return ds.squeeze()
        else:
            ds = ds.assign_coords(lon=(ds.lon - 180))
            return ds.squeeze()

    def _load_model(self, event: str):
        pn.state.notifications.info(f"Loading {self.model_sel.value}", duration=2000)
        files = utils.get_files_to_process(
            CONFIG[self.model_sel.value]["model_base_path"] + "/pres",
            self.run_sel.value,
        )

        # level = utils.convert_ai_model_pres_ip1s(level, self.model_sel.value)
        ds = fstd2nc.Buffer(
            files,
            vars=["HR", "GZ"],
            filter=[f"ip1=={int(self.level_sel.value)}"],
            forecast_axis=True,
        ).to_xarray()

        self.crs = get_crs(ds)

        if self.model_sel.value == "GDPS" and "forecast1" in ds.coords:
            forecasts = np.intersect1d(ds.forecast1.values, ds.forecast2.values)
            ds = ds.sel(forecast1=forecasts, forecast2=forecasts)
            hr_ds = ds.HR.copy(deep=True).rename({"forecast2": "forecast"})
            gz_ds = ds.GZ.copy(deep=True).rename({"forecast1": "forecast"})
            ds = xr.merge([hr_ds, gz_ds])
            ds = ds.assign_coords(lon=(ds.lon - 180))
            self.ds = ds.squeeze()

        elif self.model_sel.value == "GDPS_SN":
            ds = ds.roll(lon=1200, roll_coords=True)
            ds = ds.assign_coords(lon=(ds.lon - 180))
            self.ds = ds.squeeze()
            self._display(self)
        else:
            ds = ds.assign_coords(lon=(ds.lon - 180))
            self.ds = ds.squeeze()
        pn.state.notifications.success(
            f"Finished Loading\n{self.model_sel.value}", duration=4000
        )
        self.populate_cache(self.ds)
        self.forecast_widg.options = list(self.ds.forecast.values)
        self.forecast_widg.value = self.closest_forecast_in_dataset()
        self._display(self)

    def _add_features(self, ax: plt.Axes):
        ax.add_feature(cf.COASTLINE, linewidths=0.1)
        ax.add_feature(cf.LAKES, edgecolor="black", linewidths=0.1)
        ax.add_feature(cf.OCEAN, edgecolor="black", linewidths=0.1)
        ax.add_feature(cf.BORDERS, linewidths=0.1)
        ax.add_feature(cf.STATES, linewidths=0.1)
        gl = ax.gridlines(
            linewidth=0.1, color="black", draw_labels=True, linestyle="--"
        )
        gl.top_labels = gl.right_labels = False
        gl.xformatter = gl.yformatter = plt.FormatStrFormatter("%d°")
        gl.xlabel_style = gl.ylabel_style = {"size": 6}

    def _matplotlib(
        self, location: List[float], forecast_value, **kwargs
    ) -> plt.Figure:
        temp_ds = self.ds.sel(lon=slice(*location[:2]), lat=slice(*location[-2:]))
        temp_ds.HR.values = temp_ds.HR.values * 100
        fig, ax = plt.subplots(figsize=(24, 24), subplot_kw=dict(projection=self.crs))
        utils.set_limits(ax, self.crs, location)
        self._add_features(ax)

        lower_bound = 65
        if self.level_sel.value == "250":
            lower_bound = 30
        if self.level_sel.value == "500":
            lower_bound = 50

        hr = (
            temp_ds.HR.where(temp_ds.HR > lower_bound)
            .sel(forecast=forecast_value)
            .plot.contourf(
                ax=ax,
                colors=["green", "cyan", "blue"],
                transform=self.crs,
                levels=utils.get_hr_scale(self.level_sel.value),
                linewidth=2,
                zorder=2,
                alpha=0.7,
            )
        )
        gz = utils.plot_gz(
            ax=ax,
            ds=temp_ds,
            crs=self.crs,
            forecast_index=temp_ds.forecast.values.tolist().index(forecast_value),
            pressure_level=int(self.level_sel.value),
            zorder=3,
        )

        title = utils.create_title(
            self.model_sel.value,
            self.get_key_from_value(DOMAINS, location),
            f"Humidité Relative (%) + Hauteur Geopotentielle {self.level_sel.value}mb:",
            f"Relative Humidity (%) + Geopotential Height {self.level_sel.value}mb:",
            self.run_sel.value,
            np.timedelta64(forecast_value, "h"),
        )
        plt.title(title, fontsize=15)

        if self.domain_sel.value == "West" or self.domain_sel.value == "East":
            colorbar = plt.colorbar(
                hr,
                ax=ax,
                fraction=0.05,  # Adjust the height of the colorbar relative to the plot
                pad=0.02,  # Add a small padding between the plot and colorbar
                anchor=(
                    1.0,
                    0.55,
                ),  # Adjust this to shift the colorbar slightly above the bottom
                location="right",  # Ensures the colorbar stays aligned with the plot's right edge
            )
            colorbar.set_label("Relative Humidity (%)", fontsize=20)
        else:
            colorbar = plt.colorbar(
                hr,
                ax=ax,
                fraction=0.025,  # Adjust the height of the colorbar relative to the plot
                pad=0.02,  # Add a small padding between the plot and colorbar
                anchor=(
                    1.0,
                    0.55,
                ),  # Adjust this to shift the colorbar slightly above the bottom
                location="right",  # Ensures the colorbar stays aligned with the plot's right edge
            )
            colorbar.set_label("Relative Humidity (%)", fontsize=12)
        colorbar.ax.tick_params(labelsize=12)

        fig.delaxes(fig.axes[1])

        plt.close(fig)
        return fig

    # def get_load_time(self):
    #     return pn.pane.Str(self.load_time)

    def _display(self, event: str):
        self.col[0].object = self.cache[self.forecast_widg.value]


app = GeopotentialHeightAndRelativeHumidity()