1
- import os
2
1
import argparse
2
+ import random
3
+ import string
3
4
from typing import List , Optional , Dict , Any
4
- from src .feature_computation import compute_features
5
- from src .region_inference import infer_regions
6
- from one .api import ONE
5
+ from pathlib import Path
6
+
7
7
import numpy as np
8
8
import yaml
9
- from pathlib import Path
10
9
import pandas as pd
11
- from src .logger_config import setup_logger
12
- from src .plots import plot_results
13
- from src import decoding
14
- import random
15
- import string
16
10
17
- # Set up logger
18
- logger = setup_logger (__name__ )
11
+ from iblutil .util import setup_logger
12
+ from one .api import ONE
13
+
14
+ from ephysatlas .feature_computation import compute_features
15
+ from ephysatlas .region_inference import infer_regions
16
+ from ephysatlas .plots import plot_results
17
+ from ephysatlas import decoding
18
+
19
19
20
20
def load_config (config_path : str ) -> Dict [str , Any ]:
21
21
"""Load configuration from YAML file."""
22
- logger .info (f"Loading configuration from { config_path } " )
23
- with open (config_path , 'r' ) as f :
22
+ with open (config_path , "r" ) as f :
24
23
return yaml .safe_load (f )
25
24
26
25
27
26
def parse_arguments (args : List [str ]) -> argparse .Namespace :
28
27
"""Parse command line arguments."""
29
- logger .debug ("Parsing command line arguments" )
30
- parser = argparse .ArgumentParser (description = "Electrophysiology feature computation and region inference" )
31
- parser .add_argument ("--config" , required = True , help = "Path to YAML configuration file" )
28
+ parser = argparse .ArgumentParser (
29
+ description = "Electrophysiology feature computation and region inference"
30
+ )
31
+ parser .add_argument (
32
+ "--config" , required = True , help = "Path to YAML configuration file"
33
+ )
32
34
return parser .parse_args (args )
33
35
34
36
35
37
def get_parameters (args : argparse .Namespace ) -> Dict [str , Any ]:
36
38
"""Get parameters from config file."""
37
- logger .info ("Loading configuration from YAML file" )
38
39
config = load_config (args .config )
39
-
40
+
40
41
# Validate required parameters
41
- if ' pid' in config :
42
+ if " pid" in config :
42
43
# PID-based configuration
43
- required_params = [' pid' , ' t_start' , ' duration' ]
44
+ required_params = [" pid" , " t_start" , " duration" ]
44
45
missing_params = [param for param in required_params if param not in config ]
45
46
if missing_params :
46
- raise ValueError (f"Missing required parameters in config file: { ', ' .join (missing_params )} " )
47
+ raise ValueError (
48
+ f"Missing required parameters in config file: { ', ' .join (missing_params )} "
49
+ )
47
50
else :
48
51
# File-based configuration
49
- required_params = [' ap_file' , ' lf_file' ]
52
+ required_params = [" ap_file" , " lf_file" ]
50
53
missing_params = [param for param in required_params if param not in config ]
51
54
if missing_params :
52
- raise ValueError (f"Missing required parameters in config file: { ', ' .join (missing_params )} " )
53
-
55
+ raise ValueError (
56
+ f"Missing required parameters in config file: { ', ' .join (missing_params )} "
57
+ )
58
+
54
59
return {
55
- 'pid' : config .get ('pid' ),
56
- 'ap_file' : config .get ('ap_file' ),
57
- 'lf_file' : config .get ('lf_file' ),
58
- 't_start' : config .get ('t_start' , 0.0 ), # Default to 0.0 if not specified
59
- 'duration' : config .get ('duration' ), # Default to None if not specified
60
- 'mode' : config .get ('mode' , 'both' ),
61
- 'features_path' : config .get ('features_path' ),
62
- 'model_path' : config .get ('model_path' ),
63
- 'traj_dict' : config .get ('traj_dict' )
60
+ "pid" : config .get ("pid" ),
61
+ "ap_file" : config .get ("ap_file" ),
62
+ "lf_file" : config .get ("lf_file" ),
63
+ "t_start" : config .get ("t_start" , 0.0 ), # Default to 0.0 if not specified
64
+ "duration" : config .get ("duration" ), # Default to None if not specified
65
+ "mode" : config .get ("mode" , "both" ),
66
+ "features_path" : config .get ("features_path" ),
67
+ "model_path" : config .get ("model_path" ),
68
+ "traj_dict" : config .get ("traj_dict" ),
69
+ "log_path" : config .get ("log_path" ), # Get log path from config
64
70
}
65
71
66
72
67
73
def main (args : Optional [List [str ]] = None ) -> int :
68
74
"""Main function that can be called with arguments or use command line arguments."""
69
- logger .info ("Starting main function" )
70
75
if args is None :
71
76
import sys
77
+
72
78
args = sys .argv [1 :]
73
-
79
+
74
80
# Parse arguments
75
81
parsed_args = parse_arguments (args )
76
-
82
+
77
83
# Get parameters from config file
78
84
params = get_parameters (parsed_args )
79
- logger .info (f"Processing probe ID: { params ['pid' ]} " )
80
-
85
+
86
+ # Set up logger with config path
87
+ logger = setup_logger (__name__ , file = params .get ("log_path" ))
88
+ logger .info ("Starting main function" )
89
+
81
90
# Initialize ONE if using PID
82
91
one = ONE ()
83
- if params [' pid' ] is not None :
92
+ if params [" pid" ] is not None :
84
93
logger .info ("ONE client initialized" )
85
94
logger .info (f"Processing probe ID: { params ['pid' ]} " )
86
95
else :
87
96
logger .info (f"Processing files: AP={ params ['ap_file' ]} , LF={ params ['lf_file' ]} " )
88
-
97
+
89
98
df_features = None
90
99
# Determine features file path
91
- features_path = params .get (' features_path' )
100
+ features_path = params .get (" features_path" )
92
101
if features_path is None :
93
- if params [' pid' ] is not None :
102
+ if params [" pid" ] is not None :
94
103
features_path = Path (f"features_{ params ['pid' ]} .parquet" )
95
104
else :
96
105
# Generate 8 character alphanumeric filename
97
- filename = '' .join (random .choices (string .ascii_letters + string .digits , k = 8 ))
106
+ filename = "" .join (
107
+ random .choices (string .ascii_letters + string .digits , k = 8 )
108
+ )
98
109
features_path = Path (f"features_{ filename } .parquet" )
99
110
logger .info (f"Generated features filename: { features_path } " )
100
111
else :
101
112
features_path = Path (features_path )
102
113
# Ensure the file has .parquet extension
103
- if features_path .suffix != ' .parquet' :
104
- features_path = features_path .with_suffix (' .parquet' )
105
-
114
+ if features_path .suffix != " .parquet" :
115
+ features_path = features_path .with_suffix (" .parquet" )
116
+
106
117
# Compute features if mode is 'features' or 'both'
107
- if params [' mode' ] in [' features' , ' both' ]:
118
+ if params [" mode" ] in [" features" , " both" ]:
108
119
logger .info ("Starting feature computation" )
109
120
df_features = compute_features (
110
- pid = params .get (' pid' ),
111
- t_start = params [' t_start' ],
112
- duration = params [' duration' ],
121
+ pid = params .get (" pid" ),
122
+ t_start = params [" t_start" ],
123
+ duration = params [" duration" ],
113
124
one = one ,
114
- ap_file = params .get (' ap_file' ),
115
- lf_file = params .get (' lf_file' ),
116
- traj_dict = params .get (' traj_dict' )
125
+ ap_file = params .get (" ap_file" ),
126
+ lf_file = params .get (" lf_file" ),
127
+ traj_dict = params .get (" traj_dict" ),
117
128
)
118
129
logger .info (f"Feature computation completed. Shape: { df_features .shape } " )
119
-
130
+
120
131
# Save features to parquet file
121
132
logger .info (f"Saving features to { features_path } " )
122
133
df_features .to_parquet (features_path , index = True )
123
-
134
+
124
135
# Infer regions if mode is 'inference' or 'both'
125
- if params [' mode' ] in [' inference' , ' both' ]:
136
+ if params [" mode" ] in [" inference" , " both" ]:
126
137
logger .info ("Starting region inference" )
127
138
# Get model path from parameters or use default
128
- model_path = params .get (' model_path' )
139
+ model_path = params .get (" model_path" )
129
140
if model_path is None :
130
- model_path = Path ("/Users/pranavrai/Downloads/models/2024_W50_Cosmos_voter-snap-pudding/" )
141
+ model_path = Path (
142
+ "/Users/pranavrai/Downloads/models/2024_W50_Cosmos_voter-snap-pudding/"
143
+ )
131
144
else :
132
145
model_path = Path (model_path )
133
-
146
+
134
147
# If df_features is None, load from file
135
148
if df_features is None :
136
149
# This should only happen in inference mode
137
- assert params [' mode' ] == ' inference'
150
+ assert params [" mode" ] == " inference"
138
151
if not features_path .exists ():
139
- raise ValueError (f"Features file not found at { features_path } . Please compute features first." )
152
+ raise ValueError (
153
+ f"Features file not found at { features_path } . Please compute features first."
154
+ )
140
155
logger .info (f"Loading features from { features_path } " )
141
156
df_features = pd .read_parquet (features_path )
142
-
143
-
157
+
144
158
predicted_probas , predicted_regions = infer_regions (df_features , model_path )
145
159
logger .info (f"Predicted regions shape: { predicted_regions .shape } " )
146
160
logger .info (f"Prediction probabilities shape: { predicted_probas .shape } " )
@@ -149,24 +163,28 @@ def main(args: Optional[List[str]] = None) -> int:
149
163
output_dir = features_path .parent
150
164
np_probas_path = output_dir / f"probas_{ params ['pid' ]} .npy"
151
165
np_regions_path = output_dir / f"regions_{ params ['pid' ]} .npy"
152
-
153
- logger .info (f"Saving prediction probabilities as numpy array to { np_probas_path } " )
166
+
167
+ logger .info (
168
+ f"Saving prediction probabilities as numpy array to { np_probas_path } "
169
+ )
154
170
np .save (np_probas_path , predicted_probas )
155
-
171
+
156
172
logger .info (f"Saving predicted regions as numpy array to { np_regions_path } " )
157
173
np .save (np_regions_path , predicted_regions )
158
174
159
- #Plot the results
160
- #Todo need to have better interface than calling dict_model here just for plotting.
161
- dict_model = decoding .load_model (model_path .joinpath (f' FOLD04' ))
175
+ # Plot the results
176
+ # Todo need to have better interface than calling dict_model here just for plotting.
177
+ dict_model = decoding .load_model (model_path .joinpath (" FOLD04" ))
162
178
fig , ax = plot_results (df_features , predicted_probas , dict_model )
163
179
import matplotlib .pyplot as plt
180
+
164
181
plt .savefig (output_dir / f"results_{ params ['pid' ]} .png" )
165
-
182
+
166
183
return 0
167
184
168
185
169
186
if __name__ == "__main__" :
170
187
exit_code = main ()
171
188
import sys
172
- sys .exit (exit_code )
189
+
190
+ sys .exit (exit_code )
0 commit comments