Source code for astrohack.core.image_comparison_tool

import numpy as np
from scipy.interpolate import griddata
from matplotlib import pyplot as plt
import xarray as xr
import pathlib

from astrohack.antenna.telescope import get_proper_telescope
from astrohack.utils.conversion import convert_unit
from astrohack.utils.text import (
    statistics_to_text,
    dynamic_format,
    create_pretty_table,
    string_to_ascii_file,
)
from astrohack.utils.algorithms import (
    data_statistics,
    are_axes_equal,
)
from astrohack.visualization.plot_tools import well_positioned_colorbar, compute_extent
from astrohack.visualization.plot_tools import (
    close_figure,
    get_proper_color_map,
    scatter_plot,
)
from astrohack.utils.fits import (
    read_fits,
    get_axis_from_fits_header,
    get_stokes_axis_iaxis,
    put_axis_in_fits_header,
    write_fits,
)


[docs] def test_image(fits_image): if isinstance(fits_image, FITSImage): pass else: raise TypeError("Reference image is not a FITSImage object")
[docs] class FITSImage: def __init__(self): """ Blank slate initialization of the FITSImage object """ # Attributes:
[docs] self.filename = None
[docs] self.telescope_name = None
[docs] self.rootname = None
[docs] self.factor = 1.0
[docs] self.reference_name = None
[docs] self.resampled = False
# Metadata
[docs] self.header = None
[docs] self.unit = None
[docs] self.x_axis = None
[docs] self.y_axis = None
[docs] self.original_x_axis = None
[docs] self.original_y_axis = None
[docs] self.x_unit = None
[docs] self.y_unit = None
# Data variables
[docs] self.original_data = None
[docs] self.data = None
[docs] self.residuals = None
[docs] self.residuals_percent = None
[docs] self.divided_image = None
@classmethod
[docs] def from_xds(cls, xds): """ Initialize a FITSImage object using as a base a Xarray dataset Args: xds: Xarray dataset Returns: FITSImage object initialized from a xds """ return_obj = cls() return_obj._init_as_xds(xds) return return_obj
@classmethod
[docs] def from_fits_file(cls, fits_filename, telescope_name): """ Initialize a FITSImage object using as a base a FITS file. Args: fits_filename: FITS file on disk telescope_name: Name of the telescope used Returns: FITSImage object initialized from a FITS file """ return_obj = cls() return_obj._init_as_fits(fits_filename, telescope_name) return return_obj
@classmethod
[docs] def from_zarr(cls, zarr_filename): """ Initialize a FITSImage object using as a base a Xarray dataset store on disk in a zarr container Args: zarr_filename: Xarray dataset on disk as a zarr container Returns: FITSImage object initialized from a xds """ return_obj = cls() xds = xr.open_zarr(zarr_filename) return_obj._init_as_xds(xds) return return_obj
def _init_as_fits(self, fits_filename, telescope_name, istokes=0, ichan=0): """ Backend for FITSImage.from_fits_file Args: fits_filename: FITS file on disk telescope_name: Name of the telescope used istokes: Stokes axis element to be fetched, should always be zero (singleton stokes axis or fetching I) ichan: Channel axis element to be fetched, should be zero for most cases, unless image has multiple channels Returns: None """ self.filename = fits_filename self.telescope_name = telescope_name fits_real_filename = fits_filename.split("/")[-1] self.rootname = ".".join(fits_real_filename.split(".")[:-1]) + "." self.header, self.data = read_fits(self.filename, header_as_dict=True) stokes_iaxis = get_stokes_axis_iaxis(self.header) self.unit = self.header["BUNIT"] if len(self.data.shape) == 4: if stokes_iaxis == 4: self.data = self.data[istokes, ichan, ...] else: self.data = self.data[ichan, istokes, ...] elif len(self.data.shape) == 2: pass # image is already as expected else: raise ValueError(f"FITS image has an unsupported shape: {self.data.shape}") self.original_data = np.copy(self.data) if "AIPS" in self.header["ORIGIN"]: self.x_axis, _, self.x_unit = get_axis_from_fits_header( self.header, 1, pixel_offset=False ) self.y_axis, _, self.y_unit = get_axis_from_fits_header( self.header, 2, pixel_offset=False ) offset_scale = 1.5 x_offset = offset_scale * np.unique(np.diff(self.x_axis))[0] y_offset = offset_scale * np.unique(np.diff(self.y_axis))[0] self.x_axis = np.flip(self.x_axis + x_offset) self.y_axis = np.flip(self.y_axis + y_offset) self.x_unit = "m" self.y_unit = "m" elif "Astrohack" in self.header["ORIGIN"]: self.x_axis, _, self.x_unit = get_axis_from_fits_header(self.header, 1) self.y_axis, _, self.y_unit = get_axis_from_fits_header(self.header, 2) else: raise NotImplementedError(f'Unrecognized origin:\n{self.header["origin"]}') self._create_base_mask() self.original_x_axis = np.copy(self.x_axis) self.original_y_axis = np.copy(self.y_axis) def _init_as_xds(self, xds): """ Backend for FITSImage.from_xds Args: xds: Xarray DataSet Returns: None """ for key in xds.attrs: setattr(self, key, xds.attrs[key]) self.x_axis = xds.x.values self.y_axis = xds.y.values self.original_x_axis = xds.original_x.values self.original_y_axis = xds.original_y.values for key, value in xds.items(): setattr(self, str(key), xds[key].values) def _create_base_mask(self): """ Create a base mask based on telescope parameters such as arm shadows. Returns: None """ telescope_obj = get_proper_telescope(self.telescope_name) self.base_mask = telescope_obj.create_aperture_mask( self.x_axis, self.y_axis, use_outer_limit=True )
[docs] def resample(self, ref_image): """ Resamples the data on this object onto the grid in ref_image Args: ref_image: Reference FITSImage object Returns: None """ test_image(ref_image) x_mesh_orig, y_mesh_orig = np.meshgrid(self.x_axis, self.y_axis, indexing="ij") x_mesh_dest, y_mesh_dest = np.meshgrid( ref_image.x_axis, ref_image.y_axis, indexing="ij" ) raveled_data = self.data.ravel() valid_data = np.isfinite(raveled_data) resamp = griddata( (x_mesh_orig.ravel()[valid_data], y_mesh_orig.ravel()[valid_data]), raveled_data[valid_data], (x_mesh_dest.ravel(), y_mesh_dest.ravel()), method="nearest", ) size = ref_image.x_axis.shape[0], ref_image.y_axis.shape[0] self.x_axis = ref_image.x_axis self.y_axis = ref_image.y_axis self.data = resamp.reshape(size) self._create_base_mask() self.resampled = True
[docs] def compare_difference(self, ref_image): """ Does the difference comparison between self and ref_image. Args: ref_image: Reference FITSImage object Returns: None """ test_image(ref_image) if not self.image_has_same_sampling(ref_image): self.resample(ref_image) self.residuals = ref_image.data - (self.data * self.factor) self.residuals_percent = 100 * self.residuals / ref_image.data self.reference_name = ref_image.filename
[docs] def compare_scaled_difference(self, ref_image, rejection=10): """ Does the scaled difference comparison between self and ref_image. Args: ref_image: Reference FITSImage object rejection: rejection level for scaling factor Returns: None """ test_image(ref_image) if not self.image_has_same_sampling(ref_image): self.resample(ref_image) simple_division = ref_image.data / self.data rough_factor = np.nanmean(simple_division[self.base_mask]) self.divided_image = np.where( np.abs(simple_division) > rejection * rough_factor, np.nan, simple_division ) self.factor = np.nanmedian(self.divided_image) self.compare_difference(ref_image)
[docs] def image_has_same_sampling(self, ref_image): """ Tests if self has the same X and Y sampling as ref_image Args: ref_image: Reference FITSImage object Returns: True or False """ test_image(ref_image) return are_axes_equal(self.x_axis, ref_image.x_axis) and are_axes_equal( self.y_axis, ref_image.y_axis )
[docs] def mask_array(self, image_array): """ Applies base mask to image_array Args: image_array: Data array to be masked Returns: Masked array """ return np.where(self.base_mask, image_array, np.nan)
[docs] def mask_original(self): """ Applies base mask equivalent to original data Returns: Masked original data """ telescope_obj = get_proper_telescope(self.telescope_name) orig_mask = telescope_obj.create_aperture_mask( self.original_x_axis, self.original_y_axis, use_outer_limit=True ) return np.where(orig_mask, self.original_data, np.nan)
[docs] def plot_images( self, destination, ref_image, plot_resampled=False, plot_percentuals=False, plot_reference=False, plot_original=False, plot_divided_image=False, z_scale=None, colormap="viridis", dpi=300, display=False, ): """ Plot image contents of the FITSImage object, always plots the residuals when called Args: destination: Location onto which save plot files ref_image: reference image plot_resampled: Also plot data array? plot_percentuals: Also plot percentual residuals array? plot_reference: Also plot reference image? plot_original: Also plot original unresampled image? plot_divided_image: Also plot divided image? z_scale: Z scale for original, resampled, reference and residual images colormap: Colormap name for image plots dpi: png resolution on disk display: Show interactive view of plots Returns: None """ extent = compute_extent(self.x_axis, self.y_axis, 0.0) cmap = get_proper_color_map(colormap) base_name = f"{destination}/{self.rootname}" if self.residuals is None: raise RuntimeError("Cannot plot results as they don't exist yet.") self._plot_map( self.mask_array(self.residuals), f"Residuals, {self.reference_name} - {self.filename}", f"Residuals [{self.unit}]", f"{base_name}residuals.png", cmap, extent, z_scale, dpi, display, add_statistics=True, ) if plot_resampled: self._plot_map( self.mask_array(self.data), "Resampled Data", f"Data [{self.unit}]", f"{base_name}resampled.png", cmap, extent, z_scale, dpi, display, add_statistics=True, ) if plot_reference: self._plot_map( self.mask_array(ref_image.data), f"Reference: {self.reference_name}", f"Data [{self.unit}]", f"{base_name}reference.png", cmap, extent, z_scale, dpi, display, add_statistics=True, ) if plot_original: self._plot_map( self.mask_original(), f"Unresampled data", f"Data [{self.unit}]", f"{base_name}original.png", cmap, extent, z_scale, dpi, display, add_statistics=True, ) if plot_percentuals: if self.residuals is None: raise RuntimeError("Cannot plot results as they don't exist yet.") self._plot_map( self.mask_array(self.residuals_percent), f"Residuals in %, {self.reference_name} - {self.filename}", f"Residuals [%]", f"{base_name}residuals_percent.png", cmap, extent, "symmetrical", dpi, display, add_statistics=True, ) if plot_divided_image: if self.divided_image is None: pass else: self._plot_map( self.mask_array(self.divided_image), f"Divided image, {self.reference_name} / {self.filename}, scaling={self.factor:.4f}", f"Division [ ]", f"{base_name}divided.png", cmap, extent, None, dpi, display, add_statistics=True, )
def _plot_map( self, data, title, zlabel, filename, cmap, extent, zscale, dpi, display, add_statistics=False, ): """ Backend for plot_images Args: data: Data array to be plotted title: Title to appear on plot zlabel: Label for the colorbar filename: name for the png file on disk cmap: Colormap object for plots extent: extents of the X and Y axes zscale: Constraints on the Z axes. dpi: png resolution on disk display: Show interactive view of plots add_statistics: Add simple statistics to plot's subtitle Returns: None """ fig, ax = plt.subplots(1, 1, figsize=[10, 8]) if zscale == "symmetrical": scale = max(np.abs(np.nanmin(data)), np.abs(np.nanmax(data))) vmin, vmax = -scale, scale elif zscale is None: vmin = np.nanmin(data) vmax = np.nanmax(data) else: vmin, vmax = zscale if vmin == "None" or vmin is None: vmin = np.nanmin(data) if vmax == "None" or vmax is None: vmax = np.nanmax(data) im = ax.imshow( data, cmap=cmap, interpolation="nearest", extent=extent, vmin=vmin, vmax=vmax, ) well_positioned_colorbar( ax, fig, im, zlabel, location="right", size="5%", pad=0.05 ) ax.set_xlabel(f"X axis [{self.x_unit}]") ax.set_ylabel(f"Y axis [{self.y_unit}]") if add_statistics: data_stats = data_statistics(data) ax.set_title(statistics_to_text(data_stats, num_format="dynamic")) close_figure(fig, title, filename, dpi, display)
[docs] def export_as_xds(self): """ Create a Xarray DataSet from the FITSImage object Returns: Xarray DataSet """ xds = xr.Dataset() obj_dict = vars(self) coords = { "x": self.x_axis, "y": self.y_axis, "original_x": self.original_x_axis, "original_y": self.original_y_axis, } for key, value in obj_dict.items(): failed = False if isinstance(value, np.ndarray): if len(value.shape) == 2: if "original" in key: xds[key] = xr.DataArray( value, dims=["original_x", "original_y"] ) else: xds[key] = xr.DataArray(value, dims=["x", "y"]) elif len(value.shape) == 1: pass # Axes else: failed = True else: xds.attrs[key] = value if failed: raise ValueError(f"Don't know what to do with: {key}") xds = xds.assign_coords(coords) return xds
[docs] def to_zarr(self, zarr_filename): """ Saves a xds representation of self on disk using the zarr format. Args: zarr_filename: Name for the zarr container on disk Returns: None """ xds = self.export_as_xds() xds.to_zarr(zarr_filename, mode="w", compute=True, consolidated=True)
def __repr__(self): """ Print method Returns: A String summary of the current status of self. """ obj_dict = vars(self) outstr = "" for key, value in obj_dict.items(): if isinstance(value, np.ndarray): outstr += f"{key:17s} -> {value.shape}" elif isinstance(value, dict): outstr += f"{key:17s} -> dict()" else: outstr += f"{key:17s} = {value}" outstr += "\n" return outstr
[docs] def export_to_fits(self, destination): """ Export internal images to FITS files. Args: destination: location to store FITS files Returns: None """ pathlib.Path(destination).mkdir(exist_ok=True) ext_fits = ".fits" out_header = self.header.copy() put_axis_in_fits_header(out_header, self.x_axis, 1, "", self.x_unit) put_axis_in_fits_header(out_header, self.y_axis, 2, "", self.y_unit) obj_dict = vars(self) for key, value in obj_dict.items(): if isinstance(value, np.ndarray): if len(value.shape) == 2: if "original" in key: pass else: if key == "base_mask" or key == "divided_image": unit = "" elif key == "residuals_percent": unit = "%" else: unit = self.unit filename = f"{destination}/{self.rootname}{key}{ext_fits}" write_fits( out_header, key, np.fliplr(value.astype(float)), filename, unit, reorder_axis=False, )
[docs] def scatter_plot( self, destination, ref_image, dpi=300, display=False, max_radius=None, min_radius=None, ): """ Produce a scatter plot of self.data agains ref_image.data Args: destination: Location to store scatter plot ref_image: Reference FITSImage object dpi: png resolution on disk display: Show interactive view of plot max_radius: Maximum radius for scatter plot comparison as the outer panels can be crappy. min_radius: Minimum radius for scatter plot comparison as the innermost panels can be crappy. Returns: None """ test_image(ref_image) if not self.image_has_same_sampling(ref_image): self.resample(ref_image) fig, ax = plt.subplots(1, 1, figsize=[10, 8]) x_mesh_orig, y_mesh_orig = np.meshgrid(self.x_axis, self.y_axis, indexing="ij") radius = np.sqrt(x_mesh_orig**2 + y_mesh_orig**2) telescope = get_proper_telescope(self.telescope_name) if min_radius is None: min_radius = telescope.inner_radial_limit if max_radius is None: max_radius = telescope.outer_radial_limit - 1.0 scatter_mask = np.isfinite(ref_image.data) scatter_mask = np.where(np.isfinite(self.data), scatter_mask, False) scatter_mask = np.where(radius < max_radius, scatter_mask, False) scatter_mask = np.where(radius > min_radius, scatter_mask, False) ydata = self.data[scatter_mask] xdata = ref_image.data[scatter_mask] pl_max = np.max((np.max(xdata), np.max(ydata))) pl_min = np.min((np.min(xdata), np.min(ydata))) scatter_plot( ax, xdata, f"Reference image {ref_image.filename} [{ref_image.unit}]", ydata, f"{self.filename} [{self.unit}]", add_regression=True, regression_method="siegelslopes", add_regression_reference=True, regression_reference_label="Perfect agreement", xlim=(pl_min, pl_max), ylim=[pl_min, pl_max], force_equal_aspect=True, ) close_figure( fig, "Scatter plot against reference image", f"{destination}/{self.rootname}scatter.png", dpi, display, )
[docs] def image_comparison_chunk(compare_params): """ Chunk function for parallel execution of the image comparison tool. Args: compare_params: Parameter dictionary for workflow control. Returns: A DataTree containing the Image and its reference Image. """ image = FITSImage.from_fits_file( compare_params["this_image"], compare_params["telescope_name"] ) ref_image = FITSImage.from_fits_file( compare_params["this_reference_image"], compare_params["telescope_name"] ) plot_resampled = compare_params["plot_resampled"] plot_percentuals = compare_params["plot_percentuals"] plot_divided = compare_params["plot_divided_image"] plot_reference = compare_params["plot_reference"] plot_original = compare_params["plot_original"] destination = compare_params["destination"] z_scale = compare_params["z_scale_limits"] colormap = compare_params["colormap"] dpi = compare_params["dpi"] display = compare_params["display"] if compare_params["comparison"] == "direct": image.compare_difference(ref_image) image.plot_images( destination, ref_image, plot_resampled, plot_percentuals, plot_reference, plot_original, False, z_scale=z_scale, colormap=colormap, dpi=dpi, display=display, ) elif compare_params["comparison"] == "scaled": image.compare_scaled_difference(ref_image) image.plot_images( destination, ref_image, plot_resampled, plot_percentuals, plot_reference, plot_original, plot_divided, z_scale=z_scale, colormap=colormap, dpi=dpi, display=display, ) else: raise ValueError(f'Unknown comparison type {compare_params["comparison"]}') if compare_params["export_to_fits"]: image.export_to_fits(destination) if compare_params["plot_scatter"]: image.scatter_plot(destination, ref_image, dpi=dpi, display=display) img_node = xr.DataTree(name=image.rootname, dataset=image.export_as_xds()) ref_node = xr.DataTree(name=ref_image.rootname, dataset=ref_image.export_as_xds()) tree_node = xr.DataTree( name=image.rootname[:-1], children={"Reference": ref_node, "Image": img_node} ) return tree_node
[docs] def extract_rms_from_xds(xds): """ This simple function extracts FITSImage RMSes for a xds describing a FITSImage obj Args: xds: xds describing a FITSImage obj Returns: dict with RMS values """ rms_dict = {} img_obj = FITSImage.from_xds(xds) if img_obj.residuals is None: rms_dict["resampled"] = np.nan rms_dict["residuals"] = np.nan else: rms_dict["resampled"] = np.nanstd(img_obj.mask_array(img_obj.data)) rms_dict["residuals"] = np.nanstd(img_obj.mask_array(img_obj.residuals)) rms_dict["original"] = np.nanstd(img_obj.mask_original()) return rms_dict
[docs] def create_fits_comparison_rms_table(parameters, xdt): image_list = xdt.children rms_unit = parameters["rms_unit"] fields = [ "Image", "Reference", f"Original RMS [{rms_unit}]", f"Resampled RMS [{rms_unit}]", f"Reference RMS [{rms_unit}]", f"Residuals RMS [{rms_unit}]", ] factor = convert_unit("m", rms_unit, "length") table = create_pretty_table(fields) for image in image_list: image_xds = xdt[image]["Image"].to_dataset() reference_xds = xdt[image]["Reference"].to_dataset() img_rms_dict = extract_rms_from_xds(image_xds) ref_rms_dict = extract_rms_from_xds(reference_xds) values = np.array( [ img_rms_dict["original"], img_rms_dict["resampled"], ref_rms_dict["original"], img_rms_dict["residuals"], ] ) values *= factor row = [image_xds.attrs["filename"], reference_xds.attrs["filename"]] for val in values: row.append(f"{val:{dynamic_format(val)}}") table.add_row(row) outstr = f'RMS comparison table from {parameters["zarr_data_tree"]}:\n' outstr += table.get_string() string_to_ascii_file(outstr, parameters["table_file"]) if parameters["print_table"]: print(table) return