Skip to content

Commit 0ea276d

Browse files
authored
test multioutput for metrics.mean_squared_log_error() (#825)
1 parent 090dd27 commit 0ea276d

File tree

1 file changed

+1
-12
lines changed

1 file changed

+1
-12
lines changed

tests/metrics/test_regression.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
_METRICS_TO_TEST = [
1313
"mean_squared_error",
14+
"mean_squared_log_error",
1415
"mean_absolute_error",
1516
"r2_score",
1617
]
@@ -64,18 +65,6 @@ def test_mse_squared(squared):
6465
assert abs(result - expected) < 1e-5
6566

6667

67-
def test_mean_squared_log_error():
68-
m1 = dask_ml.metrics.mean_squared_log_error
69-
m2 = sklearn.metrics.mean_squared_log_error
70-
71-
a = da.random.uniform(size=(100,), chunks=(25,))
72-
b = da.random.uniform(size=(100,), chunks=(25,))
73-
74-
result = m1(a, b)
75-
expected = m2(a, b)
76-
assert abs(result - expected) < 1e-5
77-
78-
7968
@pytest.mark.parametrize("multioutput", ["uniform_average", None])
8069
def test_regression_metrics_unweighted_average_multioutput(metric_pairs, multioutput):
8170
m1, m2 = metric_pairs

0 commit comments

Comments
 (0)