Source code for astrohack.core.combine

from copy import deepcopy

import numpy as np
import xarray as xr

import toolviper.utils.logger as logger

from astrohack.io.image_mds import AstrohackImageFile
from astrohack.utils.text import create_dataset_label
from astrohack.utils.constants import clight
from scipy.interpolate import griddata
from astrohack.utils.text import param_to_list


[docs] def process_combine_chunk(combine_chunk_params: dict, output_mds: AstrohackImageFile): """ Process a combine chunk Args: combine_chunk_params: Param dictionary for combine chunk output_mds: output mds file that contains combined data """ ant_key = combine_chunk_params["this_ant"] ant_xdt = combine_chunk_params["xdt_data"] user_ddi_sel = combine_chunk_params["ddi"] ddi_list = param_to_list(user_ddi_sel, ant_xdt, "ddi") dataset_label = create_dataset_label(ant_key, None) nddi = len(ddi_list) if nddi == 0: logger.warning(f"Nothing to process for {ant_key}") return elif nddi == 1: ddi_key = ddi_list[0] if ddi_key in list(ant_xdt.keys()): logger.info( f"{dataset_label} has a single ddi to be combined, data copied from input file" ) output_mds.add_node(ant_xdt[ddi_key].dataset, [ant_key, ddi_key]) else: logger.warning( f"{dataset_label} has no {ddi_key}, nothing to process for this antenna" ) return else: ddi_in_xdt_list = list(ant_xdt.keys()) ddi_present_list = [ddi_key in ddi_in_xdt_list for ddi_key in ddi_list] if np.sum(ddi_present_list) == 0: logger.warning( f"{dataset_label} has no valid DDI in user selection (ddi = {user_ddi_sel})" ) return min_freq = 1e34 ddi_ref_key = None summary_dict = {} for i_ddi, ddi_key in enumerate(ddi_list): if not ddi_present_list[i_ddi]: continue summary = ant_xdt[ddi_key].attrs["summary"] summary_dict[ddi_key] = summary rep_freq = summary["spectral"]["rep. frequency"] if rep_freq < min_freq: min_freq = rep_freq ddi_ref_key = ddi_key out_xds = deepcopy(ant_xdt[ddi_ref_key].dataset) shape = list(out_xds["CORRECTED_PHASE"].values.shape) if out_xds.sizes["chan"] != 1: msg = f"Only single channel holographies supported" logger.error(msg) raise RuntimeError(msg) nmap = shape[0] if nmap != 1: msg = f"Only single mapping holographies supported" logger.error(msg) raise RuntimeError(msg) npol = shape[2] npoints = shape[3] * shape[4] amp_sum = np.zeros((npol, npoints)) pha_sum = np.zeros((npol, npoints)) u_mesh, v_mesh = np.meshgrid(out_xds.u_prime.values, out_xds.v_prime.values) dest_u_axis = u_mesh.ravel() dest_v_axis = v_mesh.ravel() for i_ddi, ddi_key in enumerate(ddi_list): this_dataset_label = create_dataset_label(ant_key, ddi_key) if not ddi_present_list[i_ddi]: logger.warning( f"{this_dataset_label} does not exist in input mds, skipping" ) continue this_xds = ant_xdt[ddi_key].dataset u_mesh, v_mesh = np.meshgrid( this_xds.u_prime.values, this_xds.v_prime.values, ) loca_u_axis = u_mesh.ravel() loca_v_axis = v_mesh.ravel() if loca_u_axis.shape[0] == dest_u_axis.shape[0]: resample_needed = not ( np.allclose(loca_u_axis, dest_u_axis, rtol=1e-6) and np.allclose(loca_v_axis, dest_v_axis, rtol=1e-6) ) else: resample_needed = True if resample_needed: logger.info(f"Regridding {this_dataset_label}") for ipol in range(npol): thispha = ( this_xds["CORRECTED_PHASE"].values[0, 0, ipol, :, :].ravel() ) thisamp = this_xds["AMPLITUDE"].values[0, 0, ipol, :, :].ravel() repha = griddata( (loca_u_axis, loca_v_axis), thispha, (dest_u_axis, dest_v_axis), method="linear", ) reamp = griddata( (loca_u_axis, loca_v_axis), thisamp, (dest_u_axis, dest_v_axis), method="linear", ) amp_sum[ipol, :] += reamp if combine_chunk_params["weighted"]: pha_sum[ipol, :] += repha * reamp else: pha_sum[ipol, :] += repha else: logger.info( f"{this_dataset_label} already has the proper sampling, simple addition" ) for ipol in range(npol): thispha = ( this_xds["CORRECTED_PHASE"].values[0, 0, ipol, :, :].ravel() ) thisamp = this_xds["AMPLITUDE"].values[0, 0, ipol, :, :].ravel() amp_sum[ipol, :] += thisamp if combine_chunk_params["weighted"]: pha_sum[ipol, :] += thispha * thisamp else: pha_sum[ipol, :] += thispha n_used_ddi = np.sum(ddi_present_list) if combine_chunk_params["weighted"]: phase = pha_sum / amp_sum else: phase = pha_sum / n_used_ddi amplitude = amp_sum / n_used_ddi out_xds["AMPLITUDE"] = xr.DataArray( amplitude.reshape(shape), dims=["time", "chan", "pol", "u_prime", "v_prime"] ) out_xds["CORRECTED_PHASE"] = xr.DataArray( phase.reshape(shape), dims=["time", "chan", "pol", "u_prime", "v_prime"] ) out_ddi_key = "ddi_99" out_xds.attrs["ddi"] = out_ddi_key out_xds.attrs["summary"] = _merge_summary_dict(summary_dict, ddi_ref_key) output_mds.add_node(out_xds, [ant_key, out_ddi_key])
def _merge_summary_dict(summary_dict, ddi_ref_key): out_summary = deepcopy(summary_dict[ddi_ref_key]) aperture_resolution = out_summary["aperture"]["resolution"] frequency_range = out_summary["spectral"]["frequency range"] channel_width = 0.0 rep_freq = 0.0 n_used_ddi = 0 for ddi_key, ddi_summary in summary_dict.items(): # Spectral part n_used_ddi += 1 channel_width += ddi_summary["spectral"]["channel width"] rep_freq += ddi_summary["spectral"]["rep. frequency"] frequency_range[0] = float( np.min([frequency_range[0], ddi_summary["spectral"]["frequency range"][0]]) ) frequency_range[1] = float( np.max([frequency_range[1], ddi_summary["spectral"]["frequency range"][1]]) ) for i_coord in range(2): aperture_resolution[i_coord] = float( np.max( [ aperture_resolution[i_coord], ddi_summary["aperture"]["resolution"][i_coord], ] ) ) rep_freq /= n_used_ddi out_summary["aperture"]["resolution"] = aperture_resolution out_summary["spectral"]["frequency range"] = frequency_range out_summary["spectral"]["rep. frequency"] = rep_freq out_summary["spectral"]["channel width"] = channel_width out_summary["spectral"]["rep. wavelength"] = clight / rep_freq return out_summary