Source code for local2global_embedding.run.plot

#  Copyright (c) 2021. Lucas G. S. Jeub
#
#  Permission is hereby granted, free of charge, to any person obtaining a copy
#  of this software and associated documentation files (the "Software"), to deal
#  in the Software without restriction, including without limitation the rights
#  to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
#  copies of the Software, and to permit persons to whom the Software is
#  furnished to do so, subject to the following conditions:
#
#  The above copyright notice and this permission notice shall be included in all
#  copies or substantial portions of the Software.
#
#  THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
#  IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
#  FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
#  AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
#  LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
#  OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
#  SOFTWARE.

import json
from statistics import mean, stdev

import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from local2global_embedding.run.utils import ScriptParser, ResultsDict, load_data
from local2global_embedding.utils import flatten


def _extract_error(data, key):
    err = None
    if key == 'acc_mean':
        err = data.get('acc_std')
    if err is not None:
        err = np.asarray(err).flatten()
    return err


def _normalise_data(data):
    data = np.asarray(data)
    data = data.flatten()
    return data


[docs] def mean_and_deviation(data): data = [flatten(v) for v in data] data_mean = [mean(v) for v in data] data_std = [stdev(v) for v in data] return data_mean, data_std
[docs] def plot_with_errorbars(x, y_mean, y_err, fmt='-', **kwargs): opts = dict(elinewidth=0.5, capthick=0.5, capsize=3) opts["fmt"] = fmt opts.update(kwargs) _, cap, ebar = plt.errorbar(_normalise_data(x), _normalise_data(y_mean), yerr=_normalise_data(y_err), **opts) if ebar: cap[0].set_alpha(0.5) cap[1].set_alpha(0.5) ebar[0].set_alpha(0.5) ebar[0].set_linestyle(fmt)
[docs] def plot(data, key, baseline_data=None, nt_data=None, rotate_data=None, translate_data=None): fig = plt.figure() if baseline_data is not None and key in baseline_data: d_mean, d_err = mean_and_deviation(baseline_data[key]) plot_with_errorbars(baseline_data['dims'], d_mean, d_err, label='full', marker='o', color='tab:blue', zorder=4) d_mean, d_err = mean_and_deviation(data[key]) plot_with_errorbars(data['dims'], d_mean, d_err, fmt='-', label='l2g', marker='>', color='tab:red', zorder=5) if rotate_data is not None and key in rotate_data: d_mean, d_err = mean_and_deviation(rotate_data[key]) plot_with_errorbars(rotate_data['dims'], d_mean, d_err, fmt='--', marker='s', markersize=3, label='rotate-only', color='tab:orange', linewidth=0.5, zorder=3) if translate_data is not None and key in translate_data: d_mean, d_err = mean_and_deviation(translate_data[key]) plot_with_errorbars(translate_data['dims'], d_mean, d_err, fmt='-.', marker='d', markersize=3, label='translate-only', color='tab:purple', linewidth=0.5, zorder=2) if nt_data is not None and key in nt_data: d_mean, d_err = mean_and_deviation(nt_data[key]) plot_with_errorbars(nt_data['dims'], d_mean, d_err, fmt=':', label='no-l2g', color='tab:pink', linewidth=0.5, zorder=1) plt.xscale('log') plt.xticks(data['dims'], data['dims']) plt.minorticks_off() if key == 'auc': plt.ylim(0.48, 1.02) plt.xlabel('embedding dimension') if key == 'auc': plt.ylabel('AUC') elif key == 'acc': plt.ylabel('classification accuracy') plt.legend(ncol=3, frameon=False) return fig
[docs] def plot_all(folder=None): """ Plot results Args: folder: results folder (default: CWD) """ if folder is None: folder = Path.cwd() else: folder = Path(folder) for file in folder.glob('**/**/*_l2g_scale_eval.json'): print(file) with open(file) as f: data = json.load(f) base_name_parts = file.name.split('_hc', 1) if len(base_name_parts) > 1: base_name = base_name_parts[0] + '_' + base_name_parts[1].split('_', 1)[1] else: base_name = base_name_parts[0] baseline = folder / file.parent.name / base_name.replace('_l2g_', '_full_').replace('_scale', '') if baseline.is_file(): baseline_data = ResultsDict(baseline) baseline_data.reduce_to_dims(data['dims']) else: baseline_data = None nt = file.with_name(base_name.replace('_scale_', '_norotate_notranslate_')) if nt.is_file(): with open(nt) as f: nt_data = json.load(f) else: nt_data = None rotate = file.with_name(base_name.replace('_scale_', '_notranslate_')) if rotate.is_file(): with open(rotate) as f: rotate_data = json.load(f) else: rotate_data = None translate = file.with_name(base_name.replace('_scale_', "_norotate_")) if translate.is_file(): with open(translate) as f: translate_data = json.load(f) else: translate_data = None name = file.stem.split('_', 1)[0] network_data = load_data(name) all_edges = network_data.num_edges patch_files = list(file.parents[1].glob('patch*_index.npy')) patch_edges = sum(network_data.subgraph(np.load(patch_file)).num_edges for patch_file in patch_files) oversampling_ratio = patch_edges / all_edges num_labels = network_data.y.max().item() + 1 title = f"oversampling ratio: {oversampling_ratio:.2}, #patches: {len(patch_files)}" if 'auc' in data: fig = plot(data, 'auc', baseline_data, nt_data, rotate_data, translate_data) ax = fig.gca() ax.set_title(title) ax.set_ylim(0.48, 1.02) fig.savefig(file.with_name(file.name.replace('.json', '_auc.pdf'))) if 'acc' in data: fig = plot(data, 'acc', baseline_data, nt_data, rotate_data, translate_data) ax = fig.gca() ax.set_title(title) ax.set_ylim(0.98/num_labels, 1.02) fig.savefig(file.with_name(file.name.replace('.json', '_cl.pdf')))
if __name__ == '__main__': ScriptParser(plot_all).run()