1
1
"""BGEN reader implementation (using bgen_reader)"""
2
2
from pathlib import Path
3
- from typing import Any , Dict , Tuple , Union
3
+ from typing import Any , Dict , Hashable , MutableMapping , Optional , Tuple , Union
4
4
5
5
import dask .array as da
6
6
import dask .dataframe as dd
7
7
import numpy as np
8
+ import xarray as xr
9
+ import zarr
8
10
from bgen_reader ._bgen_file import bgen_file
9
11
from bgen_reader ._bgen_metafile import bgen_metafile
10
12
from bgen_reader ._metafile import create_metafile
11
13
from bgen_reader ._reader import infer_metafile_filepath
12
14
from bgen_reader ._samples import generate_samples , read_samples_file
13
15
from xarray import Dataset
16
+ from xarray .backends .zarr import ZarrStore
14
17
15
18
from sgkit import create_genotype_dosage_dataset
16
19
from sgkit .typing import ArrayLike
@@ -38,6 +41,13 @@ def _to_dict(df: dd.DataFrame, dtype: Any = None) -> Dict[str, da.Array]:
38
41
VARIANT_DF_DTYPE = dict ([(f [0 ], f [1 ]) for f in VARIANT_FIELDS ])
39
42
VARIANT_ARRAY_DTYPE = dict ([(f [0 ], f [2 ]) for f in VARIANT_FIELDS ])
40
43
44
+ GT_DATA_VARS = [
45
+ "call_genotype_probability" ,
46
+ "call_genotype_probability_mask" ,
47
+ "call_dosage" ,
48
+ "call_dosage_mask" ,
49
+ ]
50
+
41
51
42
52
class BgenReader :
43
53
@@ -79,15 +89,7 @@ def split(allele_row: np.ndarray) -> np.ndarray:
79
89
80
90
return np .apply_along_axis (split , 1 , alleles [:, np .newaxis ])
81
91
82
- variant_alleles = variant_arrs ["allele_ids" ].map_blocks (split_alleles )
83
-
84
- def max_str_len (arr : ArrayLike ) -> Any :
85
- return arr .map_blocks (
86
- lambda s : np .char .str_len (s .astype (str )), dtype = np .int8
87
- ).max ()
88
-
89
- max_allele_length = max (max_str_len (variant_alleles ).compute ())
90
- self .variant_alleles = variant_alleles .astype (f"S{ max_allele_length } " )
92
+ self .variant_alleles = variant_arrs ["allele_ids" ].map_blocks (split_alleles )
91
93
92
94
with bgen_file (self .path ) as bgen :
93
95
sample_path = self .path .with_suffix (".sample" )
@@ -172,6 +174,7 @@ def read_bgen(
172
174
chunks : Union [str , int , Tuple [int , ...]] = "auto" ,
173
175
lock : bool = False ,
174
176
persist : bool = True ,
177
+ dtype : Any = "float32" ,
175
178
) -> Dataset :
176
179
"""Read BGEN dataset.
177
180
@@ -194,23 +197,23 @@ def read_bgen(
194
197
memory, by default True. This is an important performance
195
198
consideration as the metadata file for this data will
196
199
be read multiple times when False.
200
+ dtype : Any
201
+ Genotype probability array data type, by default float32.
197
202
198
203
Warnings
199
204
--------
200
205
Only bi-allelic, diploid BGEN files are currently supported.
201
206
"""
202
207
203
- bgen_reader = BgenReader (path , persist )
208
+ bgen_reader = BgenReader (path , persist , dtype = dtype )
204
209
205
210
variant_contig , variant_contig_names = encode_array (bgen_reader .contig .compute ())
206
211
variant_contig_names = list (variant_contig_names )
207
212
variant_contig = variant_contig .astype ("int16" )
208
-
209
- variant_position = np .array (bgen_reader .pos , dtype = int )
210
- variant_alleles = np .array (bgen_reader .variant_alleles , dtype = "S1" )
211
- variant_id = np .array (bgen_reader .variant_id , dtype = str )
212
-
213
- sample_id = np .array (bgen_reader .sample_id , dtype = str )
213
+ variant_position = np .asarray (bgen_reader .pos , dtype = int )
214
+ variant_alleles = np .asarray (bgen_reader .variant_alleles , dtype = "S" )
215
+ variant_id = np .asarray (bgen_reader .variant_id , dtype = str )
216
+ sample_id = np .asarray (bgen_reader .sample_id , dtype = str )
214
217
215
218
call_genotype_probability = da .from_array (
216
219
bgen_reader ,
@@ -234,3 +237,116 @@ def read_bgen(
234
237
)
235
238
236
239
return ds
240
+
241
+
242
+ def encode_variables (
243
+ ds : Dataset ,
244
+ compressor : Optional [Any ] = zarr .Blosc (cname = "zstd" , clevel = 7 , shuffle = 2 ),
245
+ probability_dtype : Optional [Any ] = "uint8" ,
246
+ ) -> Dict [Hashable , Dict [str , Any ]]:
247
+ encoding = {}
248
+ for v in ds :
249
+ e = {}
250
+ if compressor is not None :
251
+ e .update ({"compressor" : compressor })
252
+ if probability_dtype is not None and v == "call_genotype_probability" :
253
+ dtype = np .dtype (probability_dtype )
254
+ # Xarray will decode into float32 so any int greater than
255
+ # 16 bits will cause overflow/underflow
256
+ # See https://en.wikipedia.org/wiki/Floating-point_arithmetic#Internal_representation
257
+ # *bits precision column for single precision floats
258
+ if dtype not in [np .uint8 , np .uint16 ]:
259
+ raise ValueError (
260
+ "Probability integer dtype invalid, must "
261
+ f"be uint8 or uint16 not { probability_dtype } "
262
+ )
263
+ divisor = np .iinfo (dtype ).max - 1
264
+ e .update (
265
+ {
266
+ "dtype" : probability_dtype ,
267
+ "add_offset" : - 1.0 / divisor ,
268
+ "scale_factor" : 1.0 / divisor ,
269
+ "_FillValue" : 0 ,
270
+ }
271
+ )
272
+ if e :
273
+ encoding [v ] = e
274
+ return encoding
275
+
276
+
277
+ def pack_variables (ds : Dataset ) -> Dataset :
278
+ # Remove dosage as it is unnecessary and should be redefined
279
+ # based on encoded probabilities later (w/ reduced precision)
280
+ ds = ds .drop_vars (["call_dosage" , "call_dosage_mask" ], errors = "ignore" )
281
+
282
+ # Remove homozygous reference GP and redefine mask
283
+ gp = ds ["call_genotype_probability" ][..., 1 :]
284
+ gp_mask = ds ["call_genotype_probability_mask" ].any (dim = "genotypes" )
285
+ ds = ds .drop_vars (["call_genotype_probability" , "call_genotype_probability_mask" ])
286
+ ds = ds .assign (call_genotype_probability = gp , call_genotype_probability_mask = gp_mask )
287
+ return ds
288
+
289
+
290
+ def unpack_variables (ds : Dataset , dtype : Any = "float32" ) -> Dataset :
291
+ # Restore homozygous reference GP
292
+ gp = ds ["call_genotype_probability" ].astype (dtype )
293
+ if gp .sizes ["genotypes" ] != 2 :
294
+ raise ValueError (
295
+ "Expecting variable 'call_genotype_probability' to have genotypes "
296
+ f"dimension of size 2 (received sizes = { dict (gp .sizes )} )"
297
+ )
298
+ ds = ds .drop_vars ("call_genotype_probability" )
299
+ ds ["call_genotype_probability" ] = xr .concat ( # type: ignore[no-untyped-call]
300
+ [1 - gp .sum (dim = "genotypes" , skipna = False ), gp ], dim = "genotypes"
301
+ )
302
+
303
+ # Restore dosage
304
+ ds ["call_dosage" ] = gp [..., 0 ] + 2 * gp [..., 1 ]
305
+ ds ["call_dosage_mask" ] = ds ["call_genotype_probability_mask" ]
306
+ ds ["call_genotype_probability_mask" ] = ds [
307
+ "call_genotype_probability_mask"
308
+ ].broadcast_like (ds ["call_genotype_probability" ])
309
+ return ds
310
+
311
+
312
+ def rechunk_to_zarr (
313
+ ds : Dataset ,
314
+ store : Union [PathType , MutableMapping [str , bytes ]],
315
+ * ,
316
+ mode : str = "w" ,
317
+ chunk_length : int = 10_000 ,
318
+ chunk_width : int = 10_000 ,
319
+ compressor : Optional [Any ] = zarr .Blosc (cname = "zstd" , clevel = 7 , shuffle = 2 ),
320
+ probability_dtype : Optional [Any ] = "uint8" ,
321
+ pack : bool = True ,
322
+ compute : bool = True ,
323
+ ) -> ZarrStore :
324
+ if pack :
325
+ ds = pack_variables (ds )
326
+ for v in set (GT_DATA_VARS ) & set (ds ):
327
+ chunk_size = da .asarray (ds [v ]).chunksize [0 ]
328
+ if chunk_length % chunk_size != 0 :
329
+ raise ValueError (
330
+ f"Chunk size in variant dimension for variable '{ v } ' ({ chunk_size } ) "
331
+ f"must evenly divide target chunk size { chunk_length } "
332
+ )
333
+ ds [v ] = ds [v ].chunk (chunks = dict (samples = chunk_width )) # type: ignore[dict-item]
334
+ encoding = encode_variables (
335
+ ds , compressor = compressor , probability_dtype = probability_dtype
336
+ )
337
+ return ds .to_zarr (store , mode = mode , encoding = encoding or None , compute = compute ) # type: ignore[arg-type]
338
+
339
+
340
+ def rechunk_from_zarr (
341
+ store : Union [PathType , MutableMapping [str , bytes ]],
342
+ chunk_length : int = 10_000 ,
343
+ chunk_width : int = 10_000 ,
344
+ mask_and_scale : bool = True ,
345
+ ) -> Dataset :
346
+ # Always use concat_characters=False to avoid https://github.com/pydata/xarray/issues/4405
347
+ ds = xr .open_zarr (store , mask_and_scale = mask_and_scale , concat_characters = False ) # type: ignore[no-untyped-call]
348
+ for v in set (GT_DATA_VARS ) & set (ds ):
349
+ ds [v ] = ds [v ].chunk (chunks = dict (variants = chunk_length , samples = chunk_width ))
350
+ # Workaround for https://github.com/pydata/xarray/issues/4380
351
+ del ds [v ].encoding ["chunks" ]
352
+ return ds # type: ignore[no-any-return]
0 commit comments