import contextlib
import inspect
import io
import numpy as np
import xarray as xr
import xarray.testing
from PIL import Image, ImageChops
from astrohack.utils.fits import read_fits_no_checks
from astrohack.utils.package_info import get_astrohack_version
[docs]
def are_lists_equal(list_a, list_b):
n_a = len(list_a)
n_b = len(list_b)
if n_a != n_b:
return False
else:
equal = True
for item in list_a:
equal = equal and item in list_b
return equal
[docs]
def are_fits_files_close(fits_path1, fits_path2, tol=1e-5):
head1, data1 = read_fits_no_checks(fits_path1)
head2, data2 = read_fits_no_checks(fits_path2)
if are_dicts_close(head1, head2, tol=tol, ignored_keys=["DATE", "ORIGIN"]):
return np.allclose(data1, data2, equal_nan=True, atol=tol)
else:
return False
[docs]
def are_png_files_close(img_path1, img_path2, tol=1e-5):
try:
# Open images (Pillow handles various modes and removes metadata concerns for pixel data)
with Image.open(img_path1) as img1, Image.open(img_path2) as img2:
# Ensure both images are in the same mode for a reliable comparison (e.g., 'RGBA')
img1 = img1.convert("RGBA")
img2 = img2.convert("RGBA")
# Check if dimensions are the same
if img1.size != img2.size:
return False, f"PNG sizes differ"
# Calculate the difference between the images
# This results in a new image where differing pixels are non-zero
diff = ImageChops.difference(img1, img2)
mean_diff = np.mean(diff)
return (
np.abs(mean_diff) < tol,
f"Mean diff: {float(np.mean(np.absolute(diff)))}",
)
except IOError as e:
print(f"Error opening images: {e}")
return False, f"Failed opening images"
[docs]
def capture_prints_from_function(function, args=None):
# Use redirect_stdout to capture the function's output
output_capture = io.StringIO()
with contextlib.redirect_stdout(output_capture):
if args is None:
function()
else:
function(*args)
# Get the captured output as a string
return output_capture.getvalue()
[docs]
def are_txt_files_equal(txt_path1, txt_path2, ignored_key_words=()):
with open(txt_path1, "r") as txt_file1:
txt1_content = txt_file1.read()
with open(txt_path2, "r") as txt_file2:
txt2_content = txt_file2.read()
txt1_lines = txt1_content.splitlines()
txt2_lines = txt2_content.splitlines()
if len(txt1_lines) != len(txt2_lines):
return False
else:
for i_line, line1 in enumerate(txt1_lines):
line2 = txt2_lines[i_line]
if line1.strip() == line2.strip():
continue
else:
found_ignored_keywords = False
for key_word in ignored_key_words:
if key_word in line1 and key_word in line2:
found_ignored_keywords = True
break
if found_ignored_keywords:
continue
else:
return False
return True
[docs]
def is_captured_output_equal_to_txt_reference(function, txt_ref, args=None):
captured_output = capture_prints_from_function(function, args)
with open(txt_ref, "r") as ref_file:
ref_content = ref_file.read()
return ref_content == captured_output
def _get_ds_metadata(ds):
if hasattr(ds, "_input_pars"):
metadata = getattr(ds, "_input_pars")
elif isinstance(ds, xr.Dataset) or isinstance(ds, xr.DataTree):
metadata = getattr(ds, "attrs")
else:
metadata = ds.root.attrs
return metadata
[docs]
def are_dicts_close(dict_a, dict_b, tol=1e-8, ignored_keys=None):
"""
Compares dictionaries and returns True if data is close up to tolerance.
:param dict_a: First dictionary
:type dict_a: dict
:param dict_b: Second dictionary
:type dict_b: dict
:param tol: Tolerance
:type tol: float
:param ignored_keys: Keys to be ignored in comparison
:type ignored_keys: list, NoneType
:return: is_close
:rtype: bool
"""
if ignored_keys is None:
ignored_keys = []
is_close = True
b_keys = list(dict_b.keys())
for key, a_value in dict_a.items():
if key in ignored_keys:
mode = "ignored"
else:
if key in b_keys:
key_in_b = True
elif str(key) in b_keys:
key_in_b = True
key = str(key)
else:
try:
b_type = type(b_keys[0])
if b_type(key) in b_keys:
key_in_b = True
key = b_type(key)
else:
key_in_b = False
except TypeError:
key_in_b = False
if key_in_b:
b_value = dict_b[key]
if isinstance(a_value, dict) and isinstance(b_value, dict):
is_close = is_close and are_dicts_close(
a_value, b_value, tol=tol, ignored_keys=None
)
mode = "dict"
elif type(a_value) is not type(b_value):
is_close = False
mode = "type diff"
elif a_value is None:
is_close = is_close and (b_value is None)
mode = "None"
elif isinstance(a_value, (np.ndarray, float, int)):
mode = "nparray"
is_close = is_close and np.allclose(
a_value, b_value, equal_nan=True, rtol=tol
)
elif isinstance(a_value, str):
mode = "str"
is_close = is_close and a_value == b_value
elif isinstance(a_value, (list, tuple)):
if isinstance(a_value[0], (float, int)):
mode = "number list"
is_close = is_close and np.allclose(
np.array(a_value),
np.array(b_value),
equal_nan=True,
rtol=tol,
)
else:
mode = "obj list"
is_close = is_close and a_value == b_value
else:
raise TypeError(f"Unrecognized type {type(a_value)}")
else:
mode = "missing"
is_close = False
if not is_close:
print(f"Key: {key} => {mode}")
return False
# print(f"Key: {key} => {mode} == {is_close}")
return is_close
[docs]
def are_data_trees_close(tree_a, tree_b, tol=1e-8):
"""
Compares data trees and returns True if data is close up to tolerance.
:param tree_a: First data tree
:type tree_a: xarray.DataTree
:param tree_b: Second data tree
:type tree_b: xarray.DataTree
:param tol: Tolerance
:type tol: float
:return: is_close
:rtype: bool
"""
is_close = True
if are_dicts_close(tree_a.attrs, tree_b.attrs, tol=tol):
try:
xarray.testing.assert_allclose(tree_a.dataset, tree_b.dataset, rtol=tol)
except AssertionError:
print(f"Failed dataset comparison at {tree_a.name}")
return False
if tree_a.is_leaf and tree_b.is_leaf:
pass
else:
a_key_list = list(tree_a.keys())
b_key_list = list(tree_b.keys())
if not are_lists_equal(a_key_list, b_key_list):
print(f"Differing key lists at {tree_a.name}")
return False
for key, subtree_a in tree_a.items():
is_close = is_close and are_data_trees_close(
subtree_a, tree_b[key], tol
)
if not is_close:
print(f"Failed key = {key}")
return False
else:
return False
return is_close
[docs]
def add_data_folder_to_names_in_class(class_ref):
# Add datafolder to names for execution
for varname, varvalue in class_ref.__dict__.items():
if isinstance(varvalue, str):
if varname.split("_")[-1] == "name":
setattr(class_ref, varname, f"{class_ref.data_dir}/{varvalue}")
[docs]
def relative_difference(result, expected):
return 2 * np.abs(result - expected) / (abs(result) + abs(expected))
[docs]
def analyse_summary(mds_obj, exp_file_name, exp_input_pars, exp_ant_keys_list):
"""analyse summary file"""
summ_str = capture_prints_from_function(mds_obj.summary)
n_input_pars = len(exp_input_pars.keys())
i_start_input = 15
i_end_input = i_start_input + n_input_pars
i_start_orig = 6
i_end_orig = 9
inside_method_table = False
inside_antenna_table = False
exp_orig_info = create_origin_dict("test_summary")
method_list = inspect.getmembers(mds_obj, predicate=inspect.ismethod)
exp_method_list = []
for name, method in method_list:
if name[0] == "_":
continue
else:
exp_method_list.append(name)
this_method_list = []
this_ant_list = []
this_input_pars = {}
this_orig_info = {}
this_filename = None
for i_line, line in enumerate(summ_str.splitlines()):
if i_line == 2:
this_filename = line.split()[1]
if i_start_orig <= i_line < i_end_orig:
wrds = line.split(":")
key = wrds[0].strip()
value = wrds[1].strip()
this_orig_info[key] = value
elif i_start_input <= i_line < i_end_input:
wrds = line.split("|")
key = wrds[1].strip()
value = wrds[2].strip()
this_input_pars[key] = value
elif line.strip() == "Available methods:":
inside_method_table = True
elif inside_method_table:
if line.strip() == "":
inside_method_table = False
elif line[0] == "+":
pass
else:
method_wrd = line.split("|")[1].strip()
if method_wrd == "" or method_wrd == "Methods":
pass
else:
this_method_list.append(method_wrd)
elif line.strip() == "Data Contents:":
inside_antenna_table = True
elif inside_antenna_table:
if line.strip() == "":
inside_antenna_table = False
elif line[0] == "+":
pass
else:
ant_wrd = line.split("|")[1].strip()
if ant_wrd == "" or ant_wrd == "Antenna":
pass
else:
this_ant_list.append(ant_wrd)
else:
pass
assert (
this_filename == exp_file_name
), "File name in Summary should be equal to the expected one"
assert are_dicts_close(
this_input_pars, exp_input_pars
), "Input parameter dictionaries should identical"
assert are_dicts_close(
this_orig_info, exp_orig_info
), "Origin info dictionaries should be identical"
assert are_lists_equal(
this_ant_list, exp_ant_keys_list
), "Antenna list should be equal to the expected one"
assert are_lists_equal(
this_method_list, exp_method_list
), "Method list should be equal to the expected one"
[docs]
def create_origin_dict(caller):
orig_dict = {
"origin": "astrohack",
"version": get_astrohack_version(),
"creator_function": caller,
}
return orig_dict