Source code for astrohack.utils.graph

import shutil
import dask
import xarray as xr
import toolviper.utils.logger as logger
import copy
import pathlib

from astrohack.utils.text import approve_prefix
from astrohack.utils.text import param_to_list


def _white_list_creation(key_prefix, looping_dict, param_dict):
    exec_list = param_to_list(param_dict[key_prefix], looping_dict, key_prefix)
    white_list = [key for key in exec_list if approve_prefix(key)]
    return white_list


def _add_exec_data_to_param_dict(data_for_exec, param_dict):
    # param_dict is modified in place!
    if isinstance(data_for_exec, xr.DataTree):
        param_dict["xdt_data"] = data_for_exec
    elif isinstance(data_for_exec, xr.Dataset):
        param_dict["xds_data"] = data_for_exec
    elif isinstance(data_for_exec, dict):
        param_dict["dic_data"] = data_for_exec
    else:
        param_dict["unk_data"] = data_for_exec


def _factorized_graph_execution_return(status, ret_list, fetch_ret):
    if fetch_ret:
        if status:
            return status, ret_list
        else:
            return status, None
    else:
        return status


def _construct_general_graph_recursively(
    looping_dict,
    chunk_function,
    param_dict,
    delayed_list,
    key_order,
    output_mds,
    parallel,
    oneup=None,
):
    if len(key_order) == 0:
        _add_exec_data_to_param_dict(looping_dict, param_dict)
        if output_mds is None:
            args = [param_dict]
        else:
            args = [param_dict, output_mds]
        if parallel:
            delayed_list.append(dask.delayed(chunk_function)(*args))
        else:
            delayed_list.append((chunk_function, args))
    else:
        first_key_prefix = key_order[0]
        white_list = _white_list_creation(first_key_prefix, looping_dict, param_dict)

        for item in white_list:
            if item in looping_dict:
                this_param_dict = copy.deepcopy(param_dict)
                this_param_dict[f"this_{first_key_prefix}"] = item

                _construct_general_graph_recursively(
                    looping_dict=looping_dict[item],
                    chunk_function=chunk_function,
                    param_dict=this_param_dict,
                    delayed_list=delayed_list,
                    key_order=key_order[1:],
                    output_mds=output_mds,
                    parallel=parallel,
                    oneup=item,
                )

            else:
                if oneup is None:
                    logger.warning(f"{item} is not present in looping dict")
                else:
                    logger.warning(f"{item} is not present for {oneup}")


[docs] def create_and_execute_graph_from_dict( looping_dict, chunk_function, param_dict, key_order, output_mds=None, fetch_returns=False, ): parallel = param_dict["parallel"] if hasattr(looping_dict, "root"): looping_dict = looping_dict.root if output_mds is not None: output_mds.write(mode="a") # List created here to avoid complicated returns due to recursion. delayed_list = [] _construct_general_graph_recursively( looping_dict=looping_dict, chunk_function=chunk_function, param_dict=param_dict, delayed_list=delayed_list, key_order=key_order, output_mds=output_mds, parallel=parallel, ) if len(delayed_list) == 0: logger.warning(f"List of delayed processing jobs is empty: No data to process") return _factorized_graph_execution_return(False, [], fetch_returns) if parallel: return_list = dask.compute(delayed_list)[0] else: return_list = [] for function, args in delayed_list: return_list.append(function(*args)) if output_mds is not None: output_mds.consolidate(key_order) if len(output_mds.keys()) == 0: logger.warning("Processing did not yield any data") shutil.rmtree(output_mds.filename) return _factorized_graph_execution_return(False, return_list, fetch_returns) else: return _factorized_graph_execution_return(True, return_list, fetch_returns) return _factorized_graph_execution_return(True, return_list, fetch_returns)
def _sub_graph_execution_for_plots( looping_dict, chunk_function, param_dict, key_order, ): lvl1_key_prefix = key_order[0] lvl1_white_list = _white_list_creation(lvl1_key_prefix, looping_dict, param_dict) result_list = [] if len(key_order) == 1: for lvl1_item in lvl1_white_list: if lvl1_item in looping_dict: this_param_dict = copy.deepcopy(param_dict) _add_exec_data_to_param_dict(looping_dict[lvl1_item], this_param_dict) this_param_dict[f"this_{lvl1_key_prefix}"] = lvl1_item result_list.append(chunk_function(this_param_dict)) else: logger.warning(f"{lvl1_item} is not present in looping dict") elif len(key_order) == 2: lvl2_key_prefix = key_order[1] for lvl1_item in lvl1_white_list: if lvl1_item in looping_dict: lvl2_white_list = _white_list_creation( lvl2_key_prefix, looping_dict[lvl1_item], param_dict ) for lvl2_item in lvl2_white_list: if lvl2_item in looping_dict[lvl1_item]: this_param_dict = copy.deepcopy(param_dict) _add_exec_data_to_param_dict( looping_dict[lvl1_item][lvl2_item], this_param_dict ) this_param_dict[f"this_{lvl1_key_prefix}"] = lvl1_item this_param_dict[f"this_{lvl2_key_prefix}"] = lvl2_item result_list.append(chunk_function(this_param_dict)) else: logger.warning( f"{lvl2_item} is not present for {lvl1_item} in looping_dict" ) else: logger.warning(f"{lvl1_item} is not present in looping dict") return result_list
[docs] def create_and_execute_graphs_for_outputs( mds_object, chunk_function, param_dict, key_order, fetch_returns=False, ): """ Dask parallelization exclusively for exports, parallelization is done at the antenna level to decrease graph size \ and optimize plot creation. Args: mds_object: Astrohack MDS object from which to plot chunk_function: Plotting chunk function param_dict: The chunk function parameters key_order: Order in which to execute keys fetch_returns: Return value from chunk function Returns: None """ parallel = param_dict["parallel"] try: pathlib.Path(param_dict["destination"]).mkdir(exist_ok=True) except KeyError: # Observation summary creation case, where a destination folder is not necessary pass n_lvls = len(key_order) if n_lvls == 1: # This can be the case for position_mds where the key depth is only known at runtime. # In this case of 1 level then the usual executioner is to be used. return create_and_execute_graph_from_dict( looping_dict=mds_object, chunk_function=chunk_function, param_dict=param_dict, key_order=key_order, ) # here only the first level of the tree is parallelized looping_dict = mds_object.root if "display" in param_dict.keys(): if param_dict["display"] and param_dict["parallel"]: logger.warning( "Display cannot be True in parallel mode, setting it to False" ) param_dict["display"] = False first_key_prefix = key_order[0] white_list = _white_list_creation(first_key_prefix, looping_dict, param_dict) delayed_list = [] for item in white_list: if item in looping_dict: this_param_dict = copy.deepcopy(param_dict) this_param_dict[f"this_{first_key_prefix}"] = item if parallel: delayed_list.append( dask.delayed(_sub_graph_execution_for_plots)( looping_dict[item], chunk_function, this_param_dict, key_order[1:], ) ) else: delayed_list.append( ( _sub_graph_execution_for_plots, ( looping_dict[item], chunk_function, this_param_dict, key_order[1:], ), ) ) else: logger.warning(f"{item} is not present in looping dict") if parallel: result_list = dask.compute(delayed_list)[0] else: result_list = [] for this_chuk_function, args in delayed_list: result_list.append(this_chuk_function(*args)) return_list = [item for sublist in result_list for item in sublist] return _factorized_graph_execution_return(True, return_list, fetch_returns)
[docs] def compute_graph_from_lists( param_dict, chunk_function, looping_key_list, ): """ Creates and executes a graph based on entries in a parameter dictionary that are lists Args: param_dict: The parameter dictionary chunk_function: The function for the operation chunk looping_key_list: The keys that are lists in the parameter dictionaries over which to loop over Returns: A list containing the returns of the calls to the chunk function. """ parallel = param_dict["parallel"] niter = len(param_dict[looping_key_list[0]]) delayed_list = [] result_list = [] for i_iter in range(niter): this_param = copy.deepcopy(param_dict) for key in looping_key_list: this_param[f"this_{key}"] = param_dict[key][i_iter] if parallel: delayed_list.append(dask.delayed(chunk_function)(dask.delayed(this_param))) else: delayed_list.append(0) result_list.append(chunk_function(this_param)) if parallel: result_list = dask.compute(delayed_list) return result_list