# Copyright (c) 2021. Lucas G. S. Jeub
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import json
from statistics import mean, stdev
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from local2global_embedding.run.utils import ScriptParser, ResultsDict, load_data
from local2global_embedding.utils import flatten
def _extract_error(data, key):
err = None
if key == 'acc_mean':
err = data.get('acc_std')
if err is not None:
err = np.asarray(err).flatten()
return err
def _normalise_data(data):
data = np.asarray(data)
data = data.flatten()
return data
[docs]
def mean_and_deviation(data):
data = [flatten(v) for v in data]
data_mean = [mean(v) for v in data]
data_std = [stdev(v) for v in data]
return data_mean, data_std
[docs]
def plot_with_errorbars(x, y_mean, y_err, fmt='-', **kwargs):
opts = dict(elinewidth=0.5, capthick=0.5, capsize=3)
opts["fmt"] = fmt
opts.update(kwargs)
_, cap, ebar = plt.errorbar(_normalise_data(x), _normalise_data(y_mean),
yerr=_normalise_data(y_err), **opts)
if ebar:
cap[0].set_alpha(0.5)
cap[1].set_alpha(0.5)
ebar[0].set_alpha(0.5)
ebar[0].set_linestyle(fmt)
[docs]
def plot(data, key, baseline_data=None, nt_data=None, rotate_data=None, translate_data=None):
fig = plt.figure()
if baseline_data is not None and key in baseline_data:
d_mean, d_err = mean_and_deviation(baseline_data[key])
plot_with_errorbars(baseline_data['dims'], d_mean, d_err, label='full', marker='o', color='tab:blue', zorder=4)
d_mean, d_err = mean_and_deviation(data[key])
plot_with_errorbars(data['dims'], d_mean, d_err, fmt='-',
label='l2g', marker='>', color='tab:red', zorder=5)
if rotate_data is not None and key in rotate_data:
d_mean, d_err = mean_and_deviation(rotate_data[key])
plot_with_errorbars(rotate_data['dims'], d_mean, d_err, fmt='--', marker='s', markersize=3,
label='rotate-only', color='tab:orange', linewidth=0.5, zorder=3)
if translate_data is not None and key in translate_data:
d_mean, d_err = mean_and_deviation(translate_data[key])
plot_with_errorbars(translate_data['dims'], d_mean, d_err, fmt='-.', marker='d', markersize=3,
label='translate-only', color='tab:purple', linewidth=0.5, zorder=2)
if nt_data is not None and key in nt_data:
d_mean, d_err = mean_and_deviation(nt_data[key])
plot_with_errorbars(nt_data['dims'], d_mean, d_err, fmt=':',
label='no-l2g', color='tab:pink', linewidth=0.5, zorder=1)
plt.xscale('log')
plt.xticks(data['dims'], data['dims'])
plt.minorticks_off()
if key == 'auc':
plt.ylim(0.48, 1.02)
plt.xlabel('embedding dimension')
if key == 'auc':
plt.ylabel('AUC')
elif key == 'acc':
plt.ylabel('classification accuracy')
plt.legend(ncol=3, frameon=False)
return fig
[docs]
def plot_all(folder=None):
"""
Plot results
Args:
folder: results folder (default: CWD)
"""
if folder is None:
folder = Path.cwd()
else:
folder = Path(folder)
for file in folder.glob('**/**/*_l2g_scale_eval.json'):
print(file)
with open(file) as f:
data = json.load(f)
base_name_parts = file.name.split('_hc', 1)
if len(base_name_parts) > 1:
base_name = base_name_parts[0] + '_' + base_name_parts[1].split('_', 1)[1]
else:
base_name = base_name_parts[0]
baseline = folder / file.parent.name / base_name.replace('_l2g_', '_full_').replace('_scale', '')
if baseline.is_file():
baseline_data = ResultsDict(baseline)
baseline_data.reduce_to_dims(data['dims'])
else:
baseline_data = None
nt = file.with_name(base_name.replace('_scale_', '_norotate_notranslate_'))
if nt.is_file():
with open(nt) as f:
nt_data = json.load(f)
else:
nt_data = None
rotate = file.with_name(base_name.replace('_scale_', '_notranslate_'))
if rotate.is_file():
with open(rotate) as f:
rotate_data = json.load(f)
else:
rotate_data = None
translate = file.with_name(base_name.replace('_scale_', "_norotate_"))
if translate.is_file():
with open(translate) as f:
translate_data = json.load(f)
else:
translate_data = None
name = file.stem.split('_', 1)[0]
network_data = load_data(name)
all_edges = network_data.num_edges
patch_files = list(file.parents[1].glob('patch*_index.npy'))
patch_edges = sum(network_data.subgraph(np.load(patch_file)).num_edges
for patch_file in patch_files)
oversampling_ratio = patch_edges / all_edges
num_labels = network_data.y.max().item() + 1
title = f"oversampling ratio: {oversampling_ratio:.2}, #patches: {len(patch_files)}"
if 'auc' in data:
fig = plot(data, 'auc', baseline_data, nt_data, rotate_data, translate_data)
ax = fig.gca()
ax.set_title(title)
ax.set_ylim(0.48, 1.02)
fig.savefig(file.with_name(file.name.replace('.json', '_auc.pdf')))
if 'acc' in data:
fig = plot(data, 'acc', baseline_data, nt_data, rotate_data, translate_data)
ax = fig.gca()
ax.set_title(title)
ax.set_ylim(0.98/num_labels, 1.02)
fig.savefig(file.with_name(file.name.replace('.json', '_cl.pdf')))
if __name__ == '__main__':
ScriptParser(plot_all).run()