import xarray as xr
from matplotlib import patches
import toolviper.utils.logger as logger
from astrohack.antenna.telescope import get_proper_telescope
from astrohack.utils.algorithms import (
data_statistics,
phase_wrapping,
)
from astrohack.utils.constants import *
from astrohack.utils.conversion import to_db
from astrohack.utils.conversion import convert_unit
from astrohack.utils.text import (
add_prefix,
bool_to_str,
format_frequency,
format_value_unit,
string_to_ascii_file,
create_dataset_label,
statistics_to_text,
lnbr,
)
from astrohack.visualization.plot_tools import (
create_figure_and_axes,
close_figure,
simple_imshow_map_plot,
get_proper_color_map,
well_positioned_colorbar,
compute_extent,
)
from astrohack.utils.fits import (
write_fits,
put_resolution_in_fits_header,
put_axis_in_fits_header,
)
[docs]
SUPPORTED_POL_STATES = ["I", "RR", "LL", "XX", "YY"]
[docs]
class AntennaSurface:
def __init__(
self,
inputxds,
clip_type="sigma",
clip_level=3,
pmodel="rigid",
nan_out_of_bounds=True,
panel_margins=0.05,
reread=False,
pol_state="I",
patch_phase=False,
use_detailed_mask=True,
):
"""
Antenna Surface description capable of computing RMS, Gains, and fitting the surface to obtain screw adjustments
Args:
inputxds: Input xarray dataset
clip_type: Type of clipping to be applied to amplitude
clip_level: Level of clipping
pmodel: model of panel surface fitting, if is None defaults to telescope default
nan_out_of_bounds: Should the region outside the dish be replaced with NaNs?
panel_margins: Margin to be ignored at edges of panels when fitting, defaults to 20% if None
reread: Read a previously processed holography
pol_state: Polarization state to select
patch_phase: Phase data from inputxds needs to be wrapped to -pi to pi interval
use_detailed_mask: use detailed mask (arms shadows, ngvla outer panels)
"""
[docs]
self.corrections = None
[docs]
self.phase_corrections = None
[docs]
self.phase_residuals = None
[docs]
self.pol_state = pol_state
# Read the data in an XDS
self._read_xds(inputxds)
if patch_phase:
self.phase = phase_wrapping(self.phase)
self._create_aperture_mask(clip_type, clip_level, use_detailed_mask)
self.deviation = self.telescope.phase_to_deviation(
self.u_axis, self.v_axis, self.mask, self.phase, self.wavelength
)
[docs]
self.panels = self.telescope.build_panel_list(pmodel, panel_margins)
if not self.reread:
self.panelmodel = pmodel
self.panel_margins = panel_margins
self.panel_distribution = self.telescope.attribute_pixels_to_panels(
self.panels,
self.u_axis,
self.v_axis,
self.rad,
self.phi,
self.deviation,
self.mask,
)
if nan_out_of_bounds:
self._nan_out_of_bounds()
def _read_holog_xds(self, inputxds):
if "chan" in inputxds.dims:
if inputxds.sizes["chan"] != 1:
raise RuntimeError("Only single channel holographies supported")
self.wavelength = clight / inputxds.chan.values[0]
else:
self.wavelength = inputxds.attrs["wavelength"]
if self.pol_state not in inputxds.coords["pol"]:
msg = (
f"Polarization state {self.pol_state} is not present in the data (available states: "
f'{inputxds.coords["pol"]})'
)
logger.error(msg)
raise ValueError(msg)
self.amplitude = (
inputxds["AMPLITUDE"].sel(pol=self.pol_state).isel(time=0, chan=0).values
)
self.phase = (
inputxds["CORRECTED_PHASE"]
.sel(pol=self.pol_state)
.isel(time=0, chan=0)
.values
)
self.npoint = np.sqrt(inputxds.sizes["l"] ** 2 + inputxds.sizes["m"] ** 2)
self.amp_unit = "V"
self.u_axis = inputxds.u_prime.values
self.v_axis = inputxds.v_prime.values
self.computephase = False
def _read_panel_xds(self, inputxds):
self.wavelength = inputxds.attrs["wavelength"]
self.amp_unit = inputxds.attrs["amp_unit"]
self.panelmodel = inputxds.attrs["panel_model"]
self.panel_margins = inputxds.attrs["panel_margin"]
self.clip = inputxds.attrs["clip"]
self.solved = inputxds.attrs["solved"]
self.fitted = inputxds.attrs["fitted"]
if "pol_state" not in inputxds.attrs:
# Here I assume no one was doing panel fitting on something that is not Stokes I
self.pol_state = "I"
else:
self.pol_state = inputxds.attrs["pol_state"]
# Arrays
self.amplitude = inputxds["AMPLITUDE"].values
self.phase = inputxds["PHASE"].values
self.deviation = inputxds["DEVIATION"].values
self.mask = inputxds["MASK"].values
self.u_axis = inputxds.u.values
self.v_axis = inputxds.v.values
self.panel_distribution = inputxds["PANEL_DISTRIBUTION"].values
self.amplitude_noise = inputxds["AMP_NOISE"].values
if self.solved:
self.panel_fallback = inputxds["PANEL_FALLBACK"].values
self.panel_model_array = inputxds["PANEL_MODEL"].values
self.phase_residuals = inputxds["PHASE_RESIDUALS"].values
self.residuals = inputxds["RESIDUALS"].values
self.phase_corrections = inputxds["PHASE_CORRECTIONS"].values
self.corrections = inputxds["CORRECTIONS"].values
self.panel_pars = inputxds["PANEL_PARAMETERS"].values
self.screw_adjustments = inputxds["PANEL_SCREWS"].values
self.ingains = [
inputxds.attrs["input_gain"],
inputxds.attrs["theoretical_gain"],
]
self.ougains = [
inputxds.attrs["output_gain"],
inputxds.attrs["theoretical_gain"],
]
self.panel_labels = inputxds.labels.values
def _read_xds(self, inputxds):
"""
Read input XDS, the reading function depending on if it is a reread or a new processing
Args:
inputxds: X array dataset
"""
# Origin dependant reading
if self.reread:
self._read_panel_xds(inputxds)
else:
self._read_holog_xds(inputxds)
# Common elements
self.summary = inputxds.attrs["summary"]
self.antenna_name = inputxds.attrs["summary"]["general"]["antenna name"]
self.resolution = inputxds.summary["aperture"]["resolution"]
self.ddi = inputxds.attrs["ddi"]
self.label = create_dataset_label(self.antenna_name, inputxds.attrs["ddi"])
self.telescope = get_proper_telescope(
self.summary["general"]["telescope name"], self.antenna_name
)
def _define_amp_clip(self, clip_type, clip_level):
self.amplitude_noise = np.where(self.base_mask, np.nan, self.amplitude)
if clip_type is None or clip_type == "none":
clip = np.nanmin(self.amplitude)
elif clip_type == "relative":
clip = clip_level * np.nanmax(self.amplitude)
elif clip_type == "absolute":
clip = clip_level
elif clip_type == "sigma":
noise_stats = data_statistics(self.amplitude_noise)
clip = noise_stats["mean"] + clip_level * noise_stats["rms"]
elif clip_type == "noise_threshold":
clip = self._compute_noise_threshold_clip(clip_level)
else:
msg = f"Unrecognized clipping type: {clip_type}"
raise ValueError(msg)
return clip
def _compute_noise_threshold_clip(self, threshold, step_multiplier=0.95):
noise_stats = data_statistics(self.amplitude_noise)
in_disk = np.where(self.rad < self.telescope.diameter / 2.0, 1.0, np.nan)
in_disk = np.where(
self.rad < self.telescope.inner_radial_limit, np.nan, in_disk
)
n_in_disk = np.nansum(in_disk)
in_disk_amp = in_disk * self.amplitude
fraction_in = 0
clip = noise_stats["max"]
multiplier = 1.0
while fraction_in < threshold:
clip *= multiplier
data_in = np.where(in_disk_amp > clip, 1.0, 0.0)
fraction_in = np.sum(data_in) / n_in_disk
multiplier *= step_multiplier
return clip
def _create_aperture_mask(self, clip_type, clip_level, use_detailed_mask):
self.base_mask, self.rad, self.phi = self.telescope.create_aperture_mask(
self.u_axis,
self.v_axis,
use_detailed_mask=use_detailed_mask,
return_polar_meshes=True,
use_outer_limit=True,
)
if self.reread:
pass
else:
self.clip = self._define_amp_clip(clip_type, clip_level)
self.mask = np.where(self.amplitude < self.clip, False, self.base_mask)
self.mask = np.where(~np.isfinite(self.phase), False, self.mask)
def _nan_out_of_bounds(self):
self.phase = np.where(self.base_mask, self.phase, np.nan)
self.amplitude = np.where(self.base_mask, self.amplitude, np.nan)
self.deviation = np.where(self.base_mask, self.deviation, np.nan)
def _fetch_panel_ringed(self, ring, panel):
"""
Fetch a panel object from the panel list using its ring and panel numbers,
specific for circular antennas with panels arranged in rings
Args:
ring: Ring number
panel: Panel number
Returns:
Panel object
"""
if ring == 1:
ipanel = panel - 1
else:
ipanel = np.sum(self.telescope.n_panel_per_ring[: ring - 1]) + panel - 1
return self.panels[ipanel]
[docs]
def gains(self):
"""
Computes antenna gains in decibels before and after panel surface fitting
Returns:
Gains before panel fitting OR Gains before and after panel fitting
"""
self.ingains = self.gain_at_wavelength(False, self.wavelength)
if not self.solved:
return self.ingains
else:
self.ougains = self.gain_at_wavelength(True, self.wavelength)
return self.ingains, self.ougains
[docs]
def gain_at_wavelength(self, corrected, wavelength):
# This is valid for the VLA not sure if valid for anything else...
wavelength_scaling = self.wavelength / wavelength
dish_mask = self.base_mask
if corrected:
if self.fitted:
scaled_phase = wavelength_scaling * self.phase_residuals
else:
msg = (
"Cannot computed gains for corrected dish if panels are not fitted."
)
logger.error(msg)
raise RuntimeError(msg)
else:
scaled_phase = wavelength_scaling * self.phase
cossum = np.nansum(np.cos(scaled_phase[dish_mask]))
sinsum = np.nansum(np.sin(scaled_phase[dish_mask]))
real_factor = np.sqrt(cossum**2 + sinsum**2) / np.sum(dish_mask)
theo_gain = fourpi * self.telescope.diameter / wavelength
real_gain = theo_gain * real_factor
return to_db(real_gain), to_db(theo_gain)
[docs]
def get_rms(self, unit="mm"):
"""
Computes antenna surface RMS before and after panel surface fitting
Returns:
RMS before panel fitting OR RMS before and after panel fitting
"""
fac = convert_unit("m", unit, "length")
self.in_rms = self._compute_rms_array(self.deviation)
if self.residuals is None:
return fac * self.in_rms
else:
self.out_rms = self._compute_rms_array(self.residuals)
return fac * self.in_rms, fac * self.out_rms
def _compute_rms_array(self, array):
"""
Factorized the computation of the RMS of an array
Args:
array: Input data array
Returns:
RMS of the input array
"""
return np.sqrt(np.nanmean(array[self.mask] ** 2))
[docs]
def fit_surface(self):
"""
Loops over the panels to fit the panel surfaces
"""
panels = []
for panel in self.panels:
if not panel.solve():
panels.append(panel.label)
self.fitted = True
if len(panels) > 0:
msg = (
f"{self.label}: Fit failed with the {self.panelmodel} model and a simple mean has been used instead "
f"for the following panels:"
)
logger.warning(msg)
msg = str(panels)
logger.warning(msg)
[docs]
def correct_surface(self):
"""
Apply corrections determined by the panel surface fitting methods to the antenna surface
"""
if not self.fitted:
raise RuntimeError("Panels must be fitted before atempting a correction")
self.corrections = np.where(self.mask, 0, np.nan)
for panel in self.panels:
corrections = panel.get_corrections()
for corr in corrections:
ix, iy = int(corr[0]), int(corr[1])
self.corrections[ix, iy] = -corr[-1]
self.residuals = self.deviation + self.corrections
self.phase_corrections = self.telescope.deviation_to_phase(
self.u_axis, self.v_axis, self.mask, self.corrections, self.wavelength
)
self.phase_residuals = self.telescope.deviation_to_phase(
self.u_axis, self.v_axis, self.mask, self.residuals, self.wavelength
)
self._build_panel_data_arrays()
self.solved = True
[docs]
def print_misc(self):
"""
Print miscelaneous information on the panels in the antenna surface
"""
for panel in self.panels:
panel.print_misc()
[docs]
def plot_mask(self, basename, caller, parm_dict):
"""
Plot mask used in the selection of points to be fitted
Args:
basename: basename for the plot, the prefix 'ancillary_mask' will be added to it
caller: Which mds called this plotting function
parm_dict: dictionary with plotting parameters
"""
plotmask = np.where(self.mask, 1, np.nan)
plotname = add_prefix(basename, f"{caller}_mask")
parm_dict["z_lim"] = [0, 1]
parm_dict["unit"] = " "
self._plot_map(plotname, plotmask, "Mask", parm_dict)
[docs]
def plot_amplitude(self, basename, caller, parm_dict):
"""
Plot Amplitude map
Args:
basename: basename for the plot, the prefix 'ancillary_amplitude' will be added to it
caller: Which mds called this plotting function
parm_dict: dictionary with plotting parameters
"""
if (
parm_dict["amplitude_limits"] is None
or parm_dict["amplitude_limits"] == "None"
):
parm_dict["z_lim"] = np.nanmin(self.amplitude), np.nanmax(self.amplitude)
else:
parm_dict["z_lim"] = parm_dict["amplitude_limits"]
amp_stats = data_statistics(np.where(self.mask, self.amplitude, np.nan))
noise_stats = data_statistics(self.amplitude_noise)
title = (
"Amplitude, "
+ statistics_to_text(amp_stats)
+ lnbr
+ "Noise, "
+ statistics_to_text(noise_stats)
)
plotname = add_prefix(basename, f"{caller}_amplitude")
parm_dict["unit"] = self.amp_unit
self._plot_map(plotname, self.amplitude, title, parm_dict)
[docs]
def plot_phase(self, basename, caller, parm_dict):
"""
Plot phase map(s)
Args:
basename: basename for the plot(s), the prefix 'phase_{original|corrections|residuals}' will be added to
it/them
caller: Which mds called this plotting function
parm_dict: dictionary with plotting parameters
"""
if parm_dict["phase_unit"] is None:
parm_dict["unit"] = "deg"
else:
parm_dict["unit"] = parm_dict["phase_unit"]
parm_dict["z_lim"] = parm_dict["phase_limits"]
fac = convert_unit("rad", parm_dict["unit"], "trigonometric")
prefix = "phase"
if caller == "image":
prefix = "corrected"
maps = [self.phase]
labels = ["phase"]
else:
if self.residuals is None:
maps = [self.phase]
labels = ["original"]
else:
maps = [self.phase, self.phase_corrections, self.phase_residuals]
labels = ["original", "correction", "residual"]
self._multi_plot(maps, labels, prefix, basename, fac, parm_dict, caller)
[docs]
def plot_deviation(self, basename, caller, parm_dict):
"""
Plot deviation map(s)
Args:
basename: basename for the plot(s), the prefix 'deviation_{original|corrections|residuals}' will be added
to it/them
caller: Which mds called this plotting function
parm_dict: dictionary with plotting parameters
"""
if parm_dict["deviation_unit"] is None:
parm_dict["unit"] = "mm"
else:
parm_dict["unit"] = parm_dict["deviation_unit"]
parm_dict["z_lim"] = parm_dict["deviation_limits"]
fac = convert_unit("m", parm_dict["unit"], "length")
prefix = "deviation"
rms = self.get_rms(unit=parm_dict["unit"])
if caller == "image":
prefix = "original"
maps = [self.deviation]
labels = ["deviation"]
else:
if self.residuals is None:
maps = [self.deviation]
labels = [f'original RMS={rms:.2f} {parm_dict["unit"]}']
else:
maps = [self.deviation, self.corrections, self.residuals]
labels = [
f'original RMS={rms[0]:.2f} {parm_dict["unit"]}',
"correction",
f"residual RMS={rms[1]:.2f} " f'{parm_dict["unit"]}',
]
self._multi_plot(maps, labels, prefix, basename, fac, parm_dict, caller)
def _multi_plot(self, maps, labels, prefix, basename, factor, parm_dict, caller):
if len(maps) != len(labels):
raise ValueError("Map list and label list must be of the same size")
nplots = len(maps)
if parm_dict["z_lim"] is None or parm_dict["z_lim"] == "None":
vmax = np.nanmax(
np.abs(factor * maps[0])
) # Gotten from the original map (displays the biggest variation)
parm_dict["z_lim"] = [-vmax, vmax]
for iplot in range(nplots):
title = f"{prefix.capitalize()} {labels[iplot]}"
plotname = add_prefix(basename, labels[iplot].split()[0])
plotname = add_prefix(plotname, prefix)
plotname = add_prefix(plotname, caller)
self._plot_map(plotname, factor * maps[iplot], title, parm_dict)
def _plot_map(self, filename, data, title, parm_dict, add_colorbar=True):
cmap = parm_dict["colormap"]
fig, ax = create_figure_and_axes(parm_dict["figure_size"], [1, 1])
simple_imshow_map_plot(
ax,
fig,
self.u_axis,
self.v_axis,
np.where(self.mask, data, np.nan),
title,
cmap,
parm_dict["z_lim"],
z_label="Z Scale [" + parm_dict["unit"] + "]",
add_colorbar=add_colorbar,
)
self._add_resolution_to_plot(ax)
ax.set_xlabel("X axis [m]")
ax.set_ylabel("Y axis [m]")
for panel in self.panels:
panel.plot(
ax, screws=parm_dict["plot_screws"], label=parm_dict["panel_labels"]
)
suptitle = f"{self.label}, Pol. state: {self.pol_state}"
close_figure(fig, suptitle, filename, parm_dict["dpi"], parm_dict["display"])
def _add_resolution_to_plot(self, ax, xpos=0.9, ypos=0.1):
lw = 0.5
if self.resolution is None:
return
minx = self.u_axis[0]
miny = self.v_axis[1]
dx = self.u_axis[-1] - minx
dy = self.v_axis[-1] - miny
center = (minx + xpos * dx, miny + ypos * dy)
resolution = patches.Ellipse(
center,
self.resolution[0],
self.resolution[1],
angle=0.0,
linewidth=lw,
color="black",
zorder=2,
fill=False,
)
ax.add_patch(resolution)
halfbeam = self.resolution / dy / 2
ax.axvline(
x=center[0],
ymin=ypos - halfbeam[1],
ymax=ypos + halfbeam[1],
color="black",
lw=lw / 2,
)
ax.axhline(
y=center[1],
xmin=xpos - halfbeam[0],
xmax=xpos + halfbeam[0],
color="black",
lw=lw / 2,
)
[docs]
def plot_screw_adjustments(self, filename, parm_dict):
"""
Plot screw adjustments as circles over a blank canvas with the panel layout
Args:
filename: Name of the output filename for the plot
parm_dict: Dictionary with plotting parameters
"""
unit = parm_dict["unit"]
threshold = parm_dict["threshold"]
cmap = get_proper_color_map(parm_dict["colormap"], default_cmap="RdBu_r")
fig, ax = create_figure_and_axes(parm_dict["figure_size"], [1, 1])
fac = convert_unit("m", unit, "length")
vmax = np.nanmax(np.abs(fac * self.screw_adjustments))
vmin = -vmax
if threshold is None or threshold == "None":
threshold = 0.1 * vmax
else:
threshold = np.abs(threshold)
ax.set_title(f"\nThreshold = {threshold:.2f} {unit}", fontsize="small")
# set the limits of the plot to the limits of the data
extent = compute_extent(self.u_axis, self.v_axis)
im = ax.imshow(
np.full_like(self.deviation, fill_value=np.nan),
cmap=cmap,
interpolation="nearest",
extent=extent,
vmin=vmin,
vmax=vmax,
)
self._add_resolution_to_plot(ax)
colorbar = well_positioned_colorbar(
ax, fig, im, "Screw adjustments [" + unit + "]"
)
if threshold > 0:
line = threshold
while line < vmax:
colorbar.ax.axhline(y=line, color="black", linestyle="-", lw=0.2)
colorbar.ax.axhline(y=-line, color="black", linestyle="-", lw=0.2)
line += threshold
ax.set_xlabel("X axis [m]")
ax.set_ylabel("Y axis [m]")
for ipanel in range(len(self.panels)):
self.panels[ipanel].plot(ax, screws=False, label=parm_dict["panel_labels"])
self.panels[ipanel].plot_corrections(
ax, cmap, fac * self.screw_adjustments[ipanel], threshold, vmin, vmax
)
suptitle = f"{self.label}, Pol. state: {self.pol_state}"
close_figure(fig, suptitle, filename, parm_dict["dpi"], parm_dict["display"])
def _build_panel_data_arrays(self):
"""
Build arrays with data from the panels so that they can be stored on the XDS
Returns:
List with panel labels, panel fitting parameters, screw_adjustments
"""
npanels = len(self.panels)
# First panel might fail hence we need to check npar for all panels
max_par = 0
for panel in self.panels:
p_npar = panel.model.npar
if p_npar > max_par:
max_par = p_npar
nscrews = self.panels[0].screws.shape[0]
self.panel_labels = np.ndarray([npanels], dtype="U22")
self.panel_model_array = np.ndarray([npanels], dtype="U22")
self.panel_pars = np.full((npanels, max_par), np.nan, dtype=float)
self.screw_adjustments = np.ndarray((npanels, nscrews), dtype=float)
self.panel_fallback = np.ndarray([npanels], dtype=bool)
for ipanel in range(npanels):
self.panel_labels[ipanel] = self.panels[ipanel].label
self.panel_pars[ipanel, :] = self.panels[ipanel].model.parameters
self.screw_adjustments[ipanel, :] = self.panels[ipanel].export_screws(
unit="m"
)
self.panel_model_array[ipanel] = self.panels[ipanel].model_name
self.panel_fallback[ipanel] = self.panels[ipanel].fall_back_fit
[docs]
def export_screws(self, filename, unit="mm", comment_char="#"):
"""
Export screw adjustments for all panels onto an ASCII file
Args:
filename: ASCII file name/path
unit: unit for panel screw adjustments ['mm','miliinches']
comment_char: Character used for comments
"""
outfile = f"# Screw adjustments for {self.telescope.name}'s {self.label}, pol. state {self.pol_state}\n"
freq = clight / self.wavelength
rmses = self.get_rms(unit)
outfile += f"# Frequency = {format_frequency(freq)}{lnbr}"
if unit == "mm":
outfile += f"# Antenna surface RMS before adjustment: {format_value_unit(rmses[0], unit)}\n"
outfile += f"# Antenna surface RMS after adjustment: {format_value_unit(rmses[1], unit)}\n"
else:
mmrms = self.get_rms("mm")
outfile += (
f"# Antenna surface RMS before adjustment: {format_value_unit(rmses[0], unit)} or "
f'{format_value_unit(mmrms[0], "mm")}\n'
)
outfile += (
f"# Antenna surface RMS after adjustment: {format_value_unit(rmses[1], unit)} or "
f'{format_value_unit(mmrms[1], "mm")}\n'
)
outfile += "# Lower means away from subreflector" + lnbr
outfile += "# Raise means toward the subreflector" + lnbr
outfile += "# LOWER the panel if the number is POSITIVE" + lnbr
outfile += "# RAISE the panel if the number is NEGATIVE" + lnbr
outfile += "# Adjustments are in " + unit + lnbr
outfile += lnbr
spc = " "
outfile += f"{comment_char} Panel{2*spc}"
nscrews = len(self.telescope.screw_description)
for screw in self.telescope.screw_description:
outfile += f"{4*spc}{screw:2s}{4*spc}"
outfile += f"Fallback{4*spc}Model{lnbr}"
fac = convert_unit("m", unit, "length")
for ipanel in range(len(self.panel_labels)):
outfile += "{0:>5s}".format(self.panel_labels[ipanel])
for iscrew in range(nscrews):
outfile += " {0:>9.2f}".format(
fac * self.screw_adjustments[ipanel, iscrew]
)
outfile += (
f"{5*spc}{bool_to_str(self.panel_fallback[ipanel]):>3s}{7*spc}{self.panel_model_array[ipanel]}"
+ lnbr
)
string_to_ascii_file(outfile, filename)
[docs]
def export_xds(self):
"""
Export all the data to Xarray dataset
Returns:
XarrayDataSet containing all the relevant information
"""
xds = xr.Dataset()
gains = self.gains()
rms = self.get_rms(unit="m")
xds.attrs["ddi"] = self.ddi
xds.attrs["wavelength"] = self.wavelength
xds.attrs["amp_unit"] = self.amp_unit
xds.attrs["panel_model"] = self.panelmodel
xds.attrs["panel_margin"] = self.panel_margins
xds.attrs["clip"] = self.clip
xds.attrs["solved"] = self.solved
xds.attrs["fitted"] = self.fitted
xds.attrs["aperture_resolution"] = self.resolution
xds.attrs["pol_state"] = self.pol_state
xds.attrs["summary"] = self.summary
xds["AMPLITUDE"] = xr.DataArray(self.amplitude, dims=["u", "v"])
xds["PHASE"] = xr.DataArray(self.phase, dims=["u", "v"])
xds["DEVIATION"] = xr.DataArray(self.deviation, dims=["u", "v"])
xds["MASK"] = xr.DataArray(self.mask, dims=["u", "v"])
xds["PANEL_DISTRIBUTION"] = xr.DataArray(
self.panel_distribution, dims=["u", "v"]
)
xds["AMP_NOISE"] = xr.DataArray(self.amplitude_noise, dims=["u", "v"])
xds["RADIUS"] = xr.DataArray(self.rad, dims=["u", "v"])
coords = {"u": self.u_axis, "v": self.v_axis}
if self.solved:
xds["PHASE_RESIDUALS"] = xr.DataArray(self.phase_residuals, dims=["u", "v"])
xds["RESIDUALS"] = xr.DataArray(self.residuals, dims=["u", "v"])
xds["PHASE_CORRECTIONS"] = xr.DataArray(
self.phase_corrections, dims=["u", "v"]
)
xds["CORRECTIONS"] = xr.DataArray(self.corrections, dims=["u", "v"])
xds.attrs["input_rms"] = rms[0]
xds.attrs["output_rms"] = rms[1]
xds.attrs["input_gain"] = gains[0][0]
xds.attrs["output_gain"] = gains[1][0]
xds.attrs["theoretical_gain"] = gains[0][1]
xds["PANEL_PARAMETERS"] = xr.DataArray(
self.panel_pars, dims=["labels", "pars"]
)
xds["PANEL_SCREWS"] = xr.DataArray(
self.screw_adjustments, dims=["labels", "screws"]
)
xds["PANEL_MODEL"] = xr.DataArray(self.panel_model_array, dims=["labels"])
xds["PANEL_FALLBACK"] = xr.DataArray(self.panel_fallback, dims=["labels"])
coords = {
**coords,
"labels": self.panel_labels,
"screws": self.telescope.screw_description,
"pars": np.arange(self.panel_pars.shape[1]),
}
else:
xds.attrs["input_rms"] = rms
xds.attrs["input_gain"] = gains[0]
xds.attrs["theoretical_gain"] = gains[1]
xds = xds.assign_coords(coords)
return xds
[docs]
def export_to_fits(self, basename):
"""
Data to export: Amplitude, mask, phase, phase_corrections, phase_residuals, deviations, deviation_corrections,
deviation_residuals
conveniently all data are on the same grid!
Returns:
"""
head = {
"PMODEL": self.panelmodel,
"PMARGIN": self.panel_margins,
"CLIP": self.clip,
"TELESCOP": self.antenna_name,
"INSTRUME": self.telescope.name,
"WAVELENG": self.wavelength,
"FREQUENC": clight / self.wavelength,
}
head = put_axis_in_fits_header(head, self.u_axis, 1, "X----LIN", "m")
head = put_axis_in_fits_header(head, self.v_axis, 2, "Y----LIN", "m")
head = put_resolution_in_fits_header(head, self.resolution)
write_fits(
head,
"Amplitude",
self.amplitude,
add_prefix(basename, "amplitude") + ".fits",
self.amp_unit,
"panel",
)
write_fits(
head,
"Mask",
np.where(self.mask, 1.0, np.nan),
add_prefix(basename, "mask") + ".fits",
"",
"panel",
)
write_fits(
head,
"Original Phase",
self.phase,
add_prefix(basename, "phase_original") + ".fits",
"rad",
"panel",
)
write_fits(
head,
"Phase Corrections",
self.phase_corrections,
add_prefix(basename, "phase_correction") + ".fits",
"rad",
"panel",
)
write_fits(
head,
"Phase residuals",
self.phase_residuals,
add_prefix(basename, "phase_residual") + ".fits",
"rad",
"panel",
)
write_fits(
head,
"Original Deviation",
self.deviation,
add_prefix(basename, "deviation_original") + ".fits",
"m",
"panel",
)
write_fits(
head,
"Deviation Corrections",
self.corrections,
add_prefix(basename, "deviation_correction") + ".fits",
"m",
"panel",
)
write_fits(
head,
"Deviation residuals",
self.residuals,
add_prefix(basename, "deviation_residual") + ".fits",
"m",
"panel",
)