Source code for local2global_embedding.run.umap_plot

#  Copyright (c) 2021. Lucas G. S. Jeub
#
#  Permission is hereby granted, free of charge, to any person obtaining a copy
#  of this software and associated documentation files (the "Software"), to deal
#  in the Software without restriction, including without limitation the rights
#  to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
#  copies of the Software, and to permit persons to whom the Software is
#  furnished to do so, subject to the following conditions:
#
#  The above copyright notice and this permission notice shall be included in all
#  copies or substantial portions of the Software.
#
#  THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
#  IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
#  FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
#  AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
#  LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
#  OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
#  SOFTWARE.
from pathlib import Path
from typing import Optional
from datetime import datetime
from functools import partial

import umap
import datashader as ds
import datashader.transfer_functions as tf
from datashader.mpl_ext import dsshow, alpha_colormap

import pandas as pd
import seaborn as sns
import numpy as np
import matplotlib.pyplot as plt
import torch

from local2global_embedding.run.utils import ScriptParser, load_classification_problem


rng = np.random.default_rng()


[docs] def get_ax_size(ax): fig = ax.figure bbox = ax.get_window_extent().transformed(fig.dpi_scale_trans.inverted()) width, height = bbox.width, bbox.height width *= fig.dpi height *= fig.dpi return width, height
[docs] def plot_embedding(filename, name, mmap_mode: Optional[str] = None, max_points=500000, restrict_lcc=False, pointsize=5, size=2.0, dpi=1200, data_root='/tmp', min_dist=0.0, metric='euclidean', verbose=True): filename = Path(filename) print(f'loading data started at {datetime.now()}') cl = load_classification_problem(name, restrict_lcc=restrict_lcc, root=data_root) print(f'classificaton problem loaded at {datetime.now()}') fig = plt.figure(figsize=(size, size), dpi=dpi) ax = fig.add_axes([0, 0, 1, 1]) ax_size = size*dpi pad = 2*pointsize / ax_size y = np.asanyarray(cl.y) nodes = np.flatnonzero(y >= 0) if len(nodes) > max_points: nodes = rng.choice(nodes, size=(max_points,), replace=False) if filename.suffix == '.pt': coords = np.asanyarray(torch.load(filename, map_location='cpu')) else: coords = np.load(filename, mmap_mode=mmap_mode)[nodes] print(f'embedding loaded at {datetime.now()}') vc = umap.UMAP(min_dist=min_dist, metric=metric, verbose=verbose).fit_transform(coords) min_range = vc.min(axis=0) max_range = vc.max(axis=0) pad = (max_range-min_range) * pad x_range = (min_range[0]-pad[0], max_range[0]+pad[0]) y_range = (min_range[1]-pad[1], max_range[1]+pad[1]) df = pd.DataFrame(vc, columns=['x', 'y']) df['label'] = y[nodes] df['label'] = df['label'].astype('category') colors = sns.color_palette('husl', cl.num_labels) colors = {i: tuple(int(vi * 255) for vi in v) for i, v in enumerate(colors)} dsshow(df, ds.Point('x', 'y'), ds.count_cat('label'), ax=ax, norm='eq_hist', color_key=colors, shade_hook=partial(tf.dynspread, threshold=0.99, max_px=pointsize, shape='circle'), alpha_range=(55, 255), x_range=x_range, y_range=y_range) ax.set_axis_off() plt.margins(0.01, 0.01) plt.savefig(filename.with_suffix('.png'), dpi=dpi)
if __name__ == '__main__': ScriptParser(plot_embedding).run()