From 97d9edf3c0311fadeca72b05965a558b176dc24b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jon=20Haitz=20Legarreta=20Gorro=C3=B1o?= Date: Sat, 21 Dec 2024 20:27:06 -0500 Subject: [PATCH 1/4] ENH: Fix return value type annotation in GP error analysis script func Fix return value type annotation in GP error analysis script function. Fixes: ``` scripts/dwi_gp_estimation_error_analysis.py:78: error: Incompatible return value type (got "ndarray[Any, Any]", expected "dict[int, list[tuple[ndarray[Any, Any], ndarray[Any, Any], ndarray[Any, Any], ndarray[Any, Any]]]]") [return-value] ``` raised for example in: https://github.com/nipreps/nifreeze/actions/runs/12437972140/job/34728973936#step:8:111 --- scripts/dwi_gp_estimation_error_analysis.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scripts/dwi_gp_estimation_error_analysis.py b/scripts/dwi_gp_estimation_error_analysis.py index eb393b97..2150358b 100644 --- a/scripts/dwi_gp_estimation_error_analysis.py +++ b/scripts/dwi_gp_estimation_error_analysis.py @@ -49,7 +49,7 @@ def cross_validate( cv: int, n_repeats: int, gpr: DiffusionGPR, -) -> dict[int, list[tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]]]: +) -> np.ndarray: """ Perform the experiment by estimating the dMRI signal using a Gaussian process model. @@ -68,7 +68,7 @@ def cross_validate( Returns ------- - :obj:`dict` + :obj:`~numpy.ndarray` Data for the predicted signal and its error. """ From 5c3e7d32b0602be8c10ffd3b5fab2d62f37fc960 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jon=20Haitz=20Legarreta=20Gorro=C3=B1o?= Date: Sat, 21 Dec 2024 20:30:12 -0500 Subject: [PATCH 2/4] ENH: Use a boolean to tell whether to serialize the pandas df index Use a boolean to tell whether to serialize the pandas dataframe index. Fixes: ``` scripts/dwi_gp_estimation_error_analysis.py:220: error: No overload variant of "to_csv" of "NDFrame" matches argument types "Any", "str", "None", "str" [call-overload] scripts/dwi_gp_estimation_error_analysis.py:220: note: Possible overload variants: scripts/dwi_gp_estimation_error_analysis.py:220: note: def (...) ``` raised for example in: https://github.com/nipreps/nifreeze/actions/runs/12437972140/job/34728973936#step:8:121 --- scripts/dwi_gp_estimation_error_analysis.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/dwi_gp_estimation_error_analysis.py b/scripts/dwi_gp_estimation_error_analysis.py index 2150358b..a4b33d51 100644 --- a/scripts/dwi_gp_estimation_error_analysis.py +++ b/scripts/dwi_gp_estimation_error_analysis.py @@ -217,7 +217,7 @@ def main() -> None: print(f"Finished {n}-fold cross-validation") scores_df = pd.DataFrame(scores) - scores_df.to_csv(args.output_scores, sep="\t", index=None, na_rep="n/a") + scores_df.to_csv(args.output_scores, sep="\t", index=False, na_rep="n/a") grouped = scores_df.groupby(["n_folds"]) print(grouped[["rmse"]].mean()) From d9dafbe197530acff4aae9ae0cf98a3bc6232ac1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jon=20Haitz=20Legarreta=20Gorro=C3=B1o?= Date: Sat, 21 Dec 2024 20:36:34 -0500 Subject: [PATCH 3/4] BUG: Provide missing position arg to local cross validation function Provide missing position argument to local cross validation function in GP estimation error analysis script. Fixes ``` scripts/dwi_gp_estimation_error_analysis.py:210: error: Missing positional argument "gpr" in call to "cross_validate" [call-arg] scripts/dwi_gp_estimation_error_analysis.py:210: error: Unsupported operand types for * ("float" and "dict[int, list[tuple[ndarray[Any, Any], ndarray[Any, Any], ndarray[Any, Any], ndarray[Any, Any]]]]") [operator] scripts/dwi_gp_estimation_error_analysis.py:210: error: Argument 4 to "cross_validate" has incompatible type "DiffusionGPR"; expected "int" [arg-type] scripts/dwi_gp_estimation_error_analysis.py:211: error: "float" has no attribute "tolist" [attr-defined] scripts/dwi_gp_estimation_error_analysis.py:212: error: Argument 1 to "len" has incompatible type "float"; expected "Sized" [arg-type] scripts/dwi_gp_estimation_error_analysis.py:213: error: Argument 1 to "len" has incompatible type "float"; expected "Sized" [arg-type] scripts/dwi_gp_estimation_error_analysis.py:214: error: Argument 1 to "len" has incompatible type "float"; expected "Sized" [arg-type] scripts/dwi_gp_estimation_error_analysis.py:215: error: Argument 1 to "len" has incompatible type "float"; expected "Sized" [arg-type] ``` raised for example in: https://github.com/nipreps/nifreeze/actions/runs/12437972140/job/34728973936#step:8:113 --- scripts/dwi_gp_estimation_error_analysis.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/scripts/dwi_gp_estimation_error_analysis.py b/scripts/dwi_gp_estimation_error_analysis.py index a4b33d51..936931a5 100644 --- a/scripts/dwi_gp_estimation_error_analysis.py +++ b/scripts/dwi_gp_estimation_error_analysis.py @@ -202,12 +202,14 @@ def main() -> None: # max_iter=2e5, ) + n_repeats = 10 + if args.kfold: # Use Scikit-learn cross validation scores = defaultdict(list, {}) for n in args.kfold: for i in range(args.repeats): - cv_scores = -1.0 * cross_validate(X, y.T, n, gpr) + cv_scores = -1.0 * cross_validate(X, y.T, n, n_repeats, gpr) scores["rmse"] += cv_scores.tolist() scores["repeat"] += [i] * len(cv_scores) scores["n_folds"] += [n] * len(cv_scores) From 6e8311fe920f05cb07409896a23d8f9495a126ed Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jon=20Haitz=20Legarreta=20Gorro=C3=B1o?= Date: Sat, 21 Dec 2024 20:48:15 -0500 Subject: [PATCH 4/4] ENH: Add type annotation for variable in GP estimation error analysis Add type annotation for local variable `scores` in GP estimation error analysis script. Fixes: ``` scripts/dwi_gp_estimation_error_analysis.py:207: error: Need type annotation for "scores" [var-annotated] ``` raised for example in: https://github.com/nipreps/nifreeze/actions/runs/12437972140/job/34728973936#step:8:112 --- scripts/dwi_gp_estimation_error_analysis.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/scripts/dwi_gp_estimation_error_analysis.py b/scripts/dwi_gp_estimation_error_analysis.py index 936931a5..a0616493 100644 --- a/scripts/dwi_gp_estimation_error_analysis.py +++ b/scripts/dwi_gp_estimation_error_analysis.py @@ -31,6 +31,7 @@ import argparse from collections import defaultdict from pathlib import Path +from typing import DefaultDict, List import numpy as np import pandas as pd @@ -206,7 +207,7 @@ def main() -> None: if args.kfold: # Use Scikit-learn cross validation - scores = defaultdict(list, {}) + scores: DefaultDict[str, List[float | str]] = defaultdict(list) for n in args.kfold: for i in range(args.repeats): cv_scores = -1.0 * cross_validate(X, y.T, n, n_repeats, gpr)