Skip to content

Commit 6a0de80

Browse files
authored
Add write_dir argument to csv_to_wfdb. Fixes #67. (#492)
As discussed in #490, https://github.com/MIT-LCP/wfdb-python/blob/34b989e08435c1a82d31bdd2800c4c14147e3e93/wfdb/io/convert/csv.py#L10 currently "strips the path from the input .csv, then writes the output to .dat and .hea". It's inconvenient not to be able to specify the output directory. This pull request adds a new `output_dir` argument to the `csv_to_wfdb` function. By default `output_dir` is set to None, which will maintain backwards compatibility. Setting `output_dir` to a directory will mean that output files are saved to this directory. I have set this to a WIP, because I haven't tested the new behaviour (other than running `pytest`). @jshaffer94247, if you have an opportunity to test the fix, I'd appreciate your feedback.
2 parents edb5f12 + f3d633d commit 6a0de80

File tree

2 files changed

+84
-11
lines changed

2 files changed

+84
-11
lines changed

tests/io/test_convert.py

Lines changed: 72 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,22 @@
1+
import os
2+
import shutil
3+
import unittest
4+
15
import numpy as np
26

37
from wfdb.io.record import rdrecord
48
from wfdb.io.convert.edf import read_edf
9+
from wfdb.io.convert.csv import csv_to_wfdb
10+
511

12+
class TestEdfToWfdb:
13+
"""
14+
Tests for the io.convert.edf module.
15+
"""
616

7-
class TestConvert:
817
def test_edf_uniform(self):
918
"""
1019
EDF format conversion to MIT for uniform sample rates.
11-
1220
"""
1321
# Uniform sample rates
1422
record_MIT = rdrecord("sample-data/n16").__dict__
@@ -60,7 +68,6 @@ def test_edf_uniform(self):
6068
def test_edf_non_uniform(self):
6169
"""
6270
EDF format conversion to MIT for non-uniform sample rates.
63-
6471
"""
6572
# Non-uniform sample rates
6673
record_MIT = rdrecord("sample-data/wave_4").__dict__
@@ -108,3 +115,65 @@ def test_edf_non_uniform(self):
108115

109116
target_results = len(fields) * [True]
110117
assert np.array_equal(test_results, target_results)
118+
119+
120+
class TestCsvToWfdb(unittest.TestCase):
121+
"""
122+
Tests for the io.convert.csv module.
123+
"""
124+
125+
def setUp(self):
126+
"""
127+
Create a temporary directory containing data for testing.
128+
129+
Load 100.dat file for comparison to 100.csv file.
130+
"""
131+
self.test_dir = "test_output"
132+
os.makedirs(self.test_dir, exist_ok=True)
133+
134+
self.record_100_csv = "sample-data/100.csv"
135+
self.record_100_dat = rdrecord("sample-data/100", physical=True)
136+
137+
def tearDown(self):
138+
"""
139+
Remove the temporary directory after the test.
140+
"""
141+
if os.path.exists(self.test_dir):
142+
shutil.rmtree(self.test_dir)
143+
144+
def test_write_dir(self):
145+
"""
146+
Call the function with the write_dir argument.
147+
"""
148+
csv_to_wfdb(
149+
file_name=self.record_100_csv,
150+
fs=360,
151+
units="mV",
152+
write_dir=self.test_dir,
153+
)
154+
155+
# Check if the output files are created in the specified directory
156+
base_name = os.path.splitext(os.path.basename(self.record_100_csv))[0]
157+
expected_dat_file = os.path.join(self.test_dir, f"{base_name}.dat")
158+
expected_hea_file = os.path.join(self.test_dir, f"{base_name}.hea")
159+
160+
self.assertTrue(os.path.exists(expected_dat_file))
161+
self.assertTrue(os.path.exists(expected_hea_file))
162+
163+
# Check that newly written file matches the 100.dat file
164+
record_write = rdrecord(os.path.join(self.test_dir, base_name))
165+
166+
self.assertEqual(record_write.fs, 360)
167+
self.assertEqual(record_write.fs, self.record_100_dat.fs)
168+
self.assertEqual(record_write.units, ["mV", "mV"])
169+
self.assertEqual(record_write.units, self.record_100_dat.units)
170+
self.assertEqual(record_write.sig_name, ["MLII", "V5"])
171+
self.assertEqual(record_write.sig_name, self.record_100_dat.sig_name)
172+
self.assertEqual(record_write.p_signal.size, 1300000)
173+
self.assertEqual(
174+
record_write.p_signal.size, self.record_100_dat.p_signal.size
175+
)
176+
177+
178+
if __name__ == "__main__":
179+
unittest.main()

wfdb/io/convert/csv.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ def csv_to_wfdb(
3333
header=True,
3434
delimiter=",",
3535
verbose=False,
36+
write_dir="",
3637
):
3738
"""
3839
Read a WFDB header file and return either a `Record` object with the
@@ -235,6 +236,10 @@ def csv_to_wfdb(
235236
verbose : bool, optional
236237
Whether to print all the information read about the file (True) or
237238
not (False).
239+
write_dir : str, optional
240+
The directory where the output files will be saved. If write_dir is not
241+
provided, the output files will be saved in the same directory as the
242+
input file.
238243
239244
Returns
240245
-------
@@ -291,6 +296,7 @@ def csv_to_wfdb(
291296
df_CSV = pd.read_csv(file_name, delimiter=delimiter, header=None)
292297
if verbose:
293298
print("Successfully read CSV")
299+
294300
# Extract the entire signal from the dataframe
295301
p_signal = df_CSV.values
296302
# The dataframe should be in (`sig_len`, `n_sig`) dimensions
@@ -300,10 +306,11 @@ def csv_to_wfdb(
300306
n_sig = p_signal.shape[1]
301307
if verbose:
302308
print("Number of signals: {}".format(n_sig))
309+
303310
# Check if signal names are valid and set defaults
304311
if not sig_name:
305312
if header:
306-
sig_name = df_CSV.columns.to_list()
313+
sig_name = df_CSV.columns.tolist()
307314
if any(map(str.isdigit, sig_name)):
308315
print(
309316
"WARNING: One or more of your signal names are numbers, this "
@@ -318,15 +325,12 @@ def csv_to_wfdb(
318325
if verbose:
319326
print("Signal names: {}".format(sig_name))
320327

321-
# Set the output header file name to be the same, remove path
322-
if os.sep in file_name:
323-
file_name = file_name.split(os.sep)[-1]
324-
record_name = file_name.replace(".csv", "")
328+
record_name = os.path.splitext(os.path.basename(file_name))[0]
325329
if verbose:
326-
print("Output header: {}.hea".format(record_name))
330+
print("Record name: {}.hea".format(record_name))
327331

328332
# Replace the CSV file tag with DAT
329-
dat_file_name = file_name.replace(".csv", ".dat")
333+
dat_file_name = record_name + ".dat"
330334
dat_file_name = [dat_file_name] * n_sig
331335
if verbose:
332336
print("Output record: {}".format(dat_file_name[0]))
@@ -450,7 +454,6 @@ def csv_to_wfdb(
450454
if verbose:
451455
print("Record generated successfully")
452456
return record
453-
454457
else:
455458
# Write the information to a record and header file
456459
wrsamp(
@@ -465,6 +468,7 @@ def csv_to_wfdb(
465468
comments=comments,
466469
base_time=base_time,
467470
base_date=base_date,
471+
write_dir=write_dir,
468472
)
469473
if verbose:
470474
print("File generated successfully")

0 commit comments

Comments
 (0)