Skip to content

Commit 3b7c30b

Browse files
authored
Merge pull request sgkit-dev#15 from eric-czech/add_gp
Return genotype probabilities and dosages
2 parents f80601b + 93f9708 commit 3b7c30b

File tree

2 files changed

+124
-17
lines changed

2 files changed

+124
-17
lines changed

sgkit_bgen/bgen_reader.py

Lines changed: 36 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -94,28 +94,38 @@ def max_str_len(arr: ArrayLike) -> Any:
9494
else:
9595
self.sample_id = generate_samples(bgen.nsamples)
9696

97-
self.shape = (self.n_variants, len(self.sample_id))
97+
self.shape = (self.n_variants, len(self.sample_id), 3)
9898
self.dtype = dtype
99-
self.ndim = 2
99+
self.ndim = 3
100100

101101
def __getitem__(self, idx):
102102
if not isinstance(idx, tuple):
103-
raise IndexError( # pragma: no cover
104-
f"Indexer must be tuple (received {type(idx)})"
105-
)
103+
raise IndexError(f"Indexer must be tuple (received {type(idx)})")
106104
if len(idx) != self.ndim:
107-
raise IndexError( # pragma: no cover
108-
f"Indexer must be two-item tuple (received {len(idx)} slices)"
105+
raise IndexError(
106+
f"Indexer must have {self.ndim} items (received {len(idx)} slices)"
107+
)
108+
if not all(isinstance(i, slice) or isinstance(i, int) for i in idx):
109+
raise IndexError(
110+
f"Indexer must contain only slices or ints (received types {[type(i) for i in idx]})"
109111
)
112+
# Determine which dims should have unit size in result
113+
squeeze_dims = tuple(i for i in range(len(idx)) if isinstance(idx[i], int))
114+
# Convert all indexers to slices
115+
idx = tuple(slice(i, i + 1) if isinstance(i, int) else i for i in idx)
110116

111117
if idx[0].start == idx[0].stop:
112-
return np.empty((0, 0), dtype=self.dtype)
118+
return np.empty((0,) * self.ndim, dtype=self.dtype)
113119

120+
# Determine start and end partitions that correspond to the
121+
# given variant dimension indexer
114122
start_partition = idx[0].start // self.partition_size
115123
start_partition_offset = idx[0].start % self.partition_size
116124
end_partition = (idx[0].stop - 1) // self.partition_size
117125
end_partition_offset = (idx[0].stop - 1) % self.partition_size
118126

127+
# Create a list of all offsets into the underlying file at which
128+
# data for each variant begins
119129
all_vaddr = []
120130
with bgen_metafile(self.metafile_filepath) as mf:
121131
for i in range(start_partition, end_partition + 1):
@@ -129,21 +139,27 @@ def __getitem__(self, idx):
129139
vaddr = partition["vaddr"].tolist()
130140
all_vaddr.extend(vaddr[start_offset:end_offset])
131141

142+
# Read the probabilities for each variant, apply indexer for
143+
# samples dimension to give probabilities for all genotypes,
144+
# and then apply final genotype dimension indexer
132145
with bgen_file(self.path) as bgen:
133146
res = None
134147
for i, vaddr in enumerate(all_vaddr):
135148
probs = bgen.read_genotype(vaddr)["probs"][idx[1]]
136-
dosage = _to_dosage(probs)
149+
assert len(probs.shape) == 2 and probs.shape[1] == 3
137150
if res is None:
138-
res = np.zeros((len(all_vaddr), len(dosage)), dtype=self.dtype)
139-
res[i] = dosage
140-
return res
151+
res = np.zeros((len(all_vaddr), len(probs), 3), dtype=self.dtype)
152+
res[i] = probs
153+
res = res[..., idx[2]]
154+
return np.squeeze(res, axis=squeeze_dims)
141155

142156

143157
def _to_dosage(probs: ArrayLike):
144158
"""Calculate the dosage from genotype likelihoods (probabilities)"""
145-
assert len(probs.shape) == 2 and probs.shape[1] == 3
146-
return 2 * probs[:, -1] + probs[:, 1]
159+
assert (
160+
probs.shape[-1] == 3
161+
), f"Expecting genotype (trailing) dimension of size 3, got array of shape {probs.shape}"
162+
return probs[..., 1] + 2 * probs[..., 2]
147163

148164

149165
def read_bgen(
@@ -162,7 +178,8 @@ def read_bgen(
162178
path : PathType
163179
Path to BGEN file.
164180
chunks : Union[str, int, tuple], optional
165-
Chunk size for genotype data, by default "auto"
181+
Chunk size for genotype probability data (3 dimensions),
182+
by default "auto".
166183
lock : bool, optional
167184
Whether or not to synchronize concurrent reads of
168185
file blocks, by default False. This is passed through to
@@ -190,13 +207,15 @@ def read_bgen(
190207

191208
sample_id = np.array(bgen_reader.sample_id, dtype=str)
192209

193-
call_dosage = da.from_array(
210+
call_genotype_probability = da.from_array(
194211
bgen_reader,
195212
chunks=chunks,
196213
lock=lock,
214+
fancy=False,
197215
asarray=False,
198216
name=f"{bgen_reader.name}:read_bgen:{path}",
199217
)
218+
call_dosage = _to_dosage(call_genotype_probability)
200219

201220
ds = create_genotype_dosage_dataset(
202221
variant_contig_names=variant_contig_names,
@@ -205,6 +224,7 @@ def read_bgen(
205224
variant_alleles=variant_alleles,
206225
sample_id=sample_id,
207226
call_dosage=call_dosage,
227+
call_genotype_probability=call_genotype_probability,
208228
variant_id=variant_id,
209229
)
210230

sgkit_bgen/tests/test_bgen_reader.py

Lines changed: 88 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,41 @@
1+
import numpy as np
12
import numpy.testing as npt
23
import pytest
34
from sgkit_bgen import read_bgen
5+
from sgkit_bgen.bgen_reader import BgenReader
46

7+
CHUNKS = [
8+
(100, 200, 3),
9+
(100, 200, 1),
10+
(100, 500, 3),
11+
(199, 500, 3),
12+
((100, 99), 500, 2),
13+
"auto",
14+
]
15+
INDEXES = [0, 10, 20, 100, -1]
516

6-
@pytest.mark.parametrize("chunks", [(100, 200), (100, 500), (199, 500), "auto"])
17+
# Expectations below generated using bgen-reader directly, ex:
18+
# > from bgen_reader import open_bgen
19+
# > bgen = open_bgen('sgkit_bgen/tests/data/example.bgen', verbose=False)
20+
# > bgen.read(-1)[0] # Probabilities for last variant, first sample
21+
# array([[0.0133972 , 0.98135378, 0.00524902]]
22+
# > bgen.allele_expectation(-1)[0, 0, -1] # Dosage for last variant, first sample
23+
# 0.9918518217727197
24+
EXPECTED_PROBABILITIES = np.array(
25+
[ # Generated using bgen-reader directly
26+
[np.nan, np.nan, np.nan],
27+
[0.007, 0.966, 0.0259],
28+
[0.993, 0.002, 0.003],
29+
[0.916, 0.007, 0.0765],
30+
[0.013, 0.981, 0.0052],
31+
]
32+
)
33+
EXPECTED_DOSAGES = np.array(
34+
[np.nan, 1.018, 0.010, 0.160, 0.991] # Generated using bgen-reader directly
35+
)
36+
37+
38+
@pytest.mark.parametrize("chunks", CHUNKS)
739
def test_read_bgen(shared_datadir, chunks):
840
path = shared_datadir / "example.bgen"
941
ds = read_bgen(path, chunks=chunks)
@@ -12,6 +44,21 @@ def test_read_bgen(shared_datadir, chunks):
1244
assert ds["call_dosage"].shape == (199, 500)
1345
npt.assert_almost_equal(ds["call_dosage"].values[1][0], 1.987, decimal=3)
1446
npt.assert_almost_equal(ds["call_dosage"].values[100][0], 0.160, decimal=3)
47+
npt.assert_array_equal(ds["call_dosage_mask"].values[0, 0], [True])
48+
npt.assert_array_equal(ds["call_dosage_mask"].values[0, 1], [False])
49+
assert ds["call_genotype_probability"].shape == (199, 500, 3)
50+
npt.assert_almost_equal(
51+
ds["call_genotype_probability"].values[1][0], [0.005, 0.002, 0.992], decimal=3
52+
)
53+
npt.assert_almost_equal(
54+
ds["call_genotype_probability"].values[100][0], [0.916, 0.007, 0.076], decimal=3
55+
)
56+
npt.assert_array_equal(
57+
ds["call_genotype_probability_mask"].values[0, 0], [True] * 3
58+
)
59+
npt.assert_array_equal(
60+
ds["call_genotype_probability_mask"].values[0, 1], [False] * 3
61+
)
1562

1663

1764
def test_read_bgen_with_sample_file(shared_datadir):
@@ -38,3 +85,43 @@ def test_read_bgen_with_no_samples(shared_datadir):
3885
"sample_3",
3986
"sample_4",
4087
]
88+
89+
90+
@pytest.mark.parametrize("chunks", CHUNKS)
91+
def test_read_bgen_fancy_index(shared_datadir, chunks):
92+
path = shared_datadir / "example.bgen"
93+
ds = read_bgen(path, chunks=chunks)
94+
npt.assert_almost_equal(
95+
ds["call_genotype_probability"][INDEXES, 0], EXPECTED_PROBABILITIES, decimal=3
96+
)
97+
npt.assert_almost_equal(ds["call_dosage"][INDEXES, 0], EXPECTED_DOSAGES, decimal=3)
98+
99+
100+
@pytest.mark.parametrize("chunks", CHUNKS)
101+
def test_read_bgen_scalar_index(shared_datadir, chunks):
102+
path = shared_datadir / "example.bgen"
103+
ds = read_bgen(path, chunks=chunks)
104+
for i, ix in enumerate(INDEXES):
105+
npt.assert_almost_equal(
106+
ds["call_genotype_probability"][ix, 0], EXPECTED_PROBABILITIES[i], decimal=3
107+
)
108+
npt.assert_almost_equal(
109+
ds["call_dosage"][ix, 0], EXPECTED_DOSAGES[i], decimal=3
110+
)
111+
for j in range(3):
112+
npt.assert_almost_equal(
113+
ds["call_genotype_probability"][ix, 0, j],
114+
EXPECTED_PROBABILITIES[i, j],
115+
decimal=3,
116+
)
117+
118+
119+
def test_read_bgen_raise_on_invalid_indexers(shared_datadir):
120+
path = shared_datadir / "example.bgen"
121+
reader = BgenReader(path)
122+
with pytest.raises(IndexError, match="Indexer must be tuple"):
123+
reader[[0]]
124+
with pytest.raises(IndexError, match="Indexer must have 3 items"):
125+
reader[(slice(None),)]
126+
with pytest.raises(IndexError, match="Indexer must contain only slices or ints"):
127+
reader[([0], [0], [0])]

0 commit comments

Comments
 (0)