|
1 |
| -import unittest |
2 |
| - |
3 | 1 | import numpy as np
|
4 | 2 | import numpy.testing as npt
|
| 3 | +import pytest |
5 | 4 | from pandas import testing as pdt
|
6 | 5 |
|
7 | 6 | from gaiaxpy import calibrate
|
|
12 | 11 | from gaiaxpy.input_reader.required_columns import MANDATORY_INPUT_COLS, CORR_INPUT_COLUMNS
|
13 | 12 | from gaiaxpy.spectrum.absolute_sampled_spectrum import AbsoluteSampledSpectrum
|
14 | 13 | from gaiaxpy.spectrum.sampled_basis_functions import SampledBasisFunctions
|
15 |
| -from tests.files.paths import mean_spectrum_csv_file, mean_spectrum_avro_file, mean_spectrum_fits_file, \ |
16 |
| - mean_spectrum_xml_file, mean_spectrum_xml_plain_file, mean_spectrum_ecsv_file |
17 |
| -from tests.test_calibrator.calibrator_solutions import solution_default_df, solution_custom_df, \ |
18 |
| - solution_v211w_default_df, solution_v211w_custom_df, sol_custom_sampling_array, sol_v211w_default_sampling_array, \ |
19 |
| - sol_default_sampling_array |
| 14 | +from tests.files.paths import (mean_spectrum_csv_file, mean_spectrum_avro_file, mean_spectrum_fits_file, |
| 15 | + mean_spectrum_xml_file, mean_spectrum_xml_plain_file, mean_spectrum_ecsv_file) |
| 16 | +from tests.test_calibrator.calibrator_solutions import (solution_default_df, solution_custom_df, |
| 17 | + solution_v211w_default_df, solution_v211w_custom_df, |
| 18 | + sol_custom_sampling_array, sol_v211w_default_sampling_array, |
| 19 | + sol_default_sampling_array) |
20 | 20 | from tests.utils.utils import is_instance_err_message, npt_array_err_message
|
21 | 21 |
|
22 | 22 | # Load variables
|
|
32 | 32 | mean_spectrum_xml_file, mean_spectrum_xml_plain_file]
|
33 | 33 |
|
34 | 34 |
|
35 |
| -def generate_single_spectrum(mean_spectrum_path): |
36 |
| - # Read mean Spectrum |
37 |
| - parser = InternalContinuousParser(MANDATORY_INPUT_COLS['calibrate'] + CORR_INPUT_COLUMNS) |
38 |
| - parsed_spectrum_file, extension = parser.parse_file(mean_spectrum_path) |
39 |
| - # Create sampled basis functions |
40 |
| - sampled_basis_func = {band: SampledBasisFunctions.from_design_matrix(xp_sampling_grid, xp_design_matrices[band]) |
41 |
| - for band in BANDS} |
42 |
| - first_row = parsed_spectrum_file.iloc[0] |
43 |
| - return _create_spectrum(first_row, truncation=False, design_matrix=sampled_basis_func, merge=xp_merge) |
44 |
| - |
45 |
| - |
46 |
| -class TestCalibrator(unittest.TestCase): |
47 |
| - |
48 |
| - def test_create_spectrum(self): |
49 |
| - for file in cal_input_files: |
50 |
| - spectrum = generate_single_spectrum(file) |
51 |
| - instance = AbsoluteSampledSpectrum |
52 |
| - self.assertIsInstance(spectrum, instance, msg=is_instance_err_message(file, instance)) |
53 |
| - |
54 |
| - def test_calibrate_both_bands_default_calibration_model(self): |
55 |
| - # Default sampling and default calibration sampling |
56 |
| - for file in cal_input_files: |
57 |
| - spectra_df_csv, positions = calibrate(file, save_file=False) |
58 |
| - npt.assert_array_equal(positions, sol_default_sampling_array, err_msg=npt_array_err_message(file)) |
59 |
| - # Assert_frame_equal doesn't seem to have a parameter to print an error message with details. |
60 |
| - pdt.assert_frame_equal(spectra_df_csv, solution_default_df, atol=_atol, rtol=_rtol) |
61 |
| - |
62 |
| - def test_custom_sampling_default_calibration_model(self): |
63 |
| - for file in cal_input_files: |
64 |
| - spectra_df_custom_sampling, positions = calibrate(file, sampling=np.arange(350, 1050, 200), save_file=False) |
65 |
| - npt.assert_array_equal(positions, sol_custom_sampling_array, err_msg=npt_array_err_message(file)) |
66 |
| - pdt.assert_frame_equal(spectra_df_custom_sampling, solution_custom_df, atol=_atol, rtol=_rtol) |
67 |
| - |
68 |
| - def test_calibrate_both_bands_v211w_model(self): |
69 |
| - for file in cal_input_files: |
70 |
| - spectra_df_csv, positions = _calibrate(file, save_file=False, bp_model=bp_model) |
71 |
| - npt.assert_array_equal(positions, sol_v211w_default_sampling_array, err_msg=npt_array_err_message(file)) |
72 |
| - pdt.assert_frame_equal(spectra_df_csv, solution_v211w_default_df, atol=_atol, rtol=_rtol) |
73 |
| - |
74 |
| - def test_custom_sampling_v211w_model(self): |
75 |
| - for file in cal_input_files: |
76 |
| - spectra_df_custom_sampling, positions = _calibrate(file, sampling=np.arange(350, 1050, 200), |
77 |
| - save_file=False, bp_model=bp_model) |
78 |
| - npt.assert_array_equal(positions, sol_custom_sampling_array, err_msg=npt_array_err_message(file)) |
79 |
| - pdt.assert_frame_equal(spectra_df_custom_sampling, solution_v211w_custom_df, atol=_atol, rtol=_rtol) |
80 |
| - |
81 |
| - |
82 |
| -class TestCalibratorSamplingRange(unittest.TestCase): |
83 |
| - |
84 |
| - def test_sampling_wrong_array(self): |
85 |
| - with self.assertRaises(ValueError): |
86 |
| - calibrate(mean_spectrum_avro_file, sampling=np.linspace(800, 350, 600), save_file=False) |
87 |
| - |
88 |
| - def test_sampling_low(self): |
89 |
| - with self.assertRaises(ValueError): |
90 |
| - calibrate(mean_spectrum_avro_file, sampling=np.linspace(300, 850, 300), save_file=False) |
91 |
| - |
92 |
| - def test_sampling_high(self): |
93 |
| - with self.assertRaises(ValueError): |
94 |
| - calibrate(mean_spectrum_avro_file, sampling=np.linspace(350, 1500, 200), save_file=False) |
95 |
| - |
96 |
| - def test_sampling_both_wrong(self): |
97 |
| - with self.assertRaises(ValueError): |
98 |
| - calibrate(mean_spectrum_avro_file, sampling=np.linspace(200, 2000, 100), save_file=False) |
| 35 | +@pytest.mark.parametrize('input_file', cal_input_files) |
| 36 | +def test_create_spectrum(input_file): |
| 37 | + def generate_single_spectrum(mean_spectrum_path): |
| 38 | + # Read mean Spectrum |
| 39 | + parser = InternalContinuousParser(MANDATORY_INPUT_COLS['calibrate'] + CORR_INPUT_COLUMNS) |
| 40 | + parsed_spectrum_file, extension = parser.parse_file(mean_spectrum_path) |
| 41 | + # Create sampled basis functions |
| 42 | + sampled_basis_func = {band: SampledBasisFunctions.from_design_matrix(xp_sampling_grid, xp_design_matrices[band]) |
| 43 | + for band in BANDS} |
| 44 | + return _create_spectrum(parsed_spectrum_file.iloc[0], truncation=False, design_matrix=sampled_basis_func, |
| 45 | + merge=xp_merge) |
| 46 | + |
| 47 | + instance = AbsoluteSampledSpectrum |
| 48 | + spectrum = generate_single_spectrum(input_file) |
| 49 | + assert isinstance(spectrum, instance), is_instance_err_message(input_file, instance) |
| 50 | + |
| 51 | + |
| 52 | +@pytest.mark.parametrize('input_file', cal_input_files) |
| 53 | +def test_calibrate_both_bands_default_calibration_model(input_file, request): |
| 54 | + # Default sampling and default calibration sampling |
| 55 | + spectra_df_csv, positions = calibrate(input_file, save_file=False) |
| 56 | + npt.assert_array_equal(positions, sol_default_sampling_array, err_msg=npt_array_err_message(input_file)) |
| 57 | + try: |
| 58 | + pdt.assert_frame_equal(spectra_df_csv, solution_default_df, atol=_atol, rtol=_rtol) |
| 59 | + except AssertionError as e: |
| 60 | + print(f'{request.node.name} failed for file: {input_file}.') |
| 61 | + raise e |
| 62 | + |
| 63 | + |
| 64 | +@pytest.mark.parametrize('input_file', cal_input_files) |
| 65 | +def test_custom_sampling_default_calibration_model(input_file): |
| 66 | + spectra_df_custom_sampling, positions = calibrate(input_file, sampling=np.arange(350, 1050, 200), save_file=False) |
| 67 | + npt.assert_array_equal(positions, sol_custom_sampling_array, err_msg=npt_array_err_message(input_file)) |
| 68 | + pdt.assert_frame_equal(spectra_df_custom_sampling, solution_custom_df, atol=_atol, rtol=_rtol) |
| 69 | + |
| 70 | + |
| 71 | +@pytest.mark.parametrize('input_file', cal_input_files) |
| 72 | +def test_calibrate_both_bands_v211w_model(input_file): |
| 73 | + spectra_df_csv, positions = _calibrate(input_file, save_file=False, bp_model=bp_model) |
| 74 | + npt.assert_array_equal(positions, sol_v211w_default_sampling_array, err_msg=npt_array_err_message(input_file)) |
| 75 | + pdt.assert_frame_equal(spectra_df_csv, solution_v211w_default_df, atol=_atol, rtol=_rtol) |
| 76 | + |
| 77 | + |
| 78 | +@pytest.mark.parametrize('input_file', cal_input_files) |
| 79 | +def test_custom_sampling_v211w_model(input_file): |
| 80 | + spectra_df_custom_sampling, positions = _calibrate(input_file, sampling=np.arange(350, 1050, 200), save_file=False, |
| 81 | + bp_model=bp_model) |
| 82 | + npt.assert_array_equal(positions, sol_custom_sampling_array, err_msg=npt_array_err_message(input_file)) |
| 83 | + pdt.assert_frame_equal(spectra_df_custom_sampling, solution_v211w_custom_df, atol=_atol, rtol=_rtol) |
| 84 | + |
| 85 | + |
| 86 | +def test_sampling_wrong_array(): |
| 87 | + with pytest.raises(ValueError): |
| 88 | + calibrate(mean_spectrum_avro_file, sampling=np.linspace(800, 350, 600), save_file=False) |
| 89 | + |
| 90 | + |
| 91 | +def test_sampling_low(): |
| 92 | + with pytest.raises(ValueError): |
| 93 | + calibrate(mean_spectrum_avro_file, sampling=np.linspace(300, 850, 300), save_file=False) |
| 94 | + |
| 95 | + |
| 96 | +def test_sampling_high(): |
| 97 | + with pytest.raises(ValueError): |
| 98 | + calibrate(mean_spectrum_avro_file, sampling=np.linspace(350, 1500, 200), save_file=False) |
| 99 | + |
| 100 | + |
| 101 | +def test_sampling_both_wrong(): |
| 102 | + with pytest.raises(ValueError): |
| 103 | + calibrate(mean_spectrum_avro_file, sampling=np.linspace(200, 2000, 100), save_file=False) |
0 commit comments