Skip to content

Commit d9dafbe

Browse files
committed
BUG: Provide missing position arg to local cross validation function
Provide missing position argument to local cross validation function in GP estimation error analysis script. Fixes ``` scripts/dwi_gp_estimation_error_analysis.py:210: error: Missing positional argument "gpr" in call to "cross_validate" [call-arg] scripts/dwi_gp_estimation_error_analysis.py:210: error: Unsupported operand types for * ("float" and "dict[int, list[tuple[ndarray[Any, Any], ndarray[Any, Any], ndarray[Any, Any], ndarray[Any, Any]]]]") [operator] scripts/dwi_gp_estimation_error_analysis.py:210: error: Argument 4 to "cross_validate" has incompatible type "DiffusionGPR"; expected "int" [arg-type] scripts/dwi_gp_estimation_error_analysis.py:211: error: "float" has no attribute "tolist" [attr-defined] scripts/dwi_gp_estimation_error_analysis.py:212: error: Argument 1 to "len" has incompatible type "float"; expected "Sized" [arg-type] scripts/dwi_gp_estimation_error_analysis.py:213: error: Argument 1 to "len" has incompatible type "float"; expected "Sized" [arg-type] scripts/dwi_gp_estimation_error_analysis.py:214: error: Argument 1 to "len" has incompatible type "float"; expected "Sized" [arg-type] scripts/dwi_gp_estimation_error_analysis.py:215: error: Argument 1 to "len" has incompatible type "float"; expected "Sized" [arg-type] ``` raised for example in: https://github.com/nipreps/nifreeze/actions/runs/12437972140/job/34728973936#step:8:113
1 parent 5c3e7d3 commit d9dafbe

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

scripts/dwi_gp_estimation_error_analysis.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,12 +202,14 @@ def main() -> None:
202202
# max_iter=2e5,
203203
)
204204

205+
n_repeats = 10
206+
205207
if args.kfold:
206208
# Use Scikit-learn cross validation
207209
scores = defaultdict(list, {})
208210
for n in args.kfold:
209211
for i in range(args.repeats):
210-
cv_scores = -1.0 * cross_validate(X, y.T, n, gpr)
212+
cv_scores = -1.0 * cross_validate(X, y.T, n, n_repeats, gpr)
211213
scores["rmse"] += cv_scores.tolist()
212214
scores["repeat"] += [i] * len(cv_scores)
213215
scores["n_folds"] += [n] * len(cv_scores)

0 commit comments

Comments
 (0)