Skip to content

Commit e683f5b

Browse files
authored
Merge pull request #958 from int-brain-lab/ssl_revision_2
adding revision as a kwarg to the spike sorting loader
2 parents 88d39b1 + 1e456da commit e683f5b

File tree

1 file changed

+11
-5
lines changed

1 file changed

+11
-5
lines changed

brainbox/io/one.py

+11-5
Original file line numberDiff line numberDiff line change
@@ -808,6 +808,7 @@ class SpikeSortingLoader:
808808
spike_sorter: str = 'pykilosort'
809809
spike_sorting_path: Path = None
810810
_sync: dict = None
811+
revision: str = None
811812

812813
def __post_init__(self):
813814
# pid gets precedence
@@ -886,7 +887,7 @@ def _get_spike_sorting_collection(self, spike_sorter=None):
886887
_logger.debug(f"selecting: {collection} to load amongst candidates: {self.collections}")
887888
return collection
888889

889-
def load_spike_sorting_object(self, obj, *args, **kwargs):
890+
def load_spike_sorting_object(self, obj, *args, revision=None, **kwargs):
890891
"""
891892
Loads an ALF object
892893
:param obj: object name, str between 'spikes', 'clusters' or 'channels'
@@ -895,8 +896,10 @@ def load_spike_sorting_object(self, obj, *args, **kwargs):
895896
:param collection: string specifiying the collection, for example 'alf/probe01/pykilosort'
896897
:param kwargs: additional arguments to be passed to one.api.One.load_object
897898
:param missing: 'raise' (default) or 'ignore'
899+
:param revision: the dataset revision to load
898900
:return:
899901
"""
902+
revision = revision if revision is not None else self.revision
900903
self.download_spike_sorting_object(obj, *args, **kwargs)
901904
return self._load_object(self.files[obj])
902905

@@ -907,7 +910,7 @@ def get_version(self, spike_sorter=None):
907910
return dset[0]['version'] if len(dset) else 'unknown'
908911

909912
def download_spike_sorting_object(self, obj, spike_sorter=None, dataset_types=None, collection=None,
910-
attribute=None, missing='raise', **kwargs):
913+
attribute=None, missing='raise', revision=None, **kwargs):
911914
"""
912915
Downloads an ALF object
913916
:param obj: object name, str between 'spikes', 'clusters' or 'channels'
@@ -917,8 +920,10 @@ def download_spike_sorting_object(self, obj, spike_sorter=None, dataset_types=No
917920
:param kwargs: additional arguments to be passed to one.api.One.load_object
918921
:param attribute: list of attributes to load for the object
919922
:param missing: 'raise' (default) or 'ignore'
923+
:param revision: the dataset revision to load
920924
:return:
921925
"""
926+
revision = revision if revision is not None else self.revision
922927
if spike_sorter is None:
923928
spike_sorter = self.spike_sorter if self.spike_sorter is not None else 'iblsorter'
924929
if len(self.collections) == 0:
@@ -1170,12 +1175,13 @@ def url(self):
11701175
webclient = getattr(self.one, '_web_client', None)
11711176
return webclient.rel_path2url(get_alf_path(self.session_path)) if webclient else None
11721177

1173-
def _get_probe_info(self):
1178+
def _get_probe_info(self, revision=None):
1179+
revision = revision if revision is not None else self.revision
11741180
if self._sync is None:
11751181
timestamps = self.one.load_dataset(
1176-
self.eid, dataset='_spikeglx_*.timestamps.npy', collection=f'raw_ephys_data/{self.pname}')
1182+
self.eid, dataset='_spikeglx_*.timestamps.npy', collection=f'raw_ephys_data/{self.pname}', revision=revision)
11771183
_ = self.one.load_dataset( # this is not used here but we want to trigger the download for potential tasks
1178-
self.eid, dataset='_spikeglx_*.sync.npy', collection=f'raw_ephys_data/{self.pname}')
1184+
self.eid, dataset='_spikeglx_*.sync.npy', collection=f'raw_ephys_data/{self.pname}', revision=revision)
11791185
try:
11801186
ap_meta = spikeglx.read_meta_data(self.one.load_dataset(
11811187
self.eid, dataset='_spikeglx_*.ap.meta', collection=f'raw_ephys_data/{self.pname}'))

0 commit comments

Comments
 (0)