# Copyright (c) 2021. Lucas G. S. Jeub
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
from pathlib import Path
from dask.distributed import Client, LocalCluster, as_completed
from dask import delayed
from runpy import run_path
from local2global_embedding.run.scripts import functions as func
from local2global_embedding.run.utils import load_data, ScriptParser, watch_progress
[docs]
def run(name='LANL', data_root=None, data_opts={'protocol': 'TCP'}, dims=(2,),
output='.', scale=True, alignment_type='temporal', alignment_window=14, use_median=True, cluster_init=True):
model = 'SVD'
cluster_init_path = Path().home() / '.config' / 'dask' / 'cluster_init.py'
if cluster_init and cluster_init_path.is_file():
kwargs = run_path(cluster_init_path)
cluster = kwargs['cluster']
else:
cluster = LocalCluster()
with Client(cluster) as client:
print(client.dashboard_link)
output = Path(output)
data = load_data(name, root=data_root, **data_opts)
n_patches = len(data.timesteps)
patch_folder_name = '_'.join([name, model] + [f'{key}={value}' for key, value in data_opts.items()])
patch_folder = output / patch_folder_name
patch_folder.mkdir(parents=True, exist_ok=True)
patches_s = []
patches_t = []
all_tasks = []
for d in dims:
for index in range(n_patches):
patches = delayed(func.svd_patches)(data=data, index=index, output_folder=patch_folder, dim=d).persist()
patches_s.append(patches[0].persist())
patches_t.append(patches[1].persist())
if alignment_type == 'temporal':
error_s = client.submit(func.temporal_align_errors, patches=patches_s,
scale=scale,
output_file=patch_folder / f'source_temporal_alignment_errors.npy')
all_tasks.append(error_s)
error_t = client.submit(func.temporal_align_errors, patches=patches_t,
scale=scale,
output_file=patch_folder / f'dest_temporal_alignment_errors.npy')
all_tasks.append(error_t)
elif alignment_type == 'windowed':
error_s = client.submit(func.windowed_align_errors, patches=patches_s,
window=alignment_window, scale=scale, use_median=use_median,
output_file=patch_folder / f'source_alignment_errors_window={alignment_window}.npy')
all_tasks.append(error_s)
error_t = client.submit(func.windowed_align_errors, patches=patches_t,
window=alignment_window, scale=scale, use_median=use_median,
output_file=patch_folder / f'dest_alignment_errors_window={alignment_window}.npy')
all_tasks.append(error_t)
elif alignment_type == 'global':
error_s = client.submit(func.global_align_errors, patches=patches_s, window=alignment_window, scale=scale,
output_file=patch_folder / f'source_global_alignment_errors_window={alignment_window}.npy')
all_tasks.append(error_s)
error_t = client.submit(func.global_align_errors, patches=patches_t, window=alignment_window, scale=scale,
output_file=patch_folder / f'dest_global_alignment_errors_window={alignment_window}.npy')
all_tasks.append(error_t)
all_tasks.append(client.submit(func.leave_out_z_score_errors, error_file=error_s))
all_tasks.append(client.submit(func.leave_out_z_score_errors, error_file=error_t))
del error_t
del error_s
all_tasks = as_completed(all_tasks)
watch_progress(all_tasks)
if __name__ == '__main__':
ScriptParser(run).run()