import time
import numpy as np
import xarray
from toolviper.utils import logger as logger
from scipy.interpolate import griddata
from numba import njit
from numba.core import types
from numba.typed import List as numbaList
import math
from astrohack.utils.constants import (
sig_2_fwhm,
njit_caching,
)
from astrohack.utils.algorithms import (
find_nearest,
calc_coords,
find_peak_beam_value,
chunked_average,
)
from astrohack.utils.tools import (
get_str_idx_in_list,
raise_type_error,
check_is_proper_array,
check_is_proper_shape,
)
[docs]
def grid_beam(
ant_ddi_xdt: xarray.DataTree,
grid_size,
sky_cell_size,
avg_chan,
chan_tol_fac,
telescope,
grid_interpolation_mode,
observation_summary,
label,
):
"""
Grids the visibilities onto a 2D plane based on their Sky coordinates, using scipy griddata or a gaussian
convolution
Args:
ant_ddi_xdt: Xarray DataTree containing the visibilities
grid_size: The size of the beam image grid (pixels)
sky_cell_size: Size of the beam grid cell in the sky (radians)
avg_chan: Average cahnnels? (boolean)
chan_tol_fac: Frequency tolerance to chunk channels together
telescope: Telescope object containing optical description of the telescope
grid_interpolation_mode: linear, nearest, cubic or gaussian (convolution)
observation_summary: Dictionaty containing a summary of observation information.
label: label to be used in messages
Returns:
The gridded beam, its time centroid, frequency axis, polarization axis, L and M axes and a boolean about the
necessity of gridding corrections after fourier transform.
"""
n_holog_map = len(ant_ddi_xdt.keys())
map_0_key = list(ant_ddi_xdt.keys())[0]
freq_axis = ant_ddi_xdt[map_0_key].chan.values
pol_axis = ant_ddi_xdt[map_0_key].pol.values
n_chan = ant_ddi_xdt[map_0_key].sizes["chan"]
n_pol = ant_ddi_xdt[map_0_key].sizes["pol"]
observation_summary["beam"]["grid size"] = [int(grid_size[0]), int(grid_size[1])]
observation_summary["beam"]["cell size"] = [sky_cell_size[0], sky_cell_size[1]]
reference_scaling_frequency = np.mean(freq_axis)
if avg_chan:
n_chan = 1
avg_chan_map, avg_freq_axis = _create_average_chan_map(freq_axis, chan_tol_fac)
output_freq_axis = [np.mean(avg_freq_axis)]
observation_summary["spectral"]["channel width"] *= observation_summary[
"spectral"
]["number of channels"]
observation_summary["spectral"]["number of channels"] = 1
else:
avg_chan_map = None
avg_freq_axis = None
output_freq_axis = freq_axis
l_axis, m_axis, l_grid, m_grid, beam_grid = _create_beam_grid(
grid_size, sky_cell_size, n_chan, n_pol, n_holog_map
)
scipy_interp = ["linear", "nearest", "cubic"]
time_centroid = []
grid_corr = False
for holog_map_index, map_xdt in enumerate(ant_ddi_xdt.values()):
# Grid the data
vis = map_xdt.VIS.values
vis[vis == np.nan] = 0.0
lm = map_xdt.DIRECTIONAL_COSINES.values
weight = map_xdt.WEIGHT.values
if avg_chan:
vis_avg, weight_sum = chunked_average(
vis, weight, avg_chan_map, avg_freq_axis
)
lm_freq_scaled = lm[:, :, None] * (
avg_freq_axis / reference_scaling_frequency
)
else:
vis_avg = vis
weight_sum = weight
lm_freq_scaled = lm[:, :, None] * np.full_like(freq_axis, 1.0)
if grid_interpolation_mode in scipy_interp:
beam_grid[holog_map_index, ...] = _scipy_gridding(
vis_avg,
lm_freq_scaled,
l_grid,
m_grid,
grid_interpolation_mode,
avg_chan,
label,
)
elif grid_interpolation_mode == "gaussian":
grid_corr = True
beam_grid[holog_map_index, ...] = _convolution_gridding(
vis_avg,
weight_sum,
lm_freq_scaled,
telescope.diameter,
l_axis,
m_axis,
sky_cell_size,
reference_scaling_frequency,
avg_chan,
label,
)
else:
msg = f"Unknown grid type {grid_interpolation_mode}."
logger.error(msg)
raise ValueError(msg)
time_centroid_index = map_xdt.sizes["time"] // 2
time_centroid.append(map_xdt.coords["time"][time_centroid_index].values)
beam_grid[holog_map_index, ...] = _normalize_beam(
beam_grid[holog_map_index, ...], n_chan, pol_axis
)
return (
beam_grid,
time_centroid,
output_freq_axis,
pol_axis,
l_axis,
m_axis,
grid_corr,
observation_summary,
)
[docs]
def gridding_correction(aperture, freq, diameter, sky_cell_size, u_axis, v_axis):
"""
Execute gridding correction after fourier transform for the case of the gaussian convolution
Args:
aperture: Aperture image
freq: representative frequency
diameter: Telescope diameter
sky_cell_size: Size of the beam grid cell in the sky (radians)
u_axis: U axis of the aperture grid
v_axis: V axis of the aperture grid
Returns:
The gridding corrected aperture grid
"""
beam_size = _compute_beam_size(diameter, freq)
return _gridding_correction_jit(aperture, beam_size, sky_cell_size, u_axis, v_axis)
def _create_beam_grid(grid_size, sky_cell_size, n_chan, n_pol, n_map):
"""
Create the beam onto which to store the beam image
Args:
grid_size: The size of the beam image grid (pixels)
sky_cell_size: Size of the beam grid cell in the sky (radians)
n_chan: Number of channels
n_pol: Number of polarization states
n_map: Number of mappings
Returns:
L and M axes, 2D mesh of the L and M axes, the actual beam grid
"""
l_axis, m_axis = calc_coords(grid_size, sky_cell_size)
l_grid, m_grid = list(map(np.transpose, np.meshgrid(l_axis, m_axis)))
beam_grid = np.zeros((n_map,) + (n_chan, n_pol) + l_grid.shape, dtype=np.complex128)
return l_axis, m_axis, l_grid, m_grid, beam_grid
def _scipy_gridding(vis, lm, l_grid, m_grid, grid_interpolation_mode, avg_chan, label):
"""
Grid the visibility data using scipy gridding algorithms.
Args:
vis: Visibilities
lm: Visibilities sky coordinates
l_grid: 2D mesh of the L axis
m_grid: 2D mesh of the M axis
grid_interpolation_mode: linear, nearest, cubic
Returns:
beam data gridded
"""
start = time.time()
n_pol = vis.shape[2]
n_chan = vis.shape[1]
if avg_chan:
beam_grid = np.zeros(
(1, n_pol, l_grid.shape[0], l_grid.shape[1]), dtype=complex
)
else:
beam_grid = np.zeros(
(n_chan, n_pol, l_grid.shape[0], l_grid.shape[1]), dtype=complex
)
# Unavoidable for loop because lm change over frequency.
for i_chan in range(n_chan):
# Average scaled beams.
gridded_chan = np.moveaxis(
griddata(
lm[:, :, i_chan],
vis[:, i_chan, :],
(l_grid, m_grid),
method=grid_interpolation_mode,
fill_value=0.0,
),
2,
0,
)
if avg_chan:
beam_grid[0, :, :, :] += gridded_chan
else:
beam_grid[i_chan, :, :, :] = gridded_chan
duration = time.time() - start
logger.debug(f"{label}: Interpolation gridding took {duration:.3} seconds")
return beam_grid
def _normalize_beam(beam_grid, n_chan, pol_axis):
"""
Normalize the gridded beam data
Args:
beam_grid: the gridded beam
n_chan: The number of channels in the beam data
pol_axis: polarization axis
Returns:
Normalized beam grid
"""
if "I" in pol_axis:
i_i = get_str_idx_in_list("I", pol_axis)
i_peak = find_peak_beam_value(beam_grid[0, i_i, ...], scaling=0.25)
beam_grid[0, i_i, ...] /= i_peak
else:
if "RR" in pol_axis:
i_p1 = get_str_idx_in_list("RR", pol_axis)
i_p2 = get_str_idx_in_list("LL", pol_axis)
elif "XX" in pol_axis:
i_p1 = get_str_idx_in_list("XX", pol_axis)
i_p2 = get_str_idx_in_list("YY", pol_axis)
else:
msg = f"Unknown polarization scheme: {pol_axis}"
logger.error(msg)
raise ValueError(msg)
for chan in range(n_chan):
try:
p1_peak = find_peak_beam_value(beam_grid[chan, i_p1, ...], scaling=0.25)
p2_peak = find_peak_beam_value(beam_grid[chan, i_p2, ...], scaling=0.25)
except IndexError:
center_pixel = np.array(beam_grid.shape[-2:]) // 2
p1_peak = beam_grid[chan, i_p1, center_pixel[0], center_pixel[1]]
p2_peak = beam_grid[chan, i_p2, center_pixel[0], center_pixel[1]]
normalization = np.abs(0.5 * (p1_peak + p2_peak))
if normalization == 0:
logger.warning("Peak of zero found! Setting normalization to unity.")
normalization = 1
beam_grid[chan, ...] /= normalization
return beam_grid
def _create_average_chan_map(freq_chan, chan_tolerance_factor):
"""
Create the mapping of channels to later apply their chunking
Args:
freq_chan: frequency axis
chan_tolerance_factor: Maximum distance in frequency between channels in the same chunk
Returns:
Map of channel chunking, new frequency axis
"""
n_chan = len(freq_chan)
tol = np.max(freq_chan) * chan_tolerance_factor
n_pb_chan = int(np.floor((np.max(freq_chan) - np.min(freq_chan)) / tol) + 0.5)
# Create PB's for each channel
if n_pb_chan == 0:
n_pb_chan = 1
if n_pb_chan >= n_chan:
cf_chan_map = np.arange(n_chan)
pb_freq = freq_chan
return cf_chan_map, pb_freq
pb_delta_bandwdith = (np.max(freq_chan) - np.min(freq_chan)) / n_pb_chan
pb_freq = (
np.arange(n_pb_chan) * pb_delta_bandwdith
+ np.min(freq_chan)
+ pb_delta_bandwdith / 2
)
cf_chan_map = np.zeros((n_chan,), dtype=int)
for i in range(n_chan):
cf_chan_map[i], _ = find_nearest(pb_freq, freq_chan[i])
return cf_chan_map, pb_freq
[docs]
def grid_1d_data(
dest_ax,
orig_ax,
y_data,
method,
orig_label,
dest_label,
gaussian_fallback=True,
return_weights=False,
second_dim_len=2,
):
if isinstance(y_data, np.ndarray):
y_data = [y_data]
elif isinstance(y_data, list):
pass
else:
raise_type_error("y_data", "list or numpy array")
y_data = numbaList(y_data)
check_is_proper_array(dest_ax, 1)
check_is_proper_array(orig_ax, 1)
for datum in y_data:
check_is_proper_shape(datum, [orig_ax.shape[0], second_dim_len])
dest_delta = np.median(np.diff(dest_ax))
orig_delta = np.median(np.diff(orig_ax))
if method == "linear":
if orig_delta < dest_delta:
new_y_data, weights = _linear_interpolate_under_sample(
dest_ax, orig_ax, dest_delta, y_data
)
else:
new_y_data, weights = _liner_interpolate_over_sample(
dest_ax, orig_ax, orig_delta, y_data
)
elif method == "gaussian":
new_y_data, weights = _gaussian_convolution_1d_jit(
dest_ax, orig_ax, dest_delta, y_data
)
else:
raise ValueError(f"{method} is not a valid interpolation methods")
with np.errstate(divide="ignore", invalid="ignore"):
new_y_data /= weights[np.newaxis, :, np.newaxis]
n_nans = int(np.sum(np.isnan(new_y_data[0])) / 2)
if n_nans != 0:
if method == "linear":
logger.warning(
f"{orig_label} have produced NaNs when resampled onto {dest_label} using linear "
"interpolation."
)
if gaussian_fallback:
logger.warning(f"Falling back to Gaussian convolution.")
new_y_data, weights = _gaussian_convolution_1d_jit(
dest_ax, orig_ax, dest_delta, y_data
)
else:
logger.warning(f"Fallback to gaussian convolution is off.")
else:
logger.warning(
f"{orig_label} have produced NaNs when resampled onto {dest_label} using gaussian "
"convolution."
)
if return_weights:
return new_y_data, weights
else:
return new_y_data
@njit(cache=njit_caching, nogil=True)
def _create_new_data_and_weights(dest_ax, y_data):
"""
Assumes y_data is a list of [n, m] arrays
Args:
dest_ax: destiniy axis
y_data: Y data list
Returns:
new_y_data and weights of the proper shapes
"""
n_data = len(y_data)
new_shape = (n_data, dest_ax.shape[0], y_data[0].shape[1])
new_y_data = np.zeros(new_shape)
weights = np.zeros_like(dest_ax)
return new_y_data, weights
@njit(cache=njit_caching, nogil=True)
def _get_ordered_axis_index(coor, i_pos, axis, half_int):
if i_pos == axis.shape[0]:
return -1
while coor > axis[i_pos] + half_int:
i_pos += 1
if i_pos == axis.shape[0]:
return -1
return i_pos
@njit(cache=njit_caching, nogil=True)
def _linear_interpolate_under_sample(dest_ax, orig_ax, dest_delta, y_data):
half_int_dest = dest_delta / 2
new_y_data, weights = _create_new_data_and_weights(dest_ax, y_data)
i_dest = 0
for i_orig, coor in enumerate(orig_ax):
if coor < dest_ax[i_dest] - half_int_dest:
continue
else:
i_dest = _get_ordered_axis_index(coor, i_dest, dest_ax, half_int_dest)
if i_dest < 0:
break
weights[i_dest] += 1
for i_data, datum in enumerate(y_data):
for i_3dim in range(new_y_data.shape[2]):
new_y_data[i_data, i_dest, i_3dim] += datum[i_orig, i_3dim]
return new_y_data, weights
@njit(cache=njit_caching, nogil=True)
def _liner_interpolate_over_sample(dest_ax, orig_ax, orig_delta, y_data):
half_int_orig = orig_delta / 2
new_y_data, weights = _create_new_data_and_weights(dest_ax, y_data)
i_orig = 0
for i_dest, coor in enumerate(dest_ax):
i_orig = _get_ordered_axis_index(coor, i_orig, orig_ax, half_int_orig)
weights[i_dest] += 1
for i_data, datum in enumerate(y_data):
for i_3dim in range(new_y_data.shape[2]):
new_y_data[i_data, i_dest, i_3dim] += datum[i_orig, i_3dim]
return new_y_data, weights
@njit(cache=njit_caching, nogil=True)
def _gaussian_convolution_1d_jit(dest_ax, orig_ax, hpkw, y_data):
kernel = _create_exponential_kernel(hpkw, hpkw)
new_y_data, weights = _create_new_data_and_weights(dest_ax, y_data)
for i_orig, coor in enumerate(orig_ax):
i_min, i_max = _compute_kernel_range(kernel, coor, dest_ax)
for i_dest in range(i_min, i_max):
conv_fact = _convolution_factor(kernel, dest_ax[i_dest] - coor)
weights[i_dest] += conv_fact
for i_data, datum in enumerate(y_data):
for i_3dim in range(new_y_data.shape[2]):
new_y_data[i_data, i_dest, i_3dim] += (
conv_fact * y_data[i_data][i_orig, i_3dim]
)
return new_y_data, weights
def _convolution_gridding(
visibilities,
weights,
lmvis,
diameter,
l_axis,
m_axis,
sky_cell_size,
reference_scaling_frequency,
avg_chan,
label,
):
"""
Grid the visibility data using a gaussian convolution with a kernel based on primary beam size
Args:
visibilities: Visibilities
weights: Weights
lmvis: Visibilities sky coordinates
diameter: Telescope diameter
l_axis: L axis
m_axis: M axis
sky_cell_size: Size of the beam grid cell in the sky (radians)
Returns:
beam data gridded
"""
beam_size = _compute_beam_size(diameter, reference_scaling_frequency)
start = time.time()
beam, _ = _convolution_gridding_jit(
visibilities, lmvis, weights, sky_cell_size, l_axis, m_axis, beam_size, avg_chan
)
duration = time.time() - start
logger.debug(f"{label}: Gaussian convolution gridding took {duration:.3} seconds")
return beam
@njit(cache=njit_caching, nogil=True)
def _convolution_gridding_jit(
visibilities, lmvis, weights, sky_cell_size, l_axis, m_axis, beam_size, avg_chan
):
"""
Actual Gridding of the visibility data using a gaussian convolution with a kernel based on primary beam size,
using numba jit for fast code
Args:
visibilities: Visibilities
weights: Weights
lmvis: Visibilities sky coordinates
l_axis: L axis
m_axis: M axis
sky_cell_size: Size of the beam grid cell in the sky (radians)
beam_size: Primary beam size
Returns:
beam data gridded
"""
ntime, nchan, npol = visibilities.shape
l_kernel = _create_exponential_kernel(beam_size, sky_cell_size[0])
m_kernel = _create_exponential_kernel(beam_size, sky_cell_size[1])
if avg_chan:
grid_shape = (1, npol, l_axis.shape[0], m_axis.shape[0])
else:
grid_shape = (nchan, npol, l_axis.shape[0], m_axis.shape[0])
# This type has to be changed to np.complex128 when debugging with jit off
beam_grid = np.zeros(grid_shape, dtype=types.complex128)
weig_grid = np.zeros(grid_shape)
o_chan = np.arange(visibilities.shape[1])
if avg_chan:
o_chan[:] = 0
for i_time in range(ntime):
for i_chan in range(nchan):
lval, mval = lmvis[i_time, :, i_chan]
i_lmin, i_lmax = _compute_kernel_range(l_kernel, lval, l_axis)
i_mmin, i_mmax = _compute_kernel_range(m_kernel, mval, m_axis)
for i_pol in range(npol):
for il in range(i_lmin, i_lmax):
l_fac = _convolution_factor(l_kernel, l_axis[il] - lval)
for im in range(i_mmin, i_mmax):
m_fac = _convolution_factor(m_kernel, m_axis[im] - mval)
conv_fact = l_fac * m_fac * weights[i_time, i_chan, i_pol]
beam_grid[o_chan[i_chan], i_pol, il, im] += (
conv_fact * visibilities[i_time, i_chan, i_pol]
)
weig_grid[o_chan[i_chan], i_pol, il, im] += conv_fact
beam_grid /= weig_grid
beam_grid = np.nan_to_num(beam_grid)
return beam_grid, weig_grid
@njit(cache=njit_caching, nogil=True)
def _find_nearest(value, array):
"""
Find nearest array element to value (array must be sorted)
Args:
value: value to test
array: array to onto which to find the nearest element
Returns:
Index in the array containing the nearest value to input value
"""
diff = np.abs(array - value)
idx = diff.argmin()
return idx
@njit(cache=njit_caching, nogil=True)
def _create_exponential_kernel(
beam_size, sky_cell_size, exponent=2, oversampling=100, hpbw_width=4
):
"""
Creates an exponential kernel to use in convolution
Args:
beam_size: Beam size (used to determine kernel's width, radians)
sky_cell_size: Size of the beam grid cell in the sky (radians)
exponent: exponent of the kernels exponent
Returns:
Adictionary containing the convolution kernel
"""
smoothing = beam_size
support = hpbw_width * smoothing
width = smoothing / sig_2_fwhm
pix_support = support / np.abs(sky_cell_size)
pix_width = width / np.abs(sky_cell_size)
if pix_support < 1.0:
used_support = 2 * (pix_support + 0.995) + 1
else:
used_support = 2 * pix_support + 1
kernel_size = used_support * oversampling + 1
k_coeff = np.log(kernel_size) / np.log(2)
k_integer = math.ceil(k_coeff)
kernel_size = np.power(2, k_integer)
bias = oversampling / 2 * used_support + 1.0
u_axis = (np.arange(kernel_size) - bias) / oversampling
kernel = np.exp(-((u_axis / pix_width) ** exponent))
ker_dict = {
"bias": bias,
"u_axis": u_axis,
"kernel": kernel,
"user_support": support,
"user_width": width,
"pix_support": pix_support,
"oversampling": oversampling,
"sky_cell_size": sky_cell_size,
"kernel_size": kernel_size,
}
return ker_dict
@njit(cache=njit_caching, nogil=True)
def _compute_kernel_range(kernel, coor, axis):
"""
Compute the range of pixels over which to perform the convolution
Args:
kernel: Convolution kernel
coor: Coordenate of the visibility
axis: axis over which convolution is being done
Returns:
first and last pixel over which to perform the convolution
"""
idx = _find_nearest(coor, axis)
i_min = round(idx - kernel["pix_support"])
i_max = round(idx + kernel["pix_support"]) + 1
if i_min < 0:
i_min = 0
if i_max >= axis.shape[0]:
i_max = axis.shape[0]
return i_min, i_max
@njit(cache=njit_caching, nogil=True)
def _convolution_factor(kernel, delta):
"""
Compute the convolution factor for a specific pixel
Args:
kernel: convolution kernel
delta: Distance of pixel to the central pixel
Returns:
Kernel value at delta
"""
pix_delta = delta / np.abs(kernel["sky_cell_size"])
ikern = round(kernel["oversampling"] * pix_delta + kernel["bias"])
if ikern < 0 or ikern > kernel["kernel_size"] - 1:
return 0
else:
return kernel["kernel"][ikern]
@njit(cache=njit_caching, nogil=True)
def _compute_kernel_correction(kernel, grid_size):
"""
Compute kernel's fourier transform convolution correction
Args:
kernel: the convolution kernel
grid_size: the size of the output grid
Returns:
the convolution correction
"""
correction = np.zeros(grid_size)
ker_val = kernel["kernel"]
bias = kernel["bias"]
m_point = grid_size / 2 + 1
kw_coeff = np.pi / m_point / kernel["oversampling"]
for i_kern in range(ker_val.shape[0]):
if ker_val[i_kern] > 1e-30:
kx_coeff = kw_coeff * (i_kern - bias)
for i_corr in range(grid_size):
costerm = np.cos(kx_coeff * (i_corr - m_point))
correction[i_corr] += ker_val[i_kern] * costerm
return correction
def _compute_beam_size(diameter, frequency):
"""
Compute primary beam for diameter and frequency
Args:
diameter: telescope diameter
frequency: frequency of observation
Returns:
primary beam HPBW
"""
if isinstance(frequency, (np.ndarray, list, tuple)):
freq = frequency[0]
else:
freq = frequency
# This beam size is anchored at NOEMA beam measurements we might need a more general formula
size = 41 * (115e9 / freq) * np.sqrt(2.0) * (15.0 / diameter) * np.pi / 180 / 3600
return size
@njit(cache=njit_caching, nogil=True)
def _get_normalized_correction(u_corr, v_corr):
"""
Compute full grid convolution grid correction
Args:
u_corr: Correction over U axis
v_corr: Correction over V axis
Returns:
Normalized gridding correction (2D)
"""
u_size = u_corr.shape[0]
v_size = v_corr.shape[0]
u_mid = int(np.floor(u_size / 2) + 1)
v_mid = int(np.floor(v_size / 2) + 1)
norm_coeff = u_corr[u_mid] * v_corr[v_mid]
norm_corr = np.zeros((u_size, v_size), dtype=types.float64)
for i_u in range(u_size):
for i_v in range(v_size):
norm_corr[i_u, i_v] = u_corr[i_u] * v_corr[i_v] / norm_coeff
return norm_corr
@njit(cache=njit_caching, nogil=True)
def _gridding_correction_jit(aperture, beam_size, sky_cell_size, u_axis, v_axis):
"""
Actual convolution gridding correction numba jitted for speed
Args:
aperture: Aperture image grid
beam_size: Primary beam size (radians)
sky_cell_size: Size of the beam grid cell in the sky (radians)
u_axis: Aperture U axis
v_axis: Aperture V axis
Returns:
convolution corrected aperture
"""
l_kernel = _create_exponential_kernel(beam_size, sky_cell_size[0])
m_kernel = _create_exponential_kernel(beam_size, sky_cell_size[1])
u_corr = _compute_kernel_correction(l_kernel, u_axis.shape[0])
v_corr = _compute_kernel_correction(m_kernel, v_axis.shape[0])
norm_corr = _get_normalized_correction(u_corr, v_corr)
ntime, nchan, npol = aperture.shape[:3]
for i_time in range(ntime):
for i_chan in range(nchan):
for i_pol in range(npol):
aperture[i_time, i_chan, i_pol] /= norm_corr
return aperture