Skip to content

Commit 32d5e6f

Browse files
Fix covariance file writer
1 parent 06c5bb1 commit 32d5e6f

File tree

2 files changed

+105
-17
lines changed

2 files changed

+105
-17
lines changed

src/gwas/src/gwas/pheno.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from .compression.arr.base import (
1111
CompressionMethod,
1212
FileArray,
13+
FileArrayWriter,
1314
)
1415
from .log import logger
1516
from .mem.arr import SharedArray
@@ -362,20 +363,24 @@ def covariance_to_txt(
362363
data_frame = pd.DataFrame(array, index=self.samples, columns=names)
363364

364365
logger.debug("Calculating covariance matrix")
365-
covariance = data_frame.cov().to_numpy(dtype=np.float64)
366+
covariance: npt.NDArray[np.float64] = np.asfortranarray(
367+
data_frame.cov().to_numpy(dtype=np.float64)
368+
)
366369

367-
file_array = FileArray.create(
370+
writer: FileArrayWriter[np.float64] = FileArray.create(
368371
path,
369372
covariance.shape,
370-
covariance.dtype,
373+
covariance.dtype.type,
371374
compression_method,
372375
num_threads=num_threads,
373376
)
377+
374378
data_frame = pd.DataFrame(dict(variable=names))
375-
with file_array:
376-
file_array.set_axis_metadata(0, data_frame)
377-
file_array.set_axis_metadata(1, names)
378-
file_array[:, :] = covariance
379+
writer.set_axis_metadata(0, data_frame)
380+
writer.set_axis_metadata(1, names)
381+
382+
with writer:
383+
writer[:, :] = covariance
379384

380385

381386
@dataclass

src/gwas/tests/score/test_pheno.py

Lines changed: 93 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,44 @@
1+
import sys
2+
13
import numpy as np
24
import pandas as pd
5+
import pytest
6+
from numpy import typing as npt
37
from numpy.testing import assert_array_equal
48
from pytest import FixtureRequest
59
from upath import UPath
610

11+
from gwas.compression.arr.base import Blosc2CompressionMethod, compression_methods
712
from gwas.mem.wkspace import SharedWorkspace
813
from gwas.pheno import VariableCollection
14+
from gwas.utils import cpu_count
915

1016
from .simulation import missing_value_rate
1117

18+
try:
19+
import blosc2 as blosc2
20+
except ImportError:
21+
pass
22+
1223
sample_count = 100
1324
phenotype_count = 16
1425
covariate_count = 4
1526

16-
samples = [str(i) for i in range(sample_count)]
17-
permutation = np.random.permutation(sample_count)
27+
samples = [f"{i + 1:03d}" for i in range(sample_count)]
1828

1929
phenotype_names = [f"phenotype_{i + 1:02d}" for i in range(phenotype_count)]
2030
covariate_names = [f"covariate_{i + 1:02d}" for i in range(covariate_count)]
2131

2232

23-
def test_pheno(
24-
tmp_path: UPath,
25-
sw: SharedWorkspace,
26-
request: FixtureRequest,
27-
) -> None:
28-
np.random.seed(47)
29-
allocation_names = set(sw.allocations.keys())
33+
@pytest.fixture(scope="session")
34+
def permutation() -> npt.NDArray[np.int_]:
35+
np.random.seed(46)
36+
return np.random.permutation(sample_count)
3037

38+
39+
@pytest.fixture(scope="session")
40+
def phenotypes() -> npt.NDArray[np.float64]:
41+
np.random.seed(47)
3142
phenotypes = np.random.rand(sample_count, phenotype_count)
3243
phenotypes[
3344
np.random.choice(
@@ -36,8 +47,22 @@ def test_pheno(
3647
p=[1 - missing_value_rate, missing_value_rate],
3748
)
3849
] = np.nan
39-
covariates = np.random.rand(sample_count, covariate_count)
50+
return phenotypes
4051

52+
53+
@pytest.fixture(scope="session")
54+
def covariates() -> npt.NDArray[np.float64]:
55+
np.random.seed(48)
56+
return np.random.rand(sample_count, covariate_count)
57+
58+
59+
@pytest.fixture(scope="session")
60+
def phenotype_path(
61+
phenotypes: npt.NDArray[np.float64],
62+
permutation: npt.NDArray[np.int_],
63+
tmp_path_factory: pytest.TempPathFactory,
64+
) -> UPath:
65+
tmp_path = UPath(tmp_path_factory.mktemp("phenotypes"))
4166
phenotype_frame = pd.DataFrame(
4267
phenotypes[permutation, :],
4368
columns=phenotype_names,
@@ -47,7 +72,16 @@ def test_pheno(
4772
phenotype_frame.to_csv(
4873
phenotype_path, sep="\t", index=True, header=True, na_rep="n/a"
4974
)
75+
return phenotype_path
5076

77+
78+
@pytest.fixture(scope="session")
79+
def covariate_path(
80+
covariates: npt.NDArray[np.float64],
81+
permutation: npt.NDArray[np.int_],
82+
tmp_path_factory: pytest.TempPathFactory,
83+
) -> UPath:
84+
tmp_path = UPath(tmp_path_factory.mktemp("covariates"))
5185
covariate_frame = pd.DataFrame(
5286
covariates[permutation, :],
5387
columns=covariate_names,
@@ -57,6 +91,18 @@ def test_pheno(
5791
covariate_frame.to_csv(
5892
covariate_path, sep="\t", index=True, header=True, na_rep="n/a"
5993
)
94+
return covariate_path
95+
96+
97+
def test_pheno(
98+
phenotypes: npt.NDArray[np.float64],
99+
covariates: npt.NDArray[np.float64],
100+
phenotype_path: UPath,
101+
covariate_path: UPath,
102+
sw: SharedWorkspace,
103+
request: FixtureRequest,
104+
) -> None:
105+
allocation_names = set(sw.allocations.keys())
60106

61107
variable_collection0 = VariableCollection.from_txt(
62108
[phenotype_path],
@@ -135,3 +181,40 @@ def test_pheno_zero_variance(
135181
variable_collection.covariates.name,
136182
}
137183
assert set(sw.allocations.keys()) <= (allocation_names | new_allocation_names)
184+
185+
186+
@pytest.mark.parametrize("compression_method_name", compression_methods.keys())
187+
def test_covariance(
188+
compression_method_name: str,
189+
phenotype_path: UPath,
190+
covariate_path: UPath,
191+
sw: SharedWorkspace,
192+
tmp_path: UPath,
193+
request: FixtureRequest,
194+
) -> None:
195+
compression_method = compression_methods[compression_method_name]
196+
if isinstance(compression_method, Blosc2CompressionMethod):
197+
if "blosc2" not in sys.modules:
198+
pytest.skip("blosc2 not installed")
199+
200+
allocation_names = set(sw.allocations.keys())
201+
202+
variable_collection = VariableCollection.from_txt(
203+
[phenotype_path],
204+
[covariate_path],
205+
sw,
206+
samples=samples,
207+
missing_value_strategy="listwise_deletion",
208+
)
209+
request.addfinalizer(variable_collection.free)
210+
211+
covariance_path = tmp_path / "covariance.tsv"
212+
variable_collection.covariance_to_txt(
213+
covariance_path, compression_method, num_threads=cpu_count()
214+
)
215+
216+
new_allocation_names = {
217+
variable_collection.phenotypes.name,
218+
variable_collection.covariates.name,
219+
}
220+
assert set(sw.allocations.keys()) <= (allocation_names | new_allocation_names)

0 commit comments

Comments
 (0)