import numpy as np
import xarray as xr
from copy import deepcopy
from astrohack.io.image_mds import AstrohackImageFile
from astrohack.utils.text import format_angular_distance
from astrohack.antenna.telescope import get_proper_telescope, RingedCassegrain
from astrohack.utils.text import create_dataset_label
from astrohack.utils.conversion import convert_5d_grid_to_stokes
from astrohack.utils.algorithms import phase_wrapping
from astrohack.utils.zernike_aperture_fitting import fit_zernike_coefficients
from astrohack.utils.imaging import (
calculate_far_field_aperture,
calculate_near_field_aperture,
)
from astrohack.utils.gridding import grid_beam
from astrohack.utils.imaging import parallactic_derotation
from astrohack.utils.phase_fitting import (
clic_like_phase_fitting,
skip_phase_fitting,
aips_like_phase_fitting,
)
import toolviper.utils.logger as logger
[docs]
def process_holog_chunk(holog_chunk_params: dict, output_mds: AstrohackImageFile):
"""Process chunk holography data along the antenna axis. Works with holography file to properly grid , normalize,
average and correct data and returns the aperture pattern.
Args:
holog_chunk_params (dict): Dictionary containing holography parameters.
output_mds: Output mds object
"""
ant_key = holog_chunk_params["this_ant"]
ddi_key = holog_chunk_params["this_ddi"]
ant_ddi_xdt = holog_chunk_params["xdt_data"]
label = create_dataset_label(ant_key, ddi_key, separator=",")
logger.info(f"Processing {label}")
summary = deepcopy(ant_ddi_xdt["map_0"].attrs["summary"])
convert_to_stokes = holog_chunk_params["to_stokes"]
user_grid_size = holog_chunk_params["grid_size"]
if user_grid_size is None:
grid_size = np.array(summary["beam"]["grid size"])
elif isinstance(user_grid_size, int):
grid_size = np.array([user_grid_size, user_grid_size])
elif isinstance(user_grid_size, (list, np.ndarray)):
grid_size = np.array(user_grid_size)
else:
raise TypeError(
f"Don't know what due with grid size of type {type(user_grid_size)}"
)
logger.info(
f"{label}: Using a grid of {grid_size[0]} by {grid_size[1]} pixels for the beam"
)
user_cell_size = holog_chunk_params["cell_size"]
if user_cell_size is None:
cell_size = np.array(
[-summary["beam"]["cell size"], summary["beam"]["cell size"]]
)
elif isinstance(user_cell_size, (int, float)):
cell_size = np.array([-user_cell_size, user_cell_size])
elif isinstance(user_cell_size, (list, np.ndarray)):
cell_size = np.array(user_cell_size)
else:
raise TypeError(
f"Don't know what due with cell size of type {type(user_cell_size)}"
)
logger.info(
f"{label}: Using a cell size of {format_angular_distance(cell_size[0])} by "
f"{format_angular_distance(cell_size[1])} for the beam"
)
telescope = get_proper_telescope(
summary["general"]["telescope name"], summary["general"]["antenna name"]
)
try:
is_near_field = ant_ddi_xdt["map_0"].attrs["near_field"]
except KeyError:
is_near_field = False
(
beam_grid,
time_centroid,
freq_axis,
pol_axis,
l_axis,
m_axis,
grid_corr,
summary,
) = grid_beam(
ant_ddi_xdt=ant_ddi_xdt,
grid_size=grid_size,
sky_cell_size=cell_size,
avg_chan=holog_chunk_params["chan_average"],
chan_tol_fac=holog_chunk_params["chan_tolerance_factor"],
telescope=telescope,
grid_interpolation_mode=holog_chunk_params["grid_interpolation_mode"],
observation_summary=summary,
label=label,
)
if not is_near_field:
beam_grid = parallactic_derotation(
data=beam_grid, parallactic_angle_dict=ant_ddi_xdt
)
if holog_chunk_params["scan_average"]:
beam_grid = np.mean(beam_grid, axis=0)[None, ...]
time_centroid = np.mean(np.array(time_centroid))
# Current bottleneck
if is_near_field:
distance, focus_offset = telescope.station_distance_dict[
holog_chunk_params["alma_osf_pad"]
]
aperture_grid, u_axis, v_axis, _, used_wavelength = (
calculate_near_field_aperture(
grid=beam_grid,
sky_cell_size=holog_chunk_params["cell_size"],
distance=distance,
freq=freq_axis,
padding_factor=holog_chunk_params["padding_factor"],
focus_offset=focus_offset,
telescope=telescope,
apply_grid_correction=grid_corr,
label=label,
)
)
else:
focus_offset = 0
aperture_grid, u_axis, v_axis, _, used_wavelength = (
calculate_far_field_aperture(
grid=beam_grid,
padding_factor=holog_chunk_params["padding_factor"],
freq=freq_axis,
telescope=telescope,
sky_cell_size=cell_size,
apply_grid_correction=grid_corr,
label=label,
)
)
zernike_n_order = holog_chunk_params["zernike_n_order"]
zernike_coeffs, zernike_model, zernike_rms, osa_coeff_list = (
fit_zernike_coefficients(
aperture_grid, u_axis, v_axis, zernike_n_order, telescope
)
)
orig_pol_axis = pol_axis
if convert_to_stokes:
beam_grid = convert_5d_grid_to_stokes(beam_grid, pol_axis)
aperture_grid = convert_5d_grid_to_stokes(aperture_grid, pol_axis)
pol_axis = ["I", "Q", "U", "V"]
amplitude, phase, u_prime, v_prime = _crop_and_split_aperture(
aperture_grid, u_axis, v_axis, telescope
)
phase_fit_engine = holog_chunk_params["phase_fit_engine"]
if phase_fit_engine == "perturbations" and not isinstance(
telescope, RingedCassegrain
):
logger.warning(
f"Pertubation phase fitting is not supported for {telescope.name}, changing phase fit engine to"
f" zernike"
)
phase_fit_engine = "zernike"
if phase_fit_engine is None or phase_fit_engine == "none":
phase_corrected_angle, phase_fit_results = skip_phase_fitting(label, phase)
else:
if is_near_field:
phase_corrected_angle, phase_fit_results = clic_like_phase_fitting(
phase, freq_axis, telescope, focus_offset, u_prime, v_prime, label
)
else:
if phase_fit_engine == "perturbations":
phase_corrected_angle, phase_fit_results = aips_like_phase_fitting(
amplitude,
phase,
pol_axis,
freq_axis,
telescope,
u_axis,
v_axis,
holog_chunk_params["phase_fit_control"],
label,
)
elif phase_fit_engine == "zernike":
if zernike_n_order > 4:
logger.warning(
"Using a Zernike order > 4 for phase fitting may result in overfitting"
)
if convert_to_stokes:
zernike_grid = convert_5d_grid_to_stokes(
zernike_model, orig_pol_axis
)
else:
zernike_grid = zernike_model.copy()
_, zernike_phase, _, _ = _crop_and_split_aperture(
zernike_grid, u_axis, v_axis, telescope
)
phase_corrected_angle = phase_wrapping(
np.where(np.isfinite(zernike_phase), phase - zernike_phase, phase)
)
phase_fit_results = None
else:
logger.error(f"Unsupported phase fitting engine: {phase_fit_engine}")
raise ValueError
summary["aperture"] = _get_aperture_summary(
u_axis, v_axis, _compute_aperture_resolution(l_axis, m_axis, used_wavelength)
)
_export_to_xds(
beam_grid,
aperture_grid,
amplitude,
phase_corrected_angle,
ant_key,
time_centroid,
ddi_key,
phase_fit_results,
pol_axis,
freq_axis,
l_axis,
m_axis,
u_axis,
v_axis,
u_prime,
v_prime,
orig_pol_axis,
osa_coeff_list,
zernike_coeffs,
zernike_model,
zernike_rms,
zernike_n_order,
summary,
output_mds,
)
logger.info(f"Finished processing {label}")
def _crop_and_split_aperture(aperture_grid, u_axis, v_axis, telescope, scaling=1.5):
# Default scaling factor is now 1.5 to allow for better analysis of the noise around the aperture.
# This will probably mean no cropping for most apertures, but may be important if dish appears too small in the
# aperture.
max_aperture_radius = 0.5 * telescope.diameter
image_slice = aperture_grid[0, 0, 0, ...]
center_pixel = np.array(image_slice.shape[0:2]) // 2
radius_u = int(
np.where(np.abs(u_axis) < max_aperture_radius * scaling)[0].max()
- center_pixel[0]
)
radius_v = int(
np.where(np.abs(v_axis) < max_aperture_radius * scaling)[0].max()
- center_pixel[1]
)
if radius_v > radius_u:
radius = radius_v
else:
radius = radius_u
start_cut = center_pixel - radius
end_cut = center_pixel + radius
amplitude = np.absolute(
aperture_grid[..., start_cut[0] : end_cut[0], start_cut[1] : end_cut[1]]
)
phase = np.angle(
aperture_grid[..., start_cut[0] : end_cut[0], start_cut[1] : end_cut[1]]
)
return (
amplitude,
phase,
u_axis[start_cut[0] : end_cut[0]],
v_axis[start_cut[1] : end_cut[1]],
)
def _compute_aperture_resolution(l_axis, m_axis, wavelength):
# Here we compute the aperture resolution from Equation 7 In EVLA memo 212
# https://library.nrao.edu/public/memos/evla/EVLAM_212.pdf
deltal = np.max(l_axis) - np.min(l_axis)
deltam = np.max(m_axis) - np.min(m_axis)
aperture_resolution = np.array([1 / deltal, 1 / deltam])
aperture_resolution *= 1.27 * wavelength
return aperture_resolution
def _export_to_xds(
beam_grid,
aperture_grid,
amplitude,
phase_corrected_angle,
ant_key,
time_centroid,
ddi_key,
phase_fit_results,
pol_axis,
freq_axis,
l_axis,
m_axis,
u_axis,
v_axis,
u_prime,
v_prime,
orig_pol_axis,
osa_coeff_list,
zernike_coeffs,
zernike_model,
zernike_rms,
zernike_n_order,
summary,
output_mds: AstrohackImageFile,
):
# Todo: Add Parallactic angle as a non-dimension coordinate dependant on time.
xds = xr.Dataset()
xds["BEAM"] = xr.DataArray(beam_grid, dims=["time", "chan", "pol", "l", "m"])
xds["APERTURE"] = xr.DataArray(
aperture_grid, dims=["time", "chan", "pol", "u", "v"]
)
xds["AMPLITUDE"] = xr.DataArray(
amplitude, dims=["time", "chan", "pol", "u_prime", "v_prime"]
)
xds["CORRECTED_PHASE"] = xr.DataArray(
phase_corrected_angle, dims=["time", "chan", "pol", "u_prime", "v_prime"]
)
xds["ZERNIKE_COEFFICIENTS"] = xr.DataArray(
zernike_coeffs, dims=["time", "chan", "orig_pol", "osa"]
)
xds["ZERNIKE_MODEL"] = xr.DataArray(
zernike_model, dims=["time", "chan", "orig_pol", "u", "v"]
)
xds["ZERNIKE_FIT_RMS"] = xr.DataArray(
zernike_rms, dims=["time", "chan", "orig_pol"]
)
xds.attrs["time_centroid"] = np.array(time_centroid)
xds.attrs["phase_fitting"] = phase_fit_results
xds.attrs["zernike_N_order"] = zernike_n_order
xds.attrs["summary"] = summary
xds.attrs["ddi"] = ddi_key
coords = {
"orig_pol": orig_pol_axis,
"pol": pol_axis,
"l": l_axis,
"m": m_axis,
"u": u_axis,
"v": v_axis,
"u_prime": u_prime,
"v_prime": v_prime,
"chan": freq_axis,
"osa": osa_coeff_list,
}
xds = xds.assign_coords(coords)
output_mds.add_node(xds, [ant_key, ddi_key])
def _get_aperture_summary(u_axis, v_axis, aperture_resolution):
aperture_dict = {
"grid size": [u_axis.shape[0], v_axis.shape[0]],
"cell size": [u_axis[1] - u_axis[0], v_axis[1] - v_axis[0]],
"resolution": aperture_resolution.tolist(),
}
return aperture_dict