Source code for astrohack.antenna.antenna_surface

import xarray as xr

from matplotlib import patches

import toolviper.utils.logger as logger

from astrohack.antenna.telescope import get_proper_telescope
from astrohack.utils.algorithms import (
    data_statistics,
    phase_wrapping,
)
from astrohack.utils.constants import *
from astrohack.utils.conversion import to_db
from astrohack.utils.conversion import convert_unit
from astrohack.utils.text import (
    add_prefix,
    bool_to_str,
    format_frequency,
    format_value_unit,
    string_to_ascii_file,
    create_dataset_label,
    statistics_to_text,
    lnbr,
)
from astrohack.visualization.plot_tools import (
    create_figure_and_axes,
    close_figure,
    simple_imshow_map_plot,
    get_proper_color_map,
    well_positioned_colorbar,
    compute_extent,
)

from astrohack.utils.fits import (
    write_fits,
    put_resolution_in_fits_header,
    put_axis_in_fits_header,
)

[docs] SUPPORTED_POL_STATES = ["I", "RR", "LL", "XX", "YY"]
[docs] class AntennaSurface: def __init__( self, inputxds, clip_type="sigma", clip_level=3, pmodel="rigid", nan_out_of_bounds=True, panel_margins=0.05, reread=False, pol_state="I", patch_phase=False, use_detailed_mask=True, ): """ Antenna Surface description capable of computing RMS, Gains, and fitting the surface to obtain screw adjustments Args: inputxds: Input xarray dataset clip_type: Type of clipping to be applied to amplitude clip_level: Level of clipping pmodel: model of panel surface fitting, if is None defaults to telescope default nan_out_of_bounds: Should the region outside the dish be replaced with NaNs? panel_margins: Margin to be ignored at edges of panels when fitting, defaults to 20% if None reread: Read a previously processed holography pol_state: Polarization state to select patch_phase: Phase data from inputxds needs to be wrapped to -pi to pi interval use_detailed_mask: use detailed mask (arms shadows, ngvla outer panels) """
[docs] self.reread = reread
[docs] self.phase = None
[docs] self.deviation = None
[docs] self.residuals = None
[docs] self.corrections = None
[docs] self.phase_corrections = None
[docs] self.phase_residuals = None
[docs] self.solved = False
[docs] self.ingains = np.nan
[docs] self.ougains = np.nan
[docs] self.in_rms = np.nan
[docs] self.out_rms = np.nan
[docs] self.fitted = False
[docs] self.pol_state = pol_state
# Read the data in an XDS self._read_xds(inputxds) if patch_phase: self.phase = phase_wrapping(self.phase) self._create_aperture_mask(clip_type, clip_level, use_detailed_mask) self.deviation = self.telescope.phase_to_deviation( self.u_axis, self.v_axis, self.mask, self.phase, self.wavelength )
[docs] self.panels = self.telescope.build_panel_list(pmodel, panel_margins)
if not self.reread: self.panelmodel = pmodel self.panel_margins = panel_margins self.panel_distribution = self.telescope.attribute_pixels_to_panels( self.panels, self.u_axis, self.v_axis, self.rad, self.phi, self.deviation, self.mask, ) if nan_out_of_bounds: self._nan_out_of_bounds() def _read_holog_xds(self, inputxds): if "chan" in inputxds.dims: if inputxds.sizes["chan"] != 1: raise RuntimeError("Only single channel holographies supported") self.wavelength = clight / inputxds.chan.values[0] else: self.wavelength = inputxds.attrs["wavelength"] if self.pol_state not in inputxds.coords["pol"]: msg = ( f"Polarization state {self.pol_state} is not present in the data (available states: " f'{inputxds.coords["pol"]})' ) logger.error(msg) raise ValueError(msg) self.amplitude = ( inputxds["AMPLITUDE"].sel(pol=self.pol_state).isel(time=0, chan=0).values ) self.phase = ( inputxds["CORRECTED_PHASE"] .sel(pol=self.pol_state) .isel(time=0, chan=0) .values ) self.npoint = np.sqrt(inputxds.sizes["l"] ** 2 + inputxds.sizes["m"] ** 2) self.amp_unit = "V" self.u_axis = inputxds.u_prime.values self.v_axis = inputxds.v_prime.values self.computephase = False def _read_panel_xds(self, inputxds): self.wavelength = inputxds.attrs["wavelength"] self.amp_unit = inputxds.attrs["amp_unit"] self.panelmodel = inputxds.attrs["panel_model"] self.panel_margins = inputxds.attrs["panel_margin"] self.clip = inputxds.attrs["clip"] self.solved = inputxds.attrs["solved"] self.fitted = inputxds.attrs["fitted"] if "pol_state" not in inputxds.attrs: # Here I assume no one was doing panel fitting on something that is not Stokes I self.pol_state = "I" else: self.pol_state = inputxds.attrs["pol_state"] # Arrays self.amplitude = inputxds["AMPLITUDE"].values self.phase = inputxds["PHASE"].values self.deviation = inputxds["DEVIATION"].values self.mask = inputxds["MASK"].values self.u_axis = inputxds.u.values self.v_axis = inputxds.v.values self.panel_distribution = inputxds["PANEL_DISTRIBUTION"].values self.amplitude_noise = inputxds["AMP_NOISE"].values if self.solved: self.panel_fallback = inputxds["PANEL_FALLBACK"].values self.panel_model_array = inputxds["PANEL_MODEL"].values self.phase_residuals = inputxds["PHASE_RESIDUALS"].values self.residuals = inputxds["RESIDUALS"].values self.phase_corrections = inputxds["PHASE_CORRECTIONS"].values self.corrections = inputxds["CORRECTIONS"].values self.panel_pars = inputxds["PANEL_PARAMETERS"].values self.screw_adjustments = inputxds["PANEL_SCREWS"].values self.ingains = [ inputxds.attrs["input_gain"], inputxds.attrs["theoretical_gain"], ] self.ougains = [ inputxds.attrs["output_gain"], inputxds.attrs["theoretical_gain"], ] self.panel_labels = inputxds.labels.values def _read_xds(self, inputxds): """ Read input XDS, the reading function depending on if it is a reread or a new processing Args: inputxds: X array dataset """ # Origin dependant reading if self.reread: self._read_panel_xds(inputxds) else: self._read_holog_xds(inputxds) # Common elements self.summary = inputxds.attrs["summary"] self.antenna_name = inputxds.attrs["summary"]["general"]["antenna name"] self.resolution = inputxds.summary["aperture"]["resolution"] self.ddi = inputxds.attrs["ddi"] self.label = create_dataset_label(self.antenna_name, inputxds.attrs["ddi"]) self.telescope = get_proper_telescope( self.summary["general"]["telescope name"], self.antenna_name ) def _define_amp_clip(self, clip_type, clip_level): self.amplitude_noise = np.where(self.base_mask, np.nan, self.amplitude) if clip_type is None or clip_type == "none": clip = np.nanmin(self.amplitude) elif clip_type == "relative": clip = clip_level * np.nanmax(self.amplitude) elif clip_type == "absolute": clip = clip_level elif clip_type == "sigma": noise_stats = data_statistics(self.amplitude_noise) clip = noise_stats["mean"] + clip_level * noise_stats["rms"] elif clip_type == "noise_threshold": clip = self._compute_noise_threshold_clip(clip_level) else: msg = f"Unrecognized clipping type: {clip_type}" raise ValueError(msg) return clip def _compute_noise_threshold_clip(self, threshold, step_multiplier=0.95): noise_stats = data_statistics(self.amplitude_noise) in_disk = np.where(self.rad < self.telescope.diameter / 2.0, 1.0, np.nan) in_disk = np.where( self.rad < self.telescope.inner_radial_limit, np.nan, in_disk ) n_in_disk = np.nansum(in_disk) in_disk_amp = in_disk * self.amplitude fraction_in = 0 clip = noise_stats["max"] multiplier = 1.0 while fraction_in < threshold: clip *= multiplier data_in = np.where(in_disk_amp > clip, 1.0, 0.0) fraction_in = np.sum(data_in) / n_in_disk multiplier *= step_multiplier return clip def _create_aperture_mask(self, clip_type, clip_level, use_detailed_mask): self.base_mask, self.rad, self.phi = self.telescope.create_aperture_mask( self.u_axis, self.v_axis, use_detailed_mask=use_detailed_mask, return_polar_meshes=True, use_outer_limit=True, ) if self.reread: pass else: self.clip = self._define_amp_clip(clip_type, clip_level) self.mask = np.where(self.amplitude < self.clip, False, self.base_mask) self.mask = np.where(~np.isfinite(self.phase), False, self.mask) def _nan_out_of_bounds(self): self.phase = np.where(self.base_mask, self.phase, np.nan) self.amplitude = np.where(self.base_mask, self.amplitude, np.nan) self.deviation = np.where(self.base_mask, self.deviation, np.nan) def _fetch_panel_ringed(self, ring, panel): """ Fetch a panel object from the panel list using its ring and panel numbers, specific for circular antennas with panels arranged in rings Args: ring: Ring number panel: Panel number Returns: Panel object """ if ring == 1: ipanel = panel - 1 else: ipanel = np.sum(self.telescope.n_panel_per_ring[: ring - 1]) + panel - 1 return self.panels[ipanel]
[docs] def gains(self): """ Computes antenna gains in decibels before and after panel surface fitting Returns: Gains before panel fitting OR Gains before and after panel fitting """ self.ingains = self.gain_at_wavelength(False, self.wavelength) if not self.solved: return self.ingains else: self.ougains = self.gain_at_wavelength(True, self.wavelength) return self.ingains, self.ougains
[docs] def gain_at_wavelength(self, corrected, wavelength): # This is valid for the VLA not sure if valid for anything else... wavelength_scaling = self.wavelength / wavelength dish_mask = self.base_mask if corrected: if self.fitted: scaled_phase = wavelength_scaling * self.phase_residuals else: msg = ( "Cannot computed gains for corrected dish if panels are not fitted." ) logger.error(msg) raise RuntimeError(msg) else: scaled_phase = wavelength_scaling * self.phase cossum = np.nansum(np.cos(scaled_phase[dish_mask])) sinsum = np.nansum(np.sin(scaled_phase[dish_mask])) real_factor = np.sqrt(cossum**2 + sinsum**2) / np.sum(dish_mask) theo_gain = fourpi * self.telescope.diameter / wavelength real_gain = theo_gain * real_factor return to_db(real_gain), to_db(theo_gain)
[docs] def get_rms(self, unit="mm"): """ Computes antenna surface RMS before and after panel surface fitting Returns: RMS before panel fitting OR RMS before and after panel fitting """ fac = convert_unit("m", unit, "length") self.in_rms = self._compute_rms_array(self.deviation) if self.residuals is None: return fac * self.in_rms else: self.out_rms = self._compute_rms_array(self.residuals) return fac * self.in_rms, fac * self.out_rms
def _compute_rms_array(self, array): """ Factorized the computation of the RMS of an array Args: array: Input data array Returns: RMS of the input array """ return np.sqrt(np.nanmean(array[self.mask] ** 2))
[docs] def fit_surface(self): """ Loops over the panels to fit the panel surfaces """ panels = [] for panel in self.panels: if not panel.solve(): panels.append(panel.label) self.fitted = True if len(panels) > 0: msg = ( f"{self.label}: Fit failed with the {self.panelmodel} model and a simple mean has been used instead " f"for the following panels:" ) logger.warning(msg) msg = str(panels) logger.warning(msg)
[docs] def correct_surface(self): """ Apply corrections determined by the panel surface fitting methods to the antenna surface """ if not self.fitted: raise RuntimeError("Panels must be fitted before atempting a correction") self.corrections = np.where(self.mask, 0, np.nan) for panel in self.panels: corrections = panel.get_corrections() for corr in corrections: ix, iy = int(corr[0]), int(corr[1]) self.corrections[ix, iy] = -corr[-1] self.residuals = self.deviation + self.corrections self.phase_corrections = self.telescope.deviation_to_phase( self.u_axis, self.v_axis, self.mask, self.corrections, self.wavelength ) self.phase_residuals = self.telescope.deviation_to_phase( self.u_axis, self.v_axis, self.mask, self.residuals, self.wavelength ) self._build_panel_data_arrays() self.solved = True
[docs] def print_misc(self): """ Print miscelaneous information on the panels in the antenna surface """ for panel in self.panels: panel.print_misc()
[docs] def plot_mask(self, basename, caller, parm_dict): """ Plot mask used in the selection of points to be fitted Args: basename: basename for the plot, the prefix 'ancillary_mask' will be added to it caller: Which mds called this plotting function parm_dict: dictionary with plotting parameters """ plotmask = np.where(self.mask, 1, np.nan) plotname = add_prefix(basename, f"{caller}_mask") parm_dict["z_lim"] = [0, 1] parm_dict["unit"] = " " self._plot_map(plotname, plotmask, "Mask", parm_dict)
[docs] def plot_amplitude(self, basename, caller, parm_dict): """ Plot Amplitude map Args: basename: basename for the plot, the prefix 'ancillary_amplitude' will be added to it caller: Which mds called this plotting function parm_dict: dictionary with plotting parameters """ if ( parm_dict["amplitude_limits"] is None or parm_dict["amplitude_limits"] == "None" ): parm_dict["z_lim"] = np.nanmin(self.amplitude), np.nanmax(self.amplitude) else: parm_dict["z_lim"] = parm_dict["amplitude_limits"] amp_stats = data_statistics(np.where(self.mask, self.amplitude, np.nan)) noise_stats = data_statistics(self.amplitude_noise) title = ( "Amplitude, " + statistics_to_text(amp_stats) + lnbr + "Noise, " + statistics_to_text(noise_stats) ) plotname = add_prefix(basename, f"{caller}_amplitude") parm_dict["unit"] = self.amp_unit self._plot_map(plotname, self.amplitude, title, parm_dict)
[docs] def plot_phase(self, basename, caller, parm_dict): """ Plot phase map(s) Args: basename: basename for the plot(s), the prefix 'phase_{original|corrections|residuals}' will be added to it/them caller: Which mds called this plotting function parm_dict: dictionary with plotting parameters """ if parm_dict["phase_unit"] is None: parm_dict["unit"] = "deg" else: parm_dict["unit"] = parm_dict["phase_unit"] parm_dict["z_lim"] = parm_dict["phase_limits"] fac = convert_unit("rad", parm_dict["unit"], "trigonometric") prefix = "phase" if caller == "image": prefix = "corrected" maps = [self.phase] labels = ["phase"] else: if self.residuals is None: maps = [self.phase] labels = ["original"] else: maps = [self.phase, self.phase_corrections, self.phase_residuals] labels = ["original", "correction", "residual"] self._multi_plot(maps, labels, prefix, basename, fac, parm_dict, caller)
[docs] def plot_deviation(self, basename, caller, parm_dict): """ Plot deviation map(s) Args: basename: basename for the plot(s), the prefix 'deviation_{original|corrections|residuals}' will be added to it/them caller: Which mds called this plotting function parm_dict: dictionary with plotting parameters """ if parm_dict["deviation_unit"] is None: parm_dict["unit"] = "mm" else: parm_dict["unit"] = parm_dict["deviation_unit"] parm_dict["z_lim"] = parm_dict["deviation_limits"] fac = convert_unit("m", parm_dict["unit"], "length") prefix = "deviation" rms = self.get_rms(unit=parm_dict["unit"]) if caller == "image": prefix = "original" maps = [self.deviation] labels = ["deviation"] else: if self.residuals is None: maps = [self.deviation] labels = [f'original RMS={rms:.2f} {parm_dict["unit"]}'] else: maps = [self.deviation, self.corrections, self.residuals] labels = [ f'original RMS={rms[0]:.2f} {parm_dict["unit"]}', "correction", f"residual RMS={rms[1]:.2f} " f'{parm_dict["unit"]}', ] self._multi_plot(maps, labels, prefix, basename, fac, parm_dict, caller)
def _multi_plot(self, maps, labels, prefix, basename, factor, parm_dict, caller): if len(maps) != len(labels): raise ValueError("Map list and label list must be of the same size") nplots = len(maps) if parm_dict["z_lim"] is None or parm_dict["z_lim"] == "None": vmax = np.nanmax( np.abs(factor * maps[0]) ) # Gotten from the original map (displays the biggest variation) parm_dict["z_lim"] = [-vmax, vmax] for iplot in range(nplots): title = f"{prefix.capitalize()} {labels[iplot]}" plotname = add_prefix(basename, labels[iplot].split()[0]) plotname = add_prefix(plotname, prefix) plotname = add_prefix(plotname, caller) self._plot_map(plotname, factor * maps[iplot], title, parm_dict) def _plot_map(self, filename, data, title, parm_dict, add_colorbar=True): cmap = parm_dict["colormap"] fig, ax = create_figure_and_axes(parm_dict["figure_size"], [1, 1]) simple_imshow_map_plot( ax, fig, self.u_axis, self.v_axis, np.where(self.mask, data, np.nan), title, cmap, parm_dict["z_lim"], z_label="Z Scale [" + parm_dict["unit"] + "]", add_colorbar=add_colorbar, ) self._add_resolution_to_plot(ax) ax.set_xlabel("X axis [m]") ax.set_ylabel("Y axis [m]") for panel in self.panels: panel.plot( ax, screws=parm_dict["plot_screws"], label=parm_dict["panel_labels"] ) suptitle = f"{self.label}, Pol. state: {self.pol_state}" close_figure(fig, suptitle, filename, parm_dict["dpi"], parm_dict["display"]) def _add_resolution_to_plot(self, ax, xpos=0.9, ypos=0.1): lw = 0.5 if self.resolution is None: return minx = self.u_axis[0] miny = self.v_axis[1] dx = self.u_axis[-1] - minx dy = self.v_axis[-1] - miny center = (minx + xpos * dx, miny + ypos * dy) resolution = patches.Ellipse( center, self.resolution[0], self.resolution[1], angle=0.0, linewidth=lw, color="black", zorder=2, fill=False, ) ax.add_patch(resolution) halfbeam = self.resolution / dy / 2 ax.axvline( x=center[0], ymin=ypos - halfbeam[1], ymax=ypos + halfbeam[1], color="black", lw=lw / 2, ) ax.axhline( y=center[1], xmin=xpos - halfbeam[0], xmax=xpos + halfbeam[0], color="black", lw=lw / 2, )
[docs] def plot_screw_adjustments(self, filename, parm_dict): """ Plot screw adjustments as circles over a blank canvas with the panel layout Args: filename: Name of the output filename for the plot parm_dict: Dictionary with plotting parameters """ unit = parm_dict["unit"] threshold = parm_dict["threshold"] cmap = get_proper_color_map(parm_dict["colormap"], default_cmap="RdBu_r") fig, ax = create_figure_and_axes(parm_dict["figure_size"], [1, 1]) fac = convert_unit("m", unit, "length") vmax = np.nanmax(np.abs(fac * self.screw_adjustments)) vmin = -vmax if threshold is None or threshold == "None": threshold = 0.1 * vmax else: threshold = np.abs(threshold) ax.set_title(f"\nThreshold = {threshold:.2f} {unit}", fontsize="small") # set the limits of the plot to the limits of the data extent = compute_extent(self.u_axis, self.v_axis) im = ax.imshow( np.full_like(self.deviation, fill_value=np.nan), cmap=cmap, interpolation="nearest", extent=extent, vmin=vmin, vmax=vmax, ) self._add_resolution_to_plot(ax) colorbar = well_positioned_colorbar( ax, fig, im, "Screw adjustments [" + unit + "]" ) if threshold > 0: line = threshold while line < vmax: colorbar.ax.axhline(y=line, color="black", linestyle="-", lw=0.2) colorbar.ax.axhline(y=-line, color="black", linestyle="-", lw=0.2) line += threshold ax.set_xlabel("X axis [m]") ax.set_ylabel("Y axis [m]") for ipanel in range(len(self.panels)): self.panels[ipanel].plot(ax, screws=False, label=parm_dict["panel_labels"]) self.panels[ipanel].plot_corrections( ax, cmap, fac * self.screw_adjustments[ipanel], threshold, vmin, vmax ) suptitle = f"{self.label}, Pol. state: {self.pol_state}" close_figure(fig, suptitle, filename, parm_dict["dpi"], parm_dict["display"])
def _build_panel_data_arrays(self): """ Build arrays with data from the panels so that they can be stored on the XDS Returns: List with panel labels, panel fitting parameters, screw_adjustments """ npanels = len(self.panels) # First panel might fail hence we need to check npar for all panels max_par = 0 for panel in self.panels: p_npar = panel.model.npar if p_npar > max_par: max_par = p_npar nscrews = self.panels[0].screws.shape[0] self.panel_labels = np.ndarray([npanels], dtype="U22") self.panel_model_array = np.ndarray([npanels], dtype="U22") self.panel_pars = np.full((npanels, max_par), np.nan, dtype=float) self.screw_adjustments = np.ndarray((npanels, nscrews), dtype=float) self.panel_fallback = np.ndarray([npanels], dtype=bool) for ipanel in range(npanels): self.panel_labels[ipanel] = self.panels[ipanel].label self.panel_pars[ipanel, :] = self.panels[ipanel].model.parameters self.screw_adjustments[ipanel, :] = self.panels[ipanel].export_screws( unit="m" ) self.panel_model_array[ipanel] = self.panels[ipanel].model_name self.panel_fallback[ipanel] = self.panels[ipanel].fall_back_fit
[docs] def export_screws(self, filename, unit="mm", comment_char="#"): """ Export screw adjustments for all panels onto an ASCII file Args: filename: ASCII file name/path unit: unit for panel screw adjustments ['mm','miliinches'] comment_char: Character used for comments """ outfile = f"# Screw adjustments for {self.telescope.name}'s {self.label}, pol. state {self.pol_state}\n" freq = clight / self.wavelength rmses = self.get_rms(unit) outfile += f"# Frequency = {format_frequency(freq)}{lnbr}" if unit == "mm": outfile += f"# Antenna surface RMS before adjustment: {format_value_unit(rmses[0], unit)}\n" outfile += f"# Antenna surface RMS after adjustment: {format_value_unit(rmses[1], unit)}\n" else: mmrms = self.get_rms("mm") outfile += ( f"# Antenna surface RMS before adjustment: {format_value_unit(rmses[0], unit)} or " f'{format_value_unit(mmrms[0], "mm")}\n' ) outfile += ( f"# Antenna surface RMS after adjustment: {format_value_unit(rmses[1], unit)} or " f'{format_value_unit(mmrms[1], "mm")}\n' ) outfile += "# Lower means away from subreflector" + lnbr outfile += "# Raise means toward the subreflector" + lnbr outfile += "# LOWER the panel if the number is POSITIVE" + lnbr outfile += "# RAISE the panel if the number is NEGATIVE" + lnbr outfile += "# Adjustments are in " + unit + lnbr outfile += lnbr spc = " " outfile += f"{comment_char} Panel{2*spc}" nscrews = len(self.telescope.screw_description) for screw in self.telescope.screw_description: outfile += f"{4*spc}{screw:2s}{4*spc}" outfile += f"Fallback{4*spc}Model{lnbr}" fac = convert_unit("m", unit, "length") for ipanel in range(len(self.panel_labels)): outfile += "{0:>5s}".format(self.panel_labels[ipanel]) for iscrew in range(nscrews): outfile += " {0:>9.2f}".format( fac * self.screw_adjustments[ipanel, iscrew] ) outfile += ( f"{5*spc}{bool_to_str(self.panel_fallback[ipanel]):>3s}{7*spc}{self.panel_model_array[ipanel]}" + lnbr ) string_to_ascii_file(outfile, filename)
[docs] def export_xds(self): """ Export all the data to Xarray dataset Returns: XarrayDataSet containing all the relevant information """ xds = xr.Dataset() gains = self.gains() rms = self.get_rms(unit="m") xds.attrs["ddi"] = self.ddi xds.attrs["wavelength"] = self.wavelength xds.attrs["amp_unit"] = self.amp_unit xds.attrs["panel_model"] = self.panelmodel xds.attrs["panel_margin"] = self.panel_margins xds.attrs["clip"] = self.clip xds.attrs["solved"] = self.solved xds.attrs["fitted"] = self.fitted xds.attrs["aperture_resolution"] = self.resolution xds.attrs["pol_state"] = self.pol_state xds.attrs["summary"] = self.summary xds["AMPLITUDE"] = xr.DataArray(self.amplitude, dims=["u", "v"]) xds["PHASE"] = xr.DataArray(self.phase, dims=["u", "v"]) xds["DEVIATION"] = xr.DataArray(self.deviation, dims=["u", "v"]) xds["MASK"] = xr.DataArray(self.mask, dims=["u", "v"]) xds["PANEL_DISTRIBUTION"] = xr.DataArray( self.panel_distribution, dims=["u", "v"] ) xds["AMP_NOISE"] = xr.DataArray(self.amplitude_noise, dims=["u", "v"]) xds["RADIUS"] = xr.DataArray(self.rad, dims=["u", "v"]) coords = {"u": self.u_axis, "v": self.v_axis} if self.solved: xds["PHASE_RESIDUALS"] = xr.DataArray(self.phase_residuals, dims=["u", "v"]) xds["RESIDUALS"] = xr.DataArray(self.residuals, dims=["u", "v"]) xds["PHASE_CORRECTIONS"] = xr.DataArray( self.phase_corrections, dims=["u", "v"] ) xds["CORRECTIONS"] = xr.DataArray(self.corrections, dims=["u", "v"]) xds.attrs["input_rms"] = rms[0] xds.attrs["output_rms"] = rms[1] xds.attrs["input_gain"] = gains[0][0] xds.attrs["output_gain"] = gains[1][0] xds.attrs["theoretical_gain"] = gains[0][1] xds["PANEL_PARAMETERS"] = xr.DataArray( self.panel_pars, dims=["labels", "pars"] ) xds["PANEL_SCREWS"] = xr.DataArray( self.screw_adjustments, dims=["labels", "screws"] ) xds["PANEL_MODEL"] = xr.DataArray(self.panel_model_array, dims=["labels"]) xds["PANEL_FALLBACK"] = xr.DataArray(self.panel_fallback, dims=["labels"]) coords = { **coords, "labels": self.panel_labels, "screws": self.telescope.screw_description, "pars": np.arange(self.panel_pars.shape[1]), } else: xds.attrs["input_rms"] = rms xds.attrs["input_gain"] = gains[0] xds.attrs["theoretical_gain"] = gains[1] xds = xds.assign_coords(coords) return xds
[docs] def export_to_fits(self, basename): """ Data to export: Amplitude, mask, phase, phase_corrections, phase_residuals, deviations, deviation_corrections, deviation_residuals conveniently all data are on the same grid! Returns: """ head = { "PMODEL": self.panelmodel, "PMARGIN": self.panel_margins, "CLIP": self.clip, "TELESCOP": self.antenna_name, "INSTRUME": self.telescope.name, "WAVELENG": self.wavelength, "FREQUENC": clight / self.wavelength, } head = put_axis_in_fits_header(head, self.u_axis, 1, "X----LIN", "m") head = put_axis_in_fits_header(head, self.v_axis, 2, "Y----LIN", "m") head = put_resolution_in_fits_header(head, self.resolution) write_fits( head, "Amplitude", self.amplitude, add_prefix(basename, "amplitude") + ".fits", self.amp_unit, "panel", ) write_fits( head, "Mask", np.where(self.mask, 1.0, np.nan), add_prefix(basename, "mask") + ".fits", "", "panel", ) write_fits( head, "Original Phase", self.phase, add_prefix(basename, "phase_original") + ".fits", "rad", "panel", ) write_fits( head, "Phase Corrections", self.phase_corrections, add_prefix(basename, "phase_correction") + ".fits", "rad", "panel", ) write_fits( head, "Phase residuals", self.phase_residuals, add_prefix(basename, "phase_residual") + ".fits", "rad", "panel", ) write_fits( head, "Original Deviation", self.deviation, add_prefix(basename, "deviation_original") + ".fits", "m", "panel", ) write_fits( head, "Deviation Corrections", self.corrections, add_prefix(basename, "deviation_correction") + ".fits", "m", "panel", ) write_fits( head, "Deviation residuals", self.residuals, add_prefix(basename, "deviation_residual") + ".fits", "m", "panel", )