# Copyright (c) 2021. Lucas G. S. Jeub
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
from pathlib import Path
from typing import Optional
from datetime import datetime
from functools import partial
import umap
import datashader as ds
import datashader.transfer_functions as tf
from datashader.mpl_ext import dsshow, alpha_colormap
import pandas as pd
import seaborn as sns
import numpy as np
import matplotlib.pyplot as plt
import torch
from local2global_embedding.run.utils import ScriptParser, load_classification_problem
rng = np.random.default_rng()
[docs]
def get_ax_size(ax):
fig = ax.figure
bbox = ax.get_window_extent().transformed(fig.dpi_scale_trans.inverted())
width, height = bbox.width, bbox.height
width *= fig.dpi
height *= fig.dpi
return width, height
[docs]
def plot_embedding(filename, name, mmap_mode: Optional[str] = None, max_points=500000, restrict_lcc=False,
pointsize=5,
size=2.0, dpi=1200, data_root='/tmp', min_dist=0.0, metric='euclidean', verbose=True):
filename = Path(filename)
print(f'loading data started at {datetime.now()}')
cl = load_classification_problem(name, restrict_lcc=restrict_lcc, root=data_root)
print(f'classificaton problem loaded at {datetime.now()}')
fig = plt.figure(figsize=(size, size), dpi=dpi)
ax = fig.add_axes([0, 0, 1, 1])
ax_size = size*dpi
pad = 2*pointsize / ax_size
y = np.asanyarray(cl.y)
nodes = np.flatnonzero(y >= 0)
if len(nodes) > max_points:
nodes = rng.choice(nodes, size=(max_points,), replace=False)
if filename.suffix == '.pt':
coords = np.asanyarray(torch.load(filename, map_location='cpu'))
else:
coords = np.load(filename, mmap_mode=mmap_mode)[nodes]
print(f'embedding loaded at {datetime.now()}')
vc = umap.UMAP(min_dist=min_dist, metric=metric, verbose=verbose).fit_transform(coords)
min_range = vc.min(axis=0)
max_range = vc.max(axis=0)
pad = (max_range-min_range) * pad
x_range = (min_range[0]-pad[0], max_range[0]+pad[0])
y_range = (min_range[1]-pad[1], max_range[1]+pad[1])
df = pd.DataFrame(vc, columns=['x', 'y'])
df['label'] = y[nodes]
df['label'] = df['label'].astype('category')
colors = sns.color_palette('husl', cl.num_labels)
colors = {i: tuple(int(vi * 255) for vi in v) for i, v in enumerate(colors)}
dsshow(df, ds.Point('x', 'y'), ds.count_cat('label'), ax=ax, norm='eq_hist', color_key=colors,
shade_hook=partial(tf.dynspread, threshold=0.99, max_px=pointsize, shape='circle'), alpha_range=(55, 255),
x_range=x_range, y_range=y_range)
ax.set_axis_off()
plt.margins(0.01, 0.01)
plt.savefig(filename.with_suffix('.png'), dpi=dpi)
if __name__ == '__main__':
ScriptParser(plot_embedding).run()