I have a Docker image that serves multiple Panel apps and I am trying to accelerate them to the level they run interactively in JupyterLab. I used the pyinstrument profiler and have tried multiple approaches to multiproccessing (ProcessPoolExecutor, asyncio, concurrent.features) all of which have well documented issues ( Panel issue pointing to a Bokeh issue .
My application loads a Xarray.Dataset
and populates a cache dict in memory with Matplotlib figures for each time step in the Dataset’s time dimension. This takes 30-60 seconds even after switching the Matplotlib backend to AGG
and can go as low as 8 seconds in JupyterLab when I use Multiproccessing. Even if I could make the multiproccessing work it still seems a large portion of my time to action is spent on these “hidden layers” of the panel/base/template
.
Here is what the profiler shows me
Panel Server running on following versions:
Python: 3.11.11
Panel: 1.5.5
Bokeh: 3.6.2
Param: 2.2.0
Here is my docker-compose.yml
version: '0.1'
services:
panel-server:
container_name: panel-verif-apps
build: '.'
restart: always
ports:
- "12900:12900"
volumes:
- /home:/home
- /fs:/fs
environment:
- AFSISIO=/home/smco502
- ESMFMKFILE=/opt/conda/lib/esmf.mk
entrypoint: bash
command: -c '
panel serve
--port 12900
--ico-path path/favicon.ico
--index /apps/index.html
--static-dirs assets=./apps/static
--log-level trace
--admin
--log-file path/panel.log
--reuse-sessions
--global-loading-spinner
--profiler "pyinstrument"
--num-threads 4
--num-procs 5
--unused-session-lifetime 60000
--allow-websocket-origin="*"
/apps/*.ipynb'
and my panel app class
class GeopotentialHeightAndRelativeHumidity(pn.viewable.Viewer):
def __init__(self):
self.run_options = self.extract_unique_sorted_model_runs(
Path(CONFIG["GraphCast13Lvl_InitGDPS"]["model_base_path"] + "/pres")
)
self.load_plot_button = pn.widgets.Button(
name="Load Dataset", button_type="success"
)
self.model_sel = pn.widgets.RadioButtonGroup(
name="Model",
options=sorted(MODELS.keys()),
value="GraphCast13Lvl_InitGDPS",
orientation="vertical",
button_style="outline",
button_type="success",
)
self.domain_sel = pn.widgets.RadioButtonGroup(
name="Region",
options=DOMAINS,
# value=['Canada'],
orientation="vertical",
button_style="outline",
button_type="success",
)
self.level_sel = pn.widgets.RadioButtonGroup(
name="Level",
options=CONFIG["levels"],
value="1000",
orientation="vertical",
button_style="outline",
button_type="success",
)
self.run_sel = pn.widgets.Select(
name="Model Run",
options=self.run_options,
value=self.run_options[-1],
)
self.load_time = None
self.ds = self._init_ds_loader(self)
self.crs = get_crs(self.ds)
self.forecast_widg = pn.widgets.DiscretePlayer(
name="Forecast",
value=self.closest_forecast_in_dataset(),
options=list(self.ds.forecast.values),
interval=1500,
show_loop_controls=False,
loop_policy="loop",
# show_value=False
)
self.cache = {}
self.populate_cache(self.ds)
self.col = pn.Column("Init")
self.col[0] = pn.pane.Matplotlib(
self.cache[self.forecast_widg.value],
format="svg",
tight=True,
fixed_aspect=True,
sizing_mode="stretch_both"
)
self._display(self)
self.forecast_widg.param.watch(self._display, "value")
self.load_plot_button.on_click(self._load_model)
@staticmethod
def extract_unique_sorted_model_runs(directory: Path) -> List[str]:
if not directory.is_dir():
raise ValueError(f"The provided path '{directory}' is not a directory.")
# Dictionary to store counts of files for each model run
model_run_file_counts = defaultdict(int)
# Populate the dictionary with counts
for file in directory.glob("2*"):
model_run = file.name[:10]
model_run_file_counts[model_run] += 1
# Extract unique model runs
unique_model_runs = sorted(model_run_file_counts.keys())
# Ensure all model runs have the same number of files
file_counts = set(model_run_file_counts.values())
if len(file_counts) > 1:
raise ValueError(
f"Inconsistent file counts across model runs: {model_run_file_counts}"
)
return unique_model_runs
@staticmethod
def get_key_from_value(dictionnary: dict, value):
values = list(dictionnary.values())
keys = list(dictionnary.keys())
return keys[values.index(value)]
def populate_cache(self, ds: xr.Dataset):
# pn.state.busy = True
self.cache = {}
start = time.time()
pn.state.notifications.info(f"Plotting and caching", duration=2000)
for forecast_hour in ds.forecast.values:
# pn.state.notifications.info(f"Plotting and caching forecast hour {int((forecast_hour / np.timedelta64(1, 'h')).item())}", duration=1000)
self.cache[forecast_hour] = self._matplotlib(
self.domain_sel.value, forecast_hour
)
pn.state.notifications.info(f"Finished plotting and caching", duration=2000)
end = time.time()
self.load_time = end - start
# pn.state.busy = False
def closest_forecast_in_dataset(self) -> np.timedelta64:
if "time" not in self.ds.coords or "forecast" not in self.ds.coords:
raise ValueError("Dataset must contain 'time' and 'forecast' dimensions.")
ref_time = self.ds["time"].values
if not isinstance(ref_time, np.datetime64):
raise ValueError("'time' coordinate must be of type np.datetime64.")
forecast_hours = self.ds["forecast"].values
if not np.issubdtype(forecast_hours.dtype, np.timedelta64):
raise ValueError("'forecast' coordinate must be of type np.timedelta64.")
current_time = np.datetime64(datetime.now(timezone.utc))
forecast_abs_times = ref_time + forecast_hours
differences = np.abs(forecast_abs_times - current_time)
closest_index = np.argmin(differences)
return forecast_hours[closest_index]
def _init_ds_loader(self, event: str):
files = utils.get_files_to_process(
CONFIG[self.model_sel.value]["model_base_path"] + "/pres",
self.run_sel.value,
)
# level = utils.convert_ai_model_pres_ip1s(level, self.model_sel.value)
ds = fstd2nc.Buffer(
files,
vars=["HR", "GZ"],
filter=[f"ip1=={int(self.level_sel.value)}"],
forecast_axis=True,
).to_xarray()
self.crs = get_crs(ds)
if self.model_sel.value == "GDPS" and "forecast1" in ds.coords:
forecasts = np.intersect1d(ds.forecast1.values, ds.forecast2.values)
ds = ds.sel(forecast1=forecasts, forecast2=forecasts)
hr_ds = ds.HR.copy(deep=True).rename({"forecast2": "forecast"})
gz_ds = ds.GZ.copy(deep=True).rename({"forecast1": "forecast"})
ds = xr.merge([hr_ds, gz_ds])
ds = ds.assign_coords(lon=(ds.lon - 180))
return ds.squeeze()
else:
ds = ds.assign_coords(lon=(ds.lon - 180))
return ds.squeeze()
def _load_model(self, event: str):
pn.state.notifications.info(f"Loading {self.model_sel.value}", duration=2000)
files = utils.get_files_to_process(
CONFIG[self.model_sel.value]["model_base_path"] + "/pres",
self.run_sel.value,
)
# level = utils.convert_ai_model_pres_ip1s(level, self.model_sel.value)
ds = fstd2nc.Buffer(
files,
vars=["HR", "GZ"],
filter=[f"ip1=={int(self.level_sel.value)}"],
forecast_axis=True,
).to_xarray()
self.crs = get_crs(ds)
if self.model_sel.value == "GDPS" and "forecast1" in ds.coords:
forecasts = np.intersect1d(ds.forecast1.values, ds.forecast2.values)
ds = ds.sel(forecast1=forecasts, forecast2=forecasts)
hr_ds = ds.HR.copy(deep=True).rename({"forecast2": "forecast"})
gz_ds = ds.GZ.copy(deep=True).rename({"forecast1": "forecast"})
ds = xr.merge([hr_ds, gz_ds])
ds = ds.assign_coords(lon=(ds.lon - 180))
self.ds = ds.squeeze()
elif self.model_sel.value == "GDPS_SN":
ds = ds.roll(lon=1200, roll_coords=True)
ds = ds.assign_coords(lon=(ds.lon - 180))
self.ds = ds.squeeze()
self._display(self)
else:
ds = ds.assign_coords(lon=(ds.lon - 180))
self.ds = ds.squeeze()
pn.state.notifications.success(
f"Finished Loading\n{self.model_sel.value}", duration=4000
)
self.populate_cache(self.ds)
self.forecast_widg.options = list(self.ds.forecast.values)
self.forecast_widg.value = self.closest_forecast_in_dataset()
self._display(self)
def _add_features(self, ax: plt.Axes):
ax.add_feature(cf.COASTLINE, linewidths=0.1)
ax.add_feature(cf.LAKES, edgecolor="black", linewidths=0.1)
ax.add_feature(cf.OCEAN, edgecolor="black", linewidths=0.1)
ax.add_feature(cf.BORDERS, linewidths=0.1)
ax.add_feature(cf.STATES, linewidths=0.1)
gl = ax.gridlines(
linewidth=0.1, color="black", draw_labels=True, linestyle="--"
)
gl.top_labels = gl.right_labels = False
gl.xformatter = gl.yformatter = plt.FormatStrFormatter("%d°")
gl.xlabel_style = gl.ylabel_style = {"size": 6}
def _matplotlib(
self, location: List[float], forecast_value, **kwargs
) -> plt.Figure:
temp_ds = self.ds.sel(lon=slice(*location[:2]), lat=slice(*location[-2:]))
temp_ds.HR.values = temp_ds.HR.values * 100
fig, ax = plt.subplots(figsize=(24, 24), subplot_kw=dict(projection=self.crs))
utils.set_limits(ax, self.crs, location)
self._add_features(ax)
lower_bound = 65
if self.level_sel.value == "250":
lower_bound = 30
if self.level_sel.value == "500":
lower_bound = 50
hr = (
temp_ds.HR.where(temp_ds.HR > lower_bound)
.sel(forecast=forecast_value)
.plot.contourf(
ax=ax,
colors=["green", "cyan", "blue"],
transform=self.crs,
levels=utils.get_hr_scale(self.level_sel.value),
linewidth=2,
zorder=2,
alpha=0.7,
)
)
gz = utils.plot_gz(
ax=ax,
ds=temp_ds,
crs=self.crs,
forecast_index=temp_ds.forecast.values.tolist().index(forecast_value),
pressure_level=int(self.level_sel.value),
zorder=3,
)
title = utils.create_title(
self.model_sel.value,
self.get_key_from_value(DOMAINS, location),
f"Humidité Relative (%) + Hauteur Geopotentielle {self.level_sel.value}mb:",
f"Relative Humidity (%) + Geopotential Height {self.level_sel.value}mb:",
self.run_sel.value,
np.timedelta64(forecast_value, "h"),
)
plt.title(title, fontsize=15)
if self.domain_sel.value == "West" or self.domain_sel.value == "East":
colorbar = plt.colorbar(
hr,
ax=ax,
fraction=0.05, # Adjust the height of the colorbar relative to the plot
pad=0.02, # Add a small padding between the plot and colorbar
anchor=(
1.0,
0.55,
), # Adjust this to shift the colorbar slightly above the bottom
location="right", # Ensures the colorbar stays aligned with the plot's right edge
)
colorbar.set_label("Relative Humidity (%)", fontsize=20)
else:
colorbar = plt.colorbar(
hr,
ax=ax,
fraction=0.025, # Adjust the height of the colorbar relative to the plot
pad=0.02, # Add a small padding between the plot and colorbar
anchor=(
1.0,
0.55,
), # Adjust this to shift the colorbar slightly above the bottom
location="right", # Ensures the colorbar stays aligned with the plot's right edge
)
colorbar.set_label("Relative Humidity (%)", fontsize=12)
colorbar.ax.tick_params(labelsize=12)
fig.delaxes(fig.axes[1])
plt.close(fig)
return fig
# def get_load_time(self):
# return pn.pane.Str(self.load_time)
def _display(self, event: str):
self.col[0].object = self.cache[self.forecast_widg.value]
app = GeopotentialHeightAndRelativeHumidity()