"""Utility functions for plotting dimension reduced embeddings using plotly."""
from __future__ import annotations
import logging
from typing import TYPE_CHECKING
import pandas as pd
import plotly.graph_objects as go
from plotly.subplots import make_subplots
if TYPE_CHECKING:
from pathlib import Path
logger = logging.getLogger(__name__)
[docs]
def update_layout(
fig: go.Figure,
title: str,
num_row: int = 6,
num_col: int = 3,
width: float = 1200,
height: float = 1800,
) -> go.Figure:
"""Update layout of a plotly figure."""
# set axis
for i in range(1, num_row + 1):
for j in range(1, num_col + 1):
fig.update_xaxes(
showticklabels=False,
linecolor="black",
showline=True,
linewidth=1,
mirror=True,
row=i,
col=j,
)
fig.update_yaxes(
showticklabels=False,
linecolor="black",
showline=True,
linewidth=1,
mirror=True,
row=i,
col=j,
)
# set layout
fig.update_layout(
title=title,
title_x=0.5,
title_font_size=30,
width=width,
height=height,
margin={"l": 10, "r": 10, "t": 80, "b": 50},
paper_bgcolor="rgba(255,255,255,1)",
plot_bgcolor="rgba(255,255,255,1)",
legend={
"orientation": "h",
"yanchor": "bottom",
"xanchor": "center",
"x": 0.5,
"y": -0.04,
"font": {"size": 30},
},
)
return fig
[docs]
def plot_reducers_embeddings(
df_label: pd.DataFrame,
reducers: list[str],
embedding_names: list[str],
embedding_dir: Path,
save_path: Path,
symbol: str = "circle",
title: str = "Embedding Visualization",
) -> go.Figure:
"""Plot dimension reduction plots."""
fig = make_subplots(
rows=6,
cols=3,
subplot_titles=[f"{reducer} - {embedding_name}" for embedding_name in embedding_names for reducer in reducers],
vertical_spacing=0.02,
horizontal_spacing=0.02,
)
# update the font size of subplot titles
for annotation in fig.layout.annotations: # type: ignore[union-attr] # plotly stubs incomplete
annotation.update(font={"size": 25})
legend_colors = {
"unlikely": "#D9D9D9",
"interesting": "#22E000",
"missing": "#FF1201",
"standard": "#002FFF",
}
for i, embedding_name in enumerate(embedding_names):
for j, reducer in enumerate(reducers):
logger.info("processing %d %d...", i, j)
embedding_data = pd.read_pickle( # noqa: S301
embedding_dir / f"{reducer}_{embedding_name}.pkl",
)
embedding_data.columns = ["x", "y"]
df_plot = embedding_data.join(df_label)
df_plot = df_plot.sample(frac=1, random_state=42)
fig.add_trace(
go.Scatter(
x=df_plot["x"],
y=df_plot["y"],
mode="markers",
marker={
"size": 8,
"color": df_plot["label"].map(legend_colors).tolist(), # type: ignore[arg-type] # pandas stubs don't accept Mapping
"opacity": 0.8,
"symbol": symbol,
"line": {"width": 0.5, "color": "DarkSlateGrey"},
},
showlegend=False,
text=df_plot["formula"],
hovertemplate=("<b>%{text}</b><br><br>"),
),
row=i + 1,
col=j + 1,
)
# add legend
for label, color in legend_colors.items():
fig.add_trace(
go.Scatter(
x=[None],
y=[None],
mode="markers",
marker={
"size": 8,
"color": color,
"opacity": 0.8,
"symbol": symbol,
"line": {"width": 0.5, "color": "DarkSlateGrey"},
},
# make only first letter capital
name=label.capitalize(),
showlegend=True,
),
row=1,
col=1,
)
# update layout
fig = update_layout(fig, title=title)
if save_path is not None:
if save_path.suffix == ".html":
fig.write_html(save_path)
else:
fig.write_image(save_path, scale=6)
logger.info("Save to %s", save_path)
return fig