Skip to content

Commit 5cc6d9d

Browse files
committed
Replace unittest by pytest
1 parent bacdeb9 commit 5cc6d9d

8 files changed

+284
-285
lines changed

tests/test_calibrator/calibrator_solutions.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,10 @@
3030
=====================
3131
"""
3232
# Default model
33-
solution_default_df = pd.read_csv(cal_sol_default_path, float_precision='round_trip', converters=calibrator_sol_converters)
34-
solution_custom_df = pd.read_csv(cal_sol_custom_path, float_precision='round_trip', converters=calibrator_sol_converters)
33+
solution_default_df = pd.read_csv(cal_sol_default_path, float_precision='round_trip',
34+
converters=calibrator_sol_converters)
35+
solution_custom_df = pd.read_csv(cal_sol_custom_path, float_precision='round_trip',
36+
converters=calibrator_sol_converters)
3537
# v211w model
3638
solution_v211w_default_df = pd.read_csv(cal_sol_v211w_default_path, float_precision='round_trip',
3739
converters=calibrator_sol_converters)
Lines changed: 76 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
1-
import unittest
2-
31
import numpy as np
42
import numpy.testing as npt
3+
import pytest
54
from pandas import testing as pdt
65

76
from gaiaxpy import calibrate
@@ -12,11 +11,12 @@
1211
from gaiaxpy.input_reader.required_columns import MANDATORY_INPUT_COLS, CORR_INPUT_COLUMNS
1312
from gaiaxpy.spectrum.absolute_sampled_spectrum import AbsoluteSampledSpectrum
1413
from gaiaxpy.spectrum.sampled_basis_functions import SampledBasisFunctions
15-
from tests.files.paths import mean_spectrum_csv_file, mean_spectrum_avro_file, mean_spectrum_fits_file, \
16-
mean_spectrum_xml_file, mean_spectrum_xml_plain_file, mean_spectrum_ecsv_file
17-
from tests.test_calibrator.calibrator_solutions import solution_default_df, solution_custom_df, \
18-
solution_v211w_default_df, solution_v211w_custom_df, sol_custom_sampling_array, sol_v211w_default_sampling_array, \
19-
sol_default_sampling_array
14+
from tests.files.paths import (mean_spectrum_csv_file, mean_spectrum_avro_file, mean_spectrum_fits_file,
15+
mean_spectrum_xml_file, mean_spectrum_xml_plain_file, mean_spectrum_ecsv_file)
16+
from tests.test_calibrator.calibrator_solutions import (solution_default_df, solution_custom_df,
17+
solution_v211w_default_df, solution_v211w_custom_df,
18+
sol_custom_sampling_array, sol_v211w_default_sampling_array,
19+
sol_default_sampling_array)
2020
from tests.utils.utils import is_instance_err_message, npt_array_err_message
2121

2222
# Load variables
@@ -32,67 +32,72 @@
3232
mean_spectrum_xml_file, mean_spectrum_xml_plain_file]
3333

3434

35-
def generate_single_spectrum(mean_spectrum_path):
36-
# Read mean Spectrum
37-
parser = InternalContinuousParser(MANDATORY_INPUT_COLS['calibrate'] + CORR_INPUT_COLUMNS)
38-
parsed_spectrum_file, extension = parser.parse_file(mean_spectrum_path)
39-
# Create sampled basis functions
40-
sampled_basis_func = {band: SampledBasisFunctions.from_design_matrix(xp_sampling_grid, xp_design_matrices[band])
41-
for band in BANDS}
42-
first_row = parsed_spectrum_file.iloc[0]
43-
return _create_spectrum(first_row, truncation=False, design_matrix=sampled_basis_func, merge=xp_merge)
44-
45-
46-
class TestCalibrator(unittest.TestCase):
47-
48-
def test_create_spectrum(self):
49-
for file in cal_input_files:
50-
spectrum = generate_single_spectrum(file)
51-
instance = AbsoluteSampledSpectrum
52-
self.assertIsInstance(spectrum, instance, msg=is_instance_err_message(file, instance))
53-
54-
def test_calibrate_both_bands_default_calibration_model(self):
55-
# Default sampling and default calibration sampling
56-
for file in cal_input_files:
57-
spectra_df_csv, positions = calibrate(file, save_file=False)
58-
npt.assert_array_equal(positions, sol_default_sampling_array, err_msg=npt_array_err_message(file))
59-
# Assert_frame_equal doesn't seem to have a parameter to print an error message with details.
60-
pdt.assert_frame_equal(spectra_df_csv, solution_default_df, atol=_atol, rtol=_rtol)
61-
62-
def test_custom_sampling_default_calibration_model(self):
63-
for file in cal_input_files:
64-
spectra_df_custom_sampling, positions = calibrate(file, sampling=np.arange(350, 1050, 200), save_file=False)
65-
npt.assert_array_equal(positions, sol_custom_sampling_array, err_msg=npt_array_err_message(file))
66-
pdt.assert_frame_equal(spectra_df_custom_sampling, solution_custom_df, atol=_atol, rtol=_rtol)
67-
68-
def test_calibrate_both_bands_v211w_model(self):
69-
for file in cal_input_files:
70-
spectra_df_csv, positions = _calibrate(file, save_file=False, bp_model=bp_model)
71-
npt.assert_array_equal(positions, sol_v211w_default_sampling_array, err_msg=npt_array_err_message(file))
72-
pdt.assert_frame_equal(spectra_df_csv, solution_v211w_default_df, atol=_atol, rtol=_rtol)
73-
74-
def test_custom_sampling_v211w_model(self):
75-
for file in cal_input_files:
76-
spectra_df_custom_sampling, positions = _calibrate(file, sampling=np.arange(350, 1050, 200),
77-
save_file=False, bp_model=bp_model)
78-
npt.assert_array_equal(positions, sol_custom_sampling_array, err_msg=npt_array_err_message(file))
79-
pdt.assert_frame_equal(spectra_df_custom_sampling, solution_v211w_custom_df, atol=_atol, rtol=_rtol)
80-
81-
82-
class TestCalibratorSamplingRange(unittest.TestCase):
83-
84-
def test_sampling_wrong_array(self):
85-
with self.assertRaises(ValueError):
86-
calibrate(mean_spectrum_avro_file, sampling=np.linspace(800, 350, 600), save_file=False)
87-
88-
def test_sampling_low(self):
89-
with self.assertRaises(ValueError):
90-
calibrate(mean_spectrum_avro_file, sampling=np.linspace(300, 850, 300), save_file=False)
91-
92-
def test_sampling_high(self):
93-
with self.assertRaises(ValueError):
94-
calibrate(mean_spectrum_avro_file, sampling=np.linspace(350, 1500, 200), save_file=False)
95-
96-
def test_sampling_both_wrong(self):
97-
with self.assertRaises(ValueError):
98-
calibrate(mean_spectrum_avro_file, sampling=np.linspace(200, 2000, 100), save_file=False)
35+
@pytest.mark.parametrize('input_file', cal_input_files)
36+
def test_create_spectrum(input_file):
37+
def generate_single_spectrum(mean_spectrum_path):
38+
# Read mean Spectrum
39+
parser = InternalContinuousParser(MANDATORY_INPUT_COLS['calibrate'] + CORR_INPUT_COLUMNS)
40+
parsed_spectrum_file, extension = parser.parse_file(mean_spectrum_path)
41+
# Create sampled basis functions
42+
sampled_basis_func = {band: SampledBasisFunctions.from_design_matrix(xp_sampling_grid, xp_design_matrices[band])
43+
for band in BANDS}
44+
return _create_spectrum(parsed_spectrum_file.iloc[0], truncation=False, design_matrix=sampled_basis_func,
45+
merge=xp_merge)
46+
47+
instance = AbsoluteSampledSpectrum
48+
spectrum = generate_single_spectrum(input_file)
49+
assert isinstance(spectrum, instance), is_instance_err_message(input_file, instance)
50+
51+
52+
@pytest.mark.parametrize('input_file', cal_input_files)
53+
def test_calibrate_both_bands_default_calibration_model(input_file, request):
54+
# Default sampling and default calibration sampling
55+
spectra_df_csv, positions = calibrate(input_file, save_file=False)
56+
npt.assert_array_equal(positions, sol_default_sampling_array, err_msg=npt_array_err_message(input_file))
57+
try:
58+
pdt.assert_frame_equal(spectra_df_csv, solution_default_df, atol=_atol, rtol=_rtol)
59+
except AssertionError as e:
60+
print(f'{request.node.name} failed for file: {input_file}.')
61+
raise e
62+
63+
64+
@pytest.mark.parametrize('input_file', cal_input_files)
65+
def test_custom_sampling_default_calibration_model(input_file):
66+
spectra_df_custom_sampling, positions = calibrate(input_file, sampling=np.arange(350, 1050, 200), save_file=False)
67+
npt.assert_array_equal(positions, sol_custom_sampling_array, err_msg=npt_array_err_message(input_file))
68+
pdt.assert_frame_equal(spectra_df_custom_sampling, solution_custom_df, atol=_atol, rtol=_rtol)
69+
70+
71+
@pytest.mark.parametrize('input_file', cal_input_files)
72+
def test_calibrate_both_bands_v211w_model(input_file):
73+
spectra_df_csv, positions = _calibrate(input_file, save_file=False, bp_model=bp_model)
74+
npt.assert_array_equal(positions, sol_v211w_default_sampling_array, err_msg=npt_array_err_message(input_file))
75+
pdt.assert_frame_equal(spectra_df_csv, solution_v211w_default_df, atol=_atol, rtol=_rtol)
76+
77+
78+
@pytest.mark.parametrize('input_file', cal_input_files)
79+
def test_custom_sampling_v211w_model(input_file):
80+
spectra_df_custom_sampling, positions = _calibrate(input_file, sampling=np.arange(350, 1050, 200), save_file=False,
81+
bp_model=bp_model)
82+
npt.assert_array_equal(positions, sol_custom_sampling_array, err_msg=npt_array_err_message(input_file))
83+
pdt.assert_frame_equal(spectra_df_custom_sampling, solution_v211w_custom_df, atol=_atol, rtol=_rtol)
84+
85+
86+
def test_sampling_wrong_array():
87+
with pytest.raises(ValueError):
88+
calibrate(mean_spectrum_avro_file, sampling=np.linspace(800, 350, 600), save_file=False)
89+
90+
91+
def test_sampling_low():
92+
with pytest.raises(ValueError):
93+
calibrate(mean_spectrum_avro_file, sampling=np.linspace(300, 850, 300), save_file=False)
94+
95+
96+
def test_sampling_high():
97+
with pytest.raises(ValueError):
98+
calibrate(mean_spectrum_avro_file, sampling=np.linspace(350, 1500, 200), save_file=False)
99+
100+
101+
def test_sampling_both_wrong():
102+
with pytest.raises(ValueError):
103+
calibrate(mean_spectrum_avro_file, sampling=np.linspace(200, 2000, 100), save_file=False)
Lines changed: 49 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,58 +1,59 @@
1-
import unittest
2-
31
import numpy.testing as npt
42
import pandas.testing as pdt
3+
import pytest
54

65
from gaiaxpy import calibrate
7-
from tests.test_calibrator.calibrator_solutions import solution_default_df, sol_with_missing_sampling_array, \
8-
sol_default_sampling_array, with_missing_solution_df, missing_solution_df
6+
from tests.test_calibrator.calibrator_solutions import (solution_default_df, sol_with_missing_sampling_array,
7+
sol_default_sampling_array, with_missing_solution_df,
8+
missing_solution_df)
99
from tests.utils.utils import missing_bp_source_id
1010

1111
_rtol = 1e-10
1212
_atol = 1e-10
1313

1414

15-
class TestCalibratorSingleElement(unittest.TestCase):
16-
17-
def test_single_element_query(self):
18-
query = "SELECT * FROM gaiadr3.gaia_source WHERE source_id='5853498713190525696'"
19-
output_df, sampling = calibrate(query, save_file=False)
20-
source_data_output = output_df[output_df['source_id'] == 5853498713190525696]
21-
source_data_solution = solution_default_df[solution_default_df['source_id'] == 5853498713190525696]
22-
pdt.assert_frame_equal(source_data_output, source_data_solution, atol=_atol, rtol=_rtol)
23-
npt.assert_array_equal(sampling, sol_default_sampling_array)
24-
25-
26-
class TestCalibratorMissingBPQueryInput(unittest.TestCase):
27-
28-
def test_missing_bp_query(self):
29-
query = f"SELECT * FROM gaiadr3.gaia_source WHERE source_id IN ('5853498713190525696', " \
30-
f"{missing_bp_source_id}, '5762406957886626816')"
31-
output_df, sampling = calibrate(query, save_file=False)
32-
sorted_output_df = output_df.sort_values('source_id', ignore_index=True)
33-
sorted_solution_df = with_missing_solution_df.sort_values('source_id', ignore_index=True)
34-
pdt.assert_frame_equal(sorted_output_df, sorted_solution_df, atol=_atol, rtol=_rtol)
35-
npt.assert_array_equal(sampling, sol_with_missing_sampling_array)
36-
37-
def test_missing_bp_query_isolated(self):
38-
query = f"SELECT * FROM gaiadr3.gaia_source WHERE source_id IN ({missing_bp_source_id})"
39-
output_df, sampling = calibrate(query, save_file=False)
40-
pdt.assert_frame_equal(output_df, missing_solution_df, atol=_atol, rtol=_rtol)
41-
npt.assert_array_equal(sampling, sol_with_missing_sampling_array)
42-
43-
44-
class TestCalibratorMissingBPListInput(unittest.TestCase):
45-
46-
def test_missing_bp_list(self):
47-
src_list = ['5853498713190525696', str(missing_bp_source_id), '5762406957886626816']
48-
output_df, sampling = calibrate(src_list, save_file=False)
49-
sorted_output_df = output_df.sort_values('source_id', ignore_index=True)
50-
sorted_solution_df = with_missing_solution_df.sort_values('source_id', ignore_index=True)
51-
pdt.assert_frame_equal(sorted_output_df, sorted_solution_df, check_dtype=False, atol=_atol, rtol=_rtol)
52-
npt.assert_array_equal(sampling, sol_with_missing_sampling_array)
53-
54-
def test_missing_bp_isolated(self):
55-
src_list = [missing_bp_source_id]
56-
output_df, sampling = calibrate(src_list, save_file=False)
57-
pdt.assert_frame_equal(output_df, missing_solution_df, check_dtype=False, atol=_atol, rtol=_rtol)
58-
npt.assert_array_equal(sampling, sol_with_missing_sampling_array)
15+
@pytest.mark.archive
16+
def test_single_element_query():
17+
query = "SELECT * FROM gaiadr3.gaia_source WHERE source_id='5853498713190525696'"
18+
output_df, sampling = calibrate(query, save_file=False)
19+
source_data_output = output_df[output_df['source_id'] == 5853498713190525696]
20+
source_data_solution = solution_default_df[solution_default_df['source_id'] == 5853498713190525696]
21+
pdt.assert_frame_equal(source_data_output, source_data_solution, atol=_atol, rtol=_rtol)
22+
npt.assert_array_equal(sampling, sol_default_sampling_array)
23+
24+
25+
@pytest.mark.archive
26+
def test_missing_bp_query():
27+
query = f"SELECT * FROM gaiadr3.gaia_source WHERE source_id IN ('5853498713190525696', " \
28+
f"{missing_bp_source_id}, '5762406957886626816')"
29+
output_df, sampling = calibrate(query, save_file=False)
30+
sorted_output_df = output_df.sort_values('source_id', ignore_index=True)
31+
sorted_solution_df = with_missing_solution_df.sort_values('source_id', ignore_index=True)
32+
pdt.assert_frame_equal(sorted_output_df, sorted_solution_df, atol=_atol, rtol=_rtol)
33+
npt.assert_array_equal(sampling, sol_with_missing_sampling_array)
34+
35+
36+
@pytest.mark.archive
37+
def test_missing_bp_query_isolated():
38+
query = f"SELECT * FROM gaiadr3.gaia_source WHERE source_id IN ({missing_bp_source_id})"
39+
output_df, sampling = calibrate(query, save_file=False)
40+
pdt.assert_frame_equal(output_df, missing_solution_df, atol=_atol, rtol=_rtol)
41+
npt.assert_array_equal(sampling, sol_with_missing_sampling_array)
42+
43+
44+
@pytest.mark.archive
45+
def test_missing_bp_list():
46+
src_list = ['5853498713190525696', str(missing_bp_source_id), '5762406957886626816']
47+
output_df, sampling = calibrate(src_list, save_file=False)
48+
sorted_output_df = output_df.sort_values('source_id', ignore_index=True)
49+
sorted_solution_df = with_missing_solution_df.sort_values('source_id', ignore_index=True)
50+
pdt.assert_frame_equal(sorted_output_df, sorted_solution_df, check_dtype=False, atol=_atol, rtol=_rtol)
51+
npt.assert_array_equal(sampling, sol_with_missing_sampling_array)
52+
53+
54+
@pytest.mark.archive
55+
def test_missing_bp_isolated():
56+
src_list = [missing_bp_source_id]
57+
output_df, sampling = calibrate(src_list, save_file=False)
58+
pdt.assert_frame_equal(output_df, missing_solution_df, check_dtype=False, atol=_atol, rtol=_rtol)
59+
npt.assert_array_equal(sampling, sol_with_missing_sampling_array)

0 commit comments

Comments
 (0)