Animagen2 image transformed

Inspired by https://twitter.com/marlene_zw/status/1509882018968047618) I would like to try out creating a small data app for images using Panel

"""Inspired by https://twitter.com/marlene_zw/status/1509882018968047618"""
from io import BytesIO
from typing import List

import holoviews as hv
import numpy as np
import panel as pn
import PIL
import torch
from PIL import Image

pn.extension(sizing_mode="stretch_width", template="fast")

ACCENT = "#974794"

pn.state.template.param.update(
    site="Awesome Panel",
    title="bryandlee/animegan2-pytorch",
    accent_base_color=ACCENT,
    header_background=ACCENT,
)

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Source: https://github.com/bryandlee/animegan2-pytorch
model = torch.hub.load("bryandlee/animegan2-pytorch:main", "generator", device=DEVICE).eval()
face2paint = torch.hub.load("bryandlee/animegan2-pytorch:main", "face2paint", device=DEVICE)

file_update = pn.widgets.FileInput(accept="png,jpg,jpeg", multiple=True)
loading_indicator = pn.indicators.LoadingSpinner(
    visible=False, value=True, height=25, width=25, sizing_mode="fixed"
)
output = pn.Column()

layout = pn.Column(
    "## Input: .png or .jpg portrait",
    file_update,
    "## Output: Transformed portraits",
    output,
    loading_indicator,
).servable()


def bokeh_hook(plot, element):  # pylint: disable=unused-argument
    """Turn off axis ticks and label"""
    plot.state.xaxis.major_tick_line_color = None
    plot.state.xaxis.minor_tick_line_color = None
    plot.state.xaxis.major_label_text_font_size = "0pt"
    plot.state.yaxis.major_tick_line_color = None
    plot.state.yaxis.minor_tick_line_color = None
    plot.state.yaxis.major_label_text_font_size = "0pt"


def to_hv_rgb(image: PIL.Image) -> hv.RGB:
    """Returns a HoloViews RGB containing the image"""
    image_array = np.array(image)
    return hv.RGB(image_array).opts(
        labelled=[],
        hooks=[bokeh_hook],
        default_tools=["save", "pan", "wheel_zoom", "box_zoom", "reset"],
        min_height=500,
        responsive=True,
    )


def transform_and_show(uploaded: List[bytes]):
    """Transforms the uploaded images"""
    loading_indicator.visible = True

    output.clear()

    for bytes_in in uploaded:
        im_in = Image.open(BytesIO(bytes_in)).convert("RGB")
        im_out = face2paint(model, im_in, side_by_side=False)

        wi1 = to_hv_rgb(im_in)
        wi2 = to_hv_rgb(im_out)

        output.append(wi1 + wi2)

    loading_indicator.visible = False


pn.bind(transform_and_show, file_update.param.value, watch=True)
1 Like

@philippjfr told me that Panel can display PIL.image directly without using hv.RGB, so here is a simplified version, which does not resize/ stretches the images though.

"""Inspired by https://twitter.com/marlene_zw/status/1509882018968047618"""
from io import BytesIO
from typing import List

import panel as pn
import torch
from PIL import Image

pn.extension(sizing_mode="stretch_width", template="fast")

ACCENT = "#974794"

pn.state.template.param.update(
    site="Awesome Panel",
    title="bryandlee/animegan2-pytorch",
    accent_base_color=ACCENT,
    header_background=ACCENT,
)

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Source: https://github.com/bryandlee/animegan2-pytorch
model = torch.hub.load("bryandlee/animegan2-pytorch:main", "generator", device=DEVICE).eval()
face2paint = torch.hub.load("bryandlee/animegan2-pytorch:main", "face2paint", device=DEVICE)

file_update = pn.widgets.FileInput(accept="png,jpg,jpeg", multiple=True)
loading_indicator = pn.indicators.LoadingSpinner(
    visible=False, value=True, height=25, width=25, sizing_mode="fixed"
)
output = pn.Column()

layout = pn.Column(
    "## Input: .png or .jpg portrait",
    file_update,
    "## Output: Transformed portraits",
    output,
    loading_indicator,
).servable()


def transform_and_show(uploaded: List[bytes]):
    """Transforms the uploaded images"""
    loading_indicator.visible = True

    output.clear()

    for bytes_in in uploaded:
        im_in = Image.open(BytesIO(bytes_in)).convert("RGB")
        im_out = face2paint(model, im_in, side_by_side=False)

        output.append(pn.Row(im_in, im_out))

    loading_indicator.visible = False


pn.bind(transform_and_show, file_update.param.value, watch=True)
5 Likes

Super cool!

2 Likes