Memory leack when plotting xarray dataset

Hi,

I just added a comment on a GitHub issue:

as I am experiencing a similar behavior, memory is not released even when the session is discarded.
I am posting the issue here as well as the problem may be different, probably on wrong usage of the on_session_destroyed() method.

It looks a bit hackish, but I tried to reduce the code a MRE, attached below:


import hvplot.xarray
import xarray as xr
import panel as pn
import holoviews as hv
import sys
from bokeh.layouts import column, Spacer
import gc

pn.param.ParamMethod.loading_indicator = True
ds = None

def on_server_loaded():
    print("server loaded")
    print("")
    sys.stdout.flush()
    
def on_session_created(session_context):
    print("session created")
    print("")
    sys.stdout.flush()

def on_session_destroyed(session_context):
    print("session destroyed")
    print("")
    print(dir(session_context))
    try:
        del ds
        gc.collect()
    except UnboundLocalError:
        pass
    try:
        del plot_widget
        gc.collect()
    except UnboundLocalError:
        pass
    plot_widget = None
    gc.collect()
    sys.stdout.flush()
    
    
def load_data(url):
    try:
        del ds
        gc.collect()
    except UnboundLocalError:
        pass
    ds = None
    try:
        # attempt to load the dataset via xarray with decode_times=True
        ds = xr.open_dataset(str(url).strip())
    except ValueError as e:
        # attempt to load the dataset via xarray with decode_times=False
        ds = xr.open_dataset(str(url).strip(), decode_times=False)
        print(e)
    except OSError as e:
        print(e)
    if ds and not ds.coords:
        del ds
        gc.collect()
        # the following hack is used when the dataset is served as a tabledap via erdap
        erdapp_uglyness = list(dict(xr.open_dataset(url).dims).keys())[0]
        renamed_vars = {i:i.replace(erdapp_uglyness+".", "") for i in list(xr.open_dataset(url).variables.keys())}
        new_nc_url = url+'?'+'time,'+','.join(list(xr.open_dataset(url).variables)).replace(f"{erdapp_uglyness}.", "").replace(f"time,", "")
        ds = xr.open_dataset(new_nc_url)
        ds = ds.set_coords(f"{erdapp_uglyness}.time")
        ds = ds.swap_dims(s=f"time")
        ds = ds.set_xindex(f"{erdapp_uglyness}.time")
        ds = ds.rename_vars(renamed_vars)
    return ds


        
def plot(var, title=None):
    try:
        del plot_widget
        gc.collect()
    except UnboundLocalError:
        pass
    plot_widget = None
    try:
        del ds
        gc.collect()
    except UnboundLocalError:
        pass
    
    gc.collect()
    ds = load_data(url)
    var = var[0]
    print(f'plotting var: {var}')
    if not title:
        try:
            title = f"{ds[var].attrs['long_name']}"
        except KeyError:
            title = f"{var}"
    else:
        title=title
    if 'featureType' in ds.attrs:
        featureType = ds.attrs['featureType'].lower()
    elif 'cdm_data_type' in ds[var].attrs:
        featureType = ds[var].attrs['cdm_data_type'].lower()
    else:
        featureType = None
    is_monotonic = False
    if featureType == 'timeseries':
        axis_arguments = {'grid':True, 'title': title, 'responsive': True}
        plot_widget =  ds[var].where(ds[var] != 9.96921e36).hvplot.line(**axis_arguments)
        return plot_widget
    if featureType != "timeseries":
        axis_arguments = {'x': ds[var], 'grid':True, 'title': title, 'widget_location': 'bottom', 'responsive': True}
        try:
            plot_widget =  ds[var].where(ds[var] != 9.96921e36).hvplot.line(**axis_arguments)
        except TypeError:
            axis_arguments = {'grid':True, 'title': title, 'widget_location': 'bottom', 'responsive': True}
            plot_widget =  ds[var].where(ds[var] != 9.96921e36).hvplot.line(**axis_arguments)
        except ValueError:
            axis_arguments = {'x': var, 'grid':True, 'title': title, 'widget_location': 'bottom', 'responsive': True}
            plot_widget =  ds[var].where(ds[var] != 9.96921e36).hvplot.line(**axis_arguments)
        return plot_widget        


def on_var_select(event):
    var = event.obj.value
    result = [key for key, value in mapping_var_names.items() if value == var]
    with pn.param.set_values(main_app, loading=True):
        plot_container[-1] = plot(var=result, title=var)
        print(f'selected {result}')


def safe_check(var):
    try:
        ds[var].values
        return var
    except Exception as e:
        # Handle the exception (e.g., log it, return False, etc.)
        print(f"Error processing {var}: {e}")
        return False
    

pn.state.onload(callback=on_server_loaded)
# pn.state.on_session_created(callback=on_session_created)
pn.state.on_session_destroyed(callback=on_session_destroyed)
url = pn.state.session_args.get('url')[0].decode("utf8")

print("++++++++++++++++++++++++ LOADING ++++++++++++++++++++++++++++++++++++")
print(str(url))
print("++++++++++++++++++++++++ +++++++ ++++++++++++++++++++++++++++++++++++")
ds = load_data(url)

if ds:
    plottable_vars = [j for j in ds if len([value for value in list(ds[j].coords) if value in list(ds.dims)]) >= 1]
    plottable_vars = [i for i in plottable_vars if safe_check(i)]
    print("plottable_vars:", plottable_vars )
    mapping_var_names = {}
    for i in plottable_vars:
        if int(len(list(ds[i].coords)) != 0):
            try:
                title = f"{ds[i].attrs['long_name']} [{i}]"
            except KeyError:
                title = f"{i}"
            mapping_var_names[i] = title
            
    # add a select widget for variables, uses long names
    variables_selector = pn.widgets.Select(options=list(mapping_var_names.values()), name='Data Variable')
    variables_selector.param.watch(on_var_select, parameter_names=['value'])
    selected_var = [key for key, value in mapping_var_names.items() if value == variables_selector.value]
    plot_plot = plot(selected_var, title=variables_selector.value)
    plot_container = pn.Column(pn.Row(variables_selector), plot_plot, sizing_mode='scale_both') # , sizing_mode='scale_both'
    main_app = pn.Row(plot_container, Spacer(width=10))
    main_app.servable()

Thank you so much for any help!