Inspired by https://twitter.com/marlene_zw/status/1509882018968047618) I would like to try out creating a small data app for images using Panel
"""Inspired by https://twitter.com/marlene_zw/status/1509882018968047618"""
from io import BytesIO
from typing import List
import holoviews as hv
import numpy as np
import panel as pn
import PIL
import torch
from PIL import Image
pn.extension(sizing_mode="stretch_width", template="fast")
ACCENT = "#974794"
pn.state.template.param.update(
site="Awesome Panel",
title="bryandlee/animegan2-pytorch",
accent_base_color=ACCENT,
header_background=ACCENT,
)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# Source: https://github.com/bryandlee/animegan2-pytorch
model = torch.hub.load("bryandlee/animegan2-pytorch:main", "generator", device=DEVICE).eval()
face2paint = torch.hub.load("bryandlee/animegan2-pytorch:main", "face2paint", device=DEVICE)
file_update = pn.widgets.FileInput(accept="png,jpg,jpeg", multiple=True)
loading_indicator = pn.indicators.LoadingSpinner(
visible=False, value=True, height=25, width=25, sizing_mode="fixed"
)
output = pn.Column()
layout = pn.Column(
"## Input: .png or .jpg portrait",
file_update,
"## Output: Transformed portraits",
output,
loading_indicator,
).servable()
def bokeh_hook(plot, element): # pylint: disable=unused-argument
"""Turn off axis ticks and label"""
plot.state.xaxis.major_tick_line_color = None
plot.state.xaxis.minor_tick_line_color = None
plot.state.xaxis.major_label_text_font_size = "0pt"
plot.state.yaxis.major_tick_line_color = None
plot.state.yaxis.minor_tick_line_color = None
plot.state.yaxis.major_label_text_font_size = "0pt"
def to_hv_rgb(image: PIL.Image) -> hv.RGB:
"""Returns a HoloViews RGB containing the image"""
image_array = np.array(image)
return hv.RGB(image_array).opts(
labelled=[],
hooks=[bokeh_hook],
default_tools=["save", "pan", "wheel_zoom", "box_zoom", "reset"],
min_height=500,
responsive=True,
)
def transform_and_show(uploaded: List[bytes]):
"""Transforms the uploaded images"""
loading_indicator.visible = True
output.clear()
for bytes_in in uploaded:
im_in = Image.open(BytesIO(bytes_in)).convert("RGB")
im_out = face2paint(model, im_in, side_by_side=False)
wi1 = to_hv_rgb(im_in)
wi2 = to_hv_rgb(im_out)
output.append(wi1 + wi2)
loading_indicator.visible = False
pn.bind(transform_and_show, file_update.param.value, watch=True)