Skip to content

Commit d132756

Browse files
committed
bgen_to_zarr implementation sgkit-dev#16
1 parent 0153804 commit d132756

File tree

5 files changed

+269
-26
lines changed

5 files changed

+269
-26
lines changed

setup.cfg

+3-1
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ ignore =
5454
[isort]
5555
default_section = THIRDPARTY
5656
known_first_party = sgkit
57-
known_third_party = bgen_reader,dask,numpy,pytest,setuptools,xarray
57+
known_third_party = bgen_reader,dask,numpy,pytest,setuptools,xarray,zarr
5858
multi_line_output = 3
5959
include_trailing_comma = True
6060
force_grid_wrap = 0
@@ -71,5 +71,7 @@ ignore_missing_imports = True
7171
ignore_missing_imports = True
7272
[mypy-sgkit.*]
7373
ignore_missing_imports = True
74+
[mypy-zarr.*]
75+
ignore_missing_imports = True
7476
[mypy-sgkit_bgen.tests.*]
7577
disallow_untyped_defs = False

sgkit_bgen/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
from .bgen_reader import read_bgen # noqa: F401
1+
from .bgen_reader import read_bgen, rechunk_from_zarr, rechunk_to_zarr # noqa: F401
22

3-
__all__ = ["read_bgen"]
3+
__all__ = ["read_bgen", "rechunk_from_zarr", "rechunk_to_zarr"]

sgkit_bgen/bgen_reader.py

+133-17
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,19 @@
11
"""BGEN reader implementation (using bgen_reader)"""
22
from pathlib import Path
3-
from typing import Any, Dict, Tuple, Union
3+
from typing import Any, Dict, Hashable, MutableMapping, Optional, Tuple, Union
44

55
import dask.array as da
66
import dask.dataframe as dd
77
import numpy as np
8+
import xarray as xr
9+
import zarr
810
from bgen_reader._bgen_file import bgen_file
911
from bgen_reader._bgen_metafile import bgen_metafile
1012
from bgen_reader._metafile import create_metafile
1113
from bgen_reader._reader import infer_metafile_filepath
1214
from bgen_reader._samples import generate_samples, read_samples_file
1315
from xarray import Dataset
16+
from xarray.backends.zarr import ZarrStore
1417

1518
from sgkit import create_genotype_dosage_dataset
1619
from sgkit.typing import ArrayLike
@@ -38,6 +41,13 @@ def _to_dict(df: dd.DataFrame, dtype: Any = None) -> Dict[str, da.Array]:
3841
VARIANT_DF_DTYPE = dict([(f[0], f[1]) for f in VARIANT_FIELDS])
3942
VARIANT_ARRAY_DTYPE = dict([(f[0], f[2]) for f in VARIANT_FIELDS])
4043

44+
GT_DATA_VARS = [
45+
"call_genotype_probability",
46+
"call_genotype_probability_mask",
47+
"call_dosage",
48+
"call_dosage_mask",
49+
]
50+
4151

4252
class BgenReader:
4353

@@ -79,15 +89,7 @@ def split(allele_row: np.ndarray) -> np.ndarray:
7989

8090
return np.apply_along_axis(split, 1, alleles[:, np.newaxis])
8191

82-
variant_alleles = variant_arrs["allele_ids"].map_blocks(split_alleles)
83-
84-
def max_str_len(arr: ArrayLike) -> Any:
85-
return arr.map_blocks(
86-
lambda s: np.char.str_len(s.astype(str)), dtype=np.int8
87-
).max()
88-
89-
max_allele_length = max(max_str_len(variant_alleles).compute())
90-
self.variant_alleles = variant_alleles.astype(f"S{max_allele_length}")
92+
self.variant_alleles = variant_arrs["allele_ids"].map_blocks(split_alleles)
9193

9294
with bgen_file(self.path) as bgen:
9395
sample_path = self.path.with_suffix(".sample")
@@ -172,6 +174,7 @@ def read_bgen(
172174
chunks: Union[str, int, Tuple[int, ...]] = "auto",
173175
lock: bool = False,
174176
persist: bool = True,
177+
dtype: Any = "float32",
175178
) -> Dataset:
176179
"""Read BGEN dataset.
177180
@@ -194,23 +197,23 @@ def read_bgen(
194197
memory, by default True. This is an important performance
195198
consideration as the metadata file for this data will
196199
be read multiple times when False.
200+
dtype : Any
201+
Genotype probability array data type, by default float32.
197202
198203
Warnings
199204
--------
200205
Only bi-allelic, diploid BGEN files are currently supported.
201206
"""
202207

203-
bgen_reader = BgenReader(path, persist)
208+
bgen_reader = BgenReader(path, persist, dtype=dtype)
204209

205210
variant_contig, variant_contig_names = encode_array(bgen_reader.contig.compute())
206211
variant_contig_names = list(variant_contig_names)
207212
variant_contig = variant_contig.astype("int16")
208-
209-
variant_position = np.array(bgen_reader.pos, dtype=int)
210-
variant_alleles = np.array(bgen_reader.variant_alleles, dtype="S1")
211-
variant_id = np.array(bgen_reader.variant_id, dtype=str)
212-
213-
sample_id = np.array(bgen_reader.sample_id, dtype=str)
213+
variant_position = np.asarray(bgen_reader.pos, dtype=int)
214+
variant_alleles = np.asarray(bgen_reader.variant_alleles, dtype="S")
215+
variant_id = np.asarray(bgen_reader.variant_id, dtype=str)
216+
sample_id = np.asarray(bgen_reader.sample_id, dtype=str)
214217

215218
call_genotype_probability = da.from_array(
216219
bgen_reader,
@@ -234,3 +237,116 @@ def read_bgen(
234237
)
235238

236239
return ds
240+
241+
242+
def encode_variables(
243+
ds: Dataset,
244+
compressor: Optional[Any] = zarr.Blosc(cname="zstd", clevel=7, shuffle=2),
245+
probability_dtype: Optional[Any] = "uint8",
246+
) -> Dict[Hashable, Dict[str, Any]]:
247+
encoding = {}
248+
for v in ds:
249+
e = {}
250+
if compressor is not None:
251+
e.update({"compressor": compressor})
252+
if probability_dtype is not None and v == "call_genotype_probability":
253+
dtype = np.dtype(probability_dtype)
254+
# Xarray will decode into float32 so any int greater than
255+
# 16 bits will cause overflow/underflow
256+
# See https://en.wikipedia.org/wiki/Floating-point_arithmetic#Internal_representation
257+
# *bits precision column for single precision floats
258+
if dtype not in [np.uint8, np.uint16]:
259+
raise ValueError(
260+
"Probability integer dtype invalid, must "
261+
f"be uint8 or uint16 not {probability_dtype}"
262+
)
263+
divisor = np.iinfo(dtype).max - 1
264+
e.update(
265+
{
266+
"dtype": probability_dtype,
267+
"add_offset": -1.0 / divisor,
268+
"scale_factor": 1.0 / divisor,
269+
"_FillValue": 0,
270+
}
271+
)
272+
if e:
273+
encoding[v] = e
274+
return encoding
275+
276+
277+
def pack_variables(ds: Dataset) -> Dataset:
278+
# Remove dosage as it is unnecessary and should be redefined
279+
# based on encoded probabilities later (w/ reduced precision)
280+
ds = ds.drop_vars(["call_dosage", "call_dosage_mask"], errors="ignore")
281+
282+
# Remove homozygous reference GP and redefine mask
283+
gp = ds["call_genotype_probability"][..., 1:]
284+
gp_mask = ds["call_genotype_probability_mask"].any(dim="genotypes")
285+
ds = ds.drop_vars(["call_genotype_probability", "call_genotype_probability_mask"])
286+
ds = ds.assign(call_genotype_probability=gp, call_genotype_probability_mask=gp_mask)
287+
return ds
288+
289+
290+
def unpack_variables(ds: Dataset, dtype: Any = "float32") -> Dataset:
291+
# Restore homozygous reference GP
292+
gp = ds["call_genotype_probability"].astype(dtype)
293+
if gp.sizes["genotypes"] != 2:
294+
raise ValueError(
295+
"Expecting variable 'call_genotype_probability' to have genotypes "
296+
f"dimension of size 2 (received sizes = {dict(gp.sizes)})"
297+
)
298+
ds = ds.drop_vars("call_genotype_probability")
299+
ds["call_genotype_probability"] = xr.concat( # type: ignore[no-untyped-call]
300+
[1 - gp.sum(dim="genotypes", skipna=False), gp], dim="genotypes"
301+
)
302+
303+
# Restore dosage
304+
ds["call_dosage"] = gp[..., 0] + 2 * gp[..., 1]
305+
ds["call_dosage_mask"] = ds["call_genotype_probability_mask"]
306+
ds["call_genotype_probability_mask"] = ds[
307+
"call_genotype_probability_mask"
308+
].broadcast_like(ds["call_genotype_probability"])
309+
return ds
310+
311+
312+
def rechunk_to_zarr(
313+
ds: Dataset,
314+
store: Union[PathType, MutableMapping[str, bytes]],
315+
*,
316+
mode: str = "w",
317+
chunk_length: int = 10_000,
318+
chunk_width: int = 10_000,
319+
compressor: Optional[Any] = zarr.Blosc(cname="zstd", clevel=7, shuffle=2),
320+
probability_dtype: Optional[Any] = "uint8",
321+
pack: bool = True,
322+
compute: bool = True,
323+
) -> ZarrStore:
324+
if pack:
325+
ds = pack_variables(ds)
326+
for v in set(GT_DATA_VARS) & set(ds):
327+
chunk_size = da.asarray(ds[v]).chunksize[0]
328+
if chunk_length % chunk_size != 0:
329+
raise ValueError(
330+
f"Chunk size in variant dimension for variable '{v}' ({chunk_size}) "
331+
f"must evenly divide target chunk size {chunk_length}"
332+
)
333+
ds[v] = ds[v].chunk(chunks=dict(samples=chunk_width)) # type: ignore[dict-item]
334+
encoding = encode_variables(
335+
ds, compressor=compressor, probability_dtype=probability_dtype
336+
)
337+
return ds.to_zarr(store, mode=mode, encoding=encoding or None, compute=compute) # type: ignore[arg-type]
338+
339+
340+
def rechunk_from_zarr(
341+
store: Union[PathType, MutableMapping[str, bytes]],
342+
chunk_length: int = 10_000,
343+
chunk_width: int = 10_000,
344+
mask_and_scale: bool = True,
345+
) -> Dataset:
346+
# Always use concat_characters=False to avoid https://github.com/pydata/xarray/issues/4405
347+
ds = xr.open_zarr(store, mask_and_scale=mask_and_scale, concat_characters=False) # type: ignore[no-untyped-call]
348+
for v in set(GT_DATA_VARS) & set(ds):
349+
ds[v] = ds[v].chunk(chunks=dict(variants=chunk_length, samples=chunk_width))
350+
# Workaround for https://github.com/pydata/xarray/issues/4380
351+
del ds[v].encoding["chunks"]
352+
return ds # type: ignore[no-any-return]

sgkit_bgen/tests/data/.gitignore

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
*.metadata2.mmm
2+
*.metafile

0 commit comments

Comments
 (0)