Skip to content

Commit 1c78b21

Browse files
author
Gregory Meltzer
authored
[RAPTOR-3968] Allow DRUM multiclass with only 2 labels (#237), release 1.4.5
* [RAPTOR-3968] Allow DRUM multiclass with only 2 labels * fix transform test * remove un-needed print * change to full release instead of RC
1 parent 2ce93f5 commit 1c78b21

32 files changed

+139
-55
lines changed

custom_model_runner/CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@ All notable changes to this project will be documented in this file.
44
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
55
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
66

7+
#### [1.4.5] - 2020-12-02
8+
##### Changes
9+
- Allow multiclass to function with only 2 labels
10+
711
#### [1.4.4] - 2020-11-24
812
##### Added
913
- New `transform` target type for performing pre-/post- processing on features/targets

custom_model_runner/datarobot_drum/drum/args_parser.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -211,8 +211,8 @@ def are_both_labels_present(arg):
211211
@staticmethod
212212
def _reg_arg_multiclass_labels(*parsers):
213213
class RequiredLength(argparse.Action):
214-
ERROR_MESSAGE = "Multiclass classification requires at least 3 labels."
215-
MIN_LABELS = 3
214+
ERROR_MESSAGE = "Multiclass classification requires at least 2 labels."
215+
MIN_LABELS = 2
216216

217217
def __call__(self, parser, namespace, values, option_string=None):
218218
if len(values) < self.MIN_LABELS:

custom_model_runner/datarobot_drum/drum/artifact_predictors/keras_predictor.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -73,12 +73,18 @@ def predict(self, data, model, **kwargs):
7373
if self.target_type.value in TargetType.CLASSIFICATION.value:
7474
if predictions.shape[1] == 1:
7575
if self.target_type == TargetType.MULTICLASS:
76-
raise DrumCommonException(
77-
"Target type '{}' predictions must return the "
78-
"probability distribution for all class labels".format(self.target_type)
79-
)
80-
predictions = pd.DataFrame(predictions, columns=[self.positive_class_label])
81-
predictions[self.negative_class_label] = 1 - predictions[self.positive_class_label]
76+
if len(self.class_labels) > 2:
77+
raise DrumCommonException(
78+
"Target type '{}' predictions must return the "
79+
"probability distribution for all class labels".format(self.target_type)
80+
)
81+
pos_label = self.class_labels[1]
82+
neg_label = self.class_labels[0]
83+
else:
84+
pos_label = self.positive_class_label
85+
neg_label = self.negative_class_label
86+
predictions = pd.DataFrame(predictions, columns=[pos_label])
87+
predictions[neg_label] = 1 - predictions[pos_label]
8288
else:
8389
predictions = pd.DataFrame(predictions, columns=self.class_labels)
8490
elif self.target_type in [TargetType.REGRESSION, TargetType.ANOMALY]:

custom_model_runner/datarobot_drum/drum/artifact_predictors/torch_predictor.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -77,12 +77,18 @@ def predict(self, data, model, **kwargs):
7777
if self.target_type.value in TargetType.CLASSIFICATION.value:
7878
if predictions.shape[1] == 1:
7979
if self.target_type == TargetType.MULTICLASS:
80-
raise DrumCommonException(
81-
"Target type '{}' predictions must return the "
82-
"probability distribution for all class labels".format(self.target_type)
83-
)
84-
predictions = pd.DataFrame(predictions, columns=[self.positive_class_label])
85-
predictions[self.negative_class_label] = 1 - predictions[self.positive_class_label]
80+
if len(self.class_labels) > 2:
81+
raise DrumCommonException(
82+
"Target type '{}' predictions must return the "
83+
"probability distribution for all class labels".format(self.target_type)
84+
)
85+
pos_label = self.class_labels[1]
86+
neg_label = self.class_labels[0]
87+
else:
88+
pos_label = self.positive_class_label
89+
neg_label = self.negative_class_label
90+
predictions = pd.DataFrame(predictions, columns=[pos_label])
91+
predictions[neg_label] = 1 - predictions[pos_label]
8692
else:
8793
predictions = pd.DataFrame(predictions, columns=self.class_labels)
8894
elif self.target_type in [TargetType.REGRESSION, TargetType.ANOMALY]:

custom_model_runner/datarobot_drum/drum/artifact_predictors/xgboost_predictor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def predict(self, data, model, **kwargs):
8383
if self.target_type.value in TargetType.CLASSIFICATION.value:
8484
if xgboost_native:
8585
predictions = model.predict(data)
86-
if self.target_type == TargetType.BINARY:
86+
if self.target_type == TargetType.BINARY or len(self.class_labels) == 2:
8787
negative_preds = 1 - predictions
8888
predictions = np.concatenate(
8989
(negative_preds.reshape(-1, 1), predictions.reshape(-1, 1)), axis=1
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
version = "1.4.4"
1+
version = "1.4.5"
22
__version__ = version
33
project_name = "datarobot-drum"

custom_model_runner/datarobot_drum/drum/drum.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -455,7 +455,7 @@ def _prepare_fit_pipeline(self, run_language):
455455
self.options.negative_class_label,
456456
) = possible_class_labels
457457
elif self.target_type == TargetType.MULTICLASS:
458-
if len(possible_class_labels) <= 2:
458+
if len(possible_class_labels) < 2:
459459
raise DrumCommonException(
460460
"Target type {} requires more than 2 class labels. Detected {}: {}".format(
461461
TargetType.MULTICLASS,
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
pyarrow==0.14.1
2-
datarobot-drum==1.4.4
2+
datarobot-drum==1.4.5

public_dropin_environments/java_codegen/env_info.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,6 @@
33
"name": "[DataRobot] Java Drop-In (DR Codegen, H2O)",
44
"description": "This template can be used as an environment for DataRobot generated scoring code or models that implement the either the IClassificationPredictor or IRegressionPredictor interface from the datarobot-prediction package and for H2O models exported as POJO or MOJO.",
55
"programmingLanguage": "java",
6-
"environmentVersionId": "5fc33e141d41c812de88c04d",
6+
"environmentVersionId": "5fc7f2e5e9790c6fd6032869",
77
"isPublic": true
88
}
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
pyarrow==0.14.1
2-
datarobot-drum==1.4.4
2+
datarobot-drum==1.4.5

public_dropin_environments/python3_keras/env_info.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,6 @@
33
"name": "[DataRobot] Python 3 Keras Drop-In",
44
"description": "This template environment can be used to create artifact-only keras custom models. This environment contains keras backed by tensorflow and only requires your model artifact as a .h5 file and optionally a custom.py file.",
55
"programmingLanguage": "python",
6-
"environmentVersionId": "5fc33e141d41c812de88c04e",
6+
"environmentVersionId": "5fc7f2e5e9790c6fd6032868",
77
"isPublic": true
88
}
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
pyarrow==0.14.1
2-
datarobot-drum==1.4.4
2+
datarobot-drum==1.4.5

public_dropin_environments/python3_pmml/env_info.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,6 @@
33
"name": "[DataRobot] Python 3 PMML Drop-In",
44
"description": "This template environment can be used to create artifact-only PMML custom models. This environment contains PyPMML and only requires your model artifact as a .pmml file and optionally a custom.py file.",
55
"programmingLanguage": "python",
6-
"environmentVersionId": "5fc33e141d41c812de88c04f",
6+
"environmentVersionId": "5fc7f2e5e9790c6fd603286a",
77
"isPublic": true
88
}
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
pyarrow==0.14.1
2-
datarobot-drum==1.4.4
2+
datarobot-drum==1.4.5

public_dropin_environments/python3_pytorch/env_info.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,6 @@
33
"name": "[DataRobot] Python 3 PyTorch Drop-In",
44
"description": "This template environment can be used to create artifact-only PyTorch custom models. This environment contains PyTorch and requires only your model artifact as a .pth file, any other code needed to deserialize your model, and optionally a custom.py file.",
55
"programmingLanguage": "python",
6-
"environmentVersionId": "5fc33e141d41c812de88c04a",
6+
"environmentVersionId": "5fc7f2e5e9790c6fd6032864",
77
"isPublic": true
88
}
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
pyarrow==0.14.1
2-
datarobot-drum==1.4.4
2+
datarobot-drum==1.4.5

public_dropin_environments/python3_sklearn/env_info.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,6 @@
33
"name": "[DataRobot] Python 3 Scikit-Learn Drop-In",
44
"description": "This template environment can be used to create artifact-only scikit-learn custom models. This environment contains scikit-learn and only requires your model artifact as a .pkl file and optionally a custom.py file.",
55
"programmingLanguage": "python",
6-
"environmentVersionId": "5fc33e141d41c812de88c04b",
6+
"environmentVersionId": "5fc7f2e5e9790c6fd6032867",
77
"isPublic": true
88
}
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
pyarrow==0.14.1
2-
datarobot-drum==1.4.4
2+
datarobot-drum==1.4.5

public_dropin_environments/python3_xgboost/env_info.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,6 @@
33
"name": "[DataRobot] Python 3 XGBoost Drop-In",
44
"description": "This template environment can be used to create artifact-only xgboost custom models. This environment contains xgboost and only requires your model artifact as a .pkl file and optionally a custom.py file.",
55
"programmingLanguage": "python",
6-
"environmentVersionId": "5fc33e141d41c812de88c04c",
6+
"environmentVersionId": "5fc7f2e5e9790c6fd6032865",
77
"isPublic": true
88
}

public_dropin_environments/r_lang/dr_requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@ numpy>=1.16.0,<1.19.0
22
pandas==1.1.0
33
rpy2<=3.3.6
44
pyarrow==0.14.1
5-
datarobot-drum[R]==1.4.4
5+
datarobot-drum[R]==1.4.5

public_dropin_environments/r_lang/env_info.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,6 @@
33
"name": "[DataRobot] R Drop-In",
44
"description": "This template environment can be used to create artifact-only R custom models that use the caret library. Your custom model archive need only contain your model artifacts if you use the environment correctly.",
55
"programmingLanguage": "r",
6-
"environmentVersionId": "5fc33e141d41c812de88c049",
6+
"environmentVersionId": "5fc7f2e5e9790c6fd6032866",
77
"isPublic": true
88
}

tests/conftest.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
MULTI_ARTIFACT,
1616
MULTICLASS,
1717
MULTICLASS_NUM_LABELS,
18+
MULTICLASS_BINARY,
1819
NO_CUSTOM,
1920
POJO,
2021
PYPMML,
@@ -72,6 +73,7 @@
7273
(None, MULTICLASS_NUM_LABELS): os.path.join(
7374
TESTS_DATA_PATH, "skyserver_sql2_27_2018_6_51_39_pm_num_class.csv"
7475
),
76+
(None, MULTICLASS_BINARY): os.path.join(TESTS_DATA_PATH, "iris_binary_training.csv"),
7577
(None, SPARSE): os.path.join(TESTS_DATA_PATH, "sparse.mtx"),
7678
(None, SPARSE_TARGET): os.path.join(TESTS_DATA_PATH, "sparse_target.csv"),
7779
(None, BINARY_BOOL): os.path.join(TESTS_DATA_PATH, "10k_diabetes_sample.csv"),
@@ -106,6 +108,7 @@
106108
REGRESSION: "MEDV",
107109
BINARY_TEXT: "Churn",
108110
MULTICLASS: "class",
111+
MULTICLASS_BINARY: "Species",
109112
MULTICLASS_NUM_LABELS: "class",
110113
SPARSE: "my_target",
111114
BINARY_BOOL: "readmitted",
@@ -120,6 +123,8 @@
120123
ANOMALY: "anomaly",
121124
UNSTRUCTURED: "unstructured",
122125
MULTICLASS: "multiclass",
126+
MULTICLASS_BINARY: "multiclass",
127+
MULTICLASS_NUM_LABELS: "multiclass",
123128
BINARY_BOOL: "binary",
124129
TRANSFORM: "transform",
125130
}
@@ -150,6 +155,16 @@
150155
(KERAS, BINARY_TEXT): ["False", "True"],
151156
(POJO, MULTICLASS): ["GALAXY", "QSO", "STAR"],
152157
(MOJO, MULTICLASS): ["GALAXY", "QSO", "STAR"],
158+
(SKLEARN_BINARY, MULTICLASS_BINARY): ["Iris-setosa", "Iris-versicolor"],
159+
(SKLEARN, MULTICLASS_BINARY): ["Iris-setosa", "Iris-versicolor"],
160+
(XGB, MULTICLASS_BINARY): ["Iris-setosa", "Iris-versicolor"],
161+
(KERAS, MULTICLASS_BINARY): ["Iris-setosa", "Iris-versicolor"],
162+
(RDS, MULTICLASS_BINARY): ["Iris-setosa", "Iris-versicolor"],
163+
(PYPMML, MULTICLASS_BINARY): ["Iris-setosa", "Iris-versicolor"],
164+
(PYTORCH, MULTICLASS_BINARY): ["Iris-setosa", "Iris-versicolor"],
165+
(CODEGEN, MULTICLASS_BINARY): ["yes", "no"],
166+
(MOJO, MULTICLASS_BINARY): ["yes", "no"],
167+
(POJO, MULTICLASS_BINARY): ["yes", "no"],
153168
}
154169

155170
_artifacts = {
@@ -215,6 +230,21 @@
215230
(SKLEARN_TRANSFORM_DENSE, TRANSFORM): os.path.join(
216231
TESTS_ARTIFACTS_PATH, "sklearn_transform_dense.pkl"
217232
),
233+
(SKLEARN, MULTICLASS_BINARY): os.path.join(TESTS_ARTIFACTS_PATH, "sklearn_bin.pkl"),
234+
(KERAS, MULTICLASS_BINARY): os.path.join(TESTS_ARTIFACTS_PATH, "keras_bin.h5"),
235+
(XGB, MULTICLASS_BINARY): os.path.join(TESTS_ARTIFACTS_PATH, "xgb_bin.pkl"),
236+
(PYTORCH, MULTICLASS_BINARY): [
237+
os.path.join(TESTS_ARTIFACTS_PATH, "torch_bin.pth"),
238+
os.path.join(TESTS_ARTIFACTS_PATH, "PyTorch.py"),
239+
],
240+
(RDS, MULTICLASS_BINARY): os.path.join(TESTS_ARTIFACTS_PATH, "r_bin.rds"),
241+
(CODEGEN, MULTICLASS_BINARY): os.path.join(TESTS_ARTIFACTS_PATH, "java_bin.jar"),
242+
(POJO, MULTICLASS_BINARY): os.path.join(
243+
TESTS_ARTIFACTS_PATH,
244+
"XGBoost_grid__1_AutoML_20200717_163214_model_159.java",
245+
),
246+
(MOJO, MULTICLASS_BINARY): os.path.join(TESTS_ARTIFACTS_PATH, "mojo_bin.zip"),
247+
(PYPMML, MULTICLASS_BINARY): os.path.join(TESTS_ARTIFACTS_PATH, "iris_bin.pmml"),
218248
}
219249

220250
_custom_filepaths = {
@@ -306,7 +336,7 @@ def _foo(problem):
306336
@pytest.fixture(scope="session")
307337
def get_target_type():
308338
def _foo(problem):
309-
return _target_types[problem]
339+
return _target_types.get(problem, problem)
310340

311341
return _foo
312342

tests/drum/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
MULTICLASS = "multiclass"
4747
TRANSFORM = "transform"
4848
MULTICLASS_NUM_LABELS = "multiclass_num_labels"
49+
MULTICLASS_BINARY = "multiclass_binary" # special case for testing multiclass with only 2 classes
4950
SPARSE = "sparse"
5051
SPARSE_TARGET = "sparse_target"
5152

tests/drum/drum_server_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def __init__(
7878
ArgumentsOptions.MAIN_COMMAND, custom_model_dir, target_type, self.server_address
7979
)
8080
if labels:
81-
cmd = _cmd_add_class_labels(cmd, labels)
81+
cmd = _cmd_add_class_labels(cmd, labels, target_type=target_type)
8282
if docker:
8383
cmd += " --docker {}".format(docker)
8484
if memory:

tests/drum/test_args_parser.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def test_valid_class_labels(self, request, valid_labels, as_file, parser):
4545

4646
@pytest.mark.parametrize("as_file", [True, False])
4747
def test_too_few_labels(self, as_file, parser):
48-
labels = list("AB")
48+
labels = list("A")
4949
with NamedTemporaryFile() as f:
5050
if as_file:
5151
f.write("\n".join(labels).encode("utf-8"))
@@ -54,5 +54,5 @@ def test_too_few_labels(self, as_file, parser):
5454
else:
5555
args = ["dummy", "--class-labels", *labels]
5656

57-
with pytest.raises(argparse.ArgumentTypeError, match="at least 3"):
57+
with pytest.raises(argparse.ArgumentTypeError, match="at least 2"):
5858
parser.parse_args(args)

tests/drum/test_fit.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,14 @@
1717
KERAS,
1818
MULTICLASS,
1919
MULTICLASS_NUM_LABELS,
20+
MULTICLASS_BINARY,
2021
PYTHON,
2122
PYTORCH,
2223
PYTORCH_MULTICLASS,
2324
R_FIT,
2425
RDS,
2526
REGRESSION,
2627
BINARY_TEXT,
27-
TRANSFORM,
2828
SIMPLE,
2929
SKLEARN,
3030
SKLEARN_BINARY,
@@ -37,6 +37,7 @@
3737
WEIGHTS_CSV,
3838
XGB,
3939
BINARY_BOOL,
40+
TRANSFORM,
4041
)
4142
from .utils import _cmd_add_class_labels, _create_custom_model_dir, _exec_shell_cmd
4243

@@ -106,7 +107,9 @@ def test_fit_for_use_output_and_nested(
106107
if use_output:
107108
cmd += " --output {}".format(output)
108109
if problem == BINARY:
109-
cmd = _cmd_add_class_labels(cmd, resources.class_labels(framework, problem))
110+
cmd = _cmd_add_class_labels(
111+
cmd, resources.class_labels(framework, problem), target_type=problem
112+
)
110113
if docker:
111114
cmd += " --docker {} ".format(docker)
112115

@@ -123,6 +126,7 @@ def test_fit_for_use_output_and_nested(
123126
(RDS, BINARY_TEXT, None),
124127
(RDS, REGRESSION, None),
125128
(RDS, MULTICLASS, None),
129+
(RDS, MULTICLASS_BINARY, None),
126130
(SKLEARN_BINARY, BINARY_TEXT, DOCKER_PYTHON_SKLEARN),
127131
(SKLEARN_REGRESSION, REGRESSION, DOCKER_PYTHON_SKLEARN),
128132
(SKLEARN_ANOMALY, ANOMALY, DOCKER_PYTHON_SKLEARN),
@@ -131,18 +135,22 @@ def test_fit_for_use_output_and_nested(
131135
(SKLEARN_REGRESSION, REGRESSION, None),
132136
(SKLEARN_ANOMALY, ANOMALY, None),
133137
(SKLEARN_MULTICLASS, MULTICLASS, None),
138+
(SKLEARN_MULTICLASS, MULTICLASS_BINARY, None),
134139
(SKLEARN_MULTICLASS, MULTICLASS_NUM_LABELS, None),
135140
(SKLEARN_TRANSFORM, REGRESSION, None),
136141
(SKLEARN_TRANSFORM, BINARY, None),
137142
(XGB, BINARY_TEXT, None),
138143
(XGB, REGRESSION, None),
139144
(XGB, MULTICLASS, None),
145+
(XGB, MULTICLASS_BINARY, None),
140146
(KERAS, BINARY_TEXT, None),
141147
(KERAS, REGRESSION, None),
142148
(KERAS, MULTICLASS, None),
149+
(KERAS, MULTICLASS_BINARY, None),
143150
(PYTORCH, BINARY_TEXT, None),
144151
(PYTORCH, REGRESSION, None),
145152
(PYTORCH_MULTICLASS, MULTICLASS, None),
153+
(PYTORCH_MULTICLASS, MULTICLASS_BINARY, None),
146154
],
147155
)
148156
@pytest.mark.parametrize("weights", [WEIGHTS_CSV, WEIGHTS_ARGS, None])
@@ -175,14 +183,9 @@ def test_fit(
175183
weights, input_dataset, r_fit=language == R_FIT
176184
)
177185

178-
if problem in [BINARY_TEXT, BINARY_BOOL]:
179-
target_type = BINARY
180-
elif problem == MULTICLASS_NUM_LABELS:
181-
target_type = MULTICLASS
182-
elif framework == SKLEARN_TRANSFORM:
183-
target_type = TRANSFORM
184-
else:
185-
target_type = problem
186+
target_type = (
187+
resources.target_types(problem) if framework != SKLEARN_TRANSFORM else TRANSFORM
188+
)
186189

187190
cmd = "{} fit --target-type {} --code-dir {} --input {} --verbose ".format(
188191
ArgumentsOptions.MAIN_COMMAND, target_type, custom_model_dir, input_dataset
@@ -193,7 +196,9 @@ def test_fit(
193196
cmd += " --unsupervised"
194197

195198
if problem in [BINARY, MULTICLASS]:
196-
cmd = _cmd_add_class_labels(cmd, resources.class_labels(framework, problem))
199+
cmd = _cmd_add_class_labels(
200+
cmd, resources.class_labels(framework, problem), target_type=target_type
201+
)
197202
if docker:
198203
cmd += " --docker {} ".format(docker)
199204

tests/drum/test_fit_variety.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def test_fit_variety(framework, variety_resources, resources, variety_data_names
4747
cmd += " --target {}".format(target)
4848

4949
if problem == BINARY:
50-
cmd = _cmd_add_class_labels(cmd, class_labels)
50+
cmd = _cmd_add_class_labels(cmd, class_labels, target_type=problem)
5151

5252
p, _, err = _exec_shell_cmd(
5353
cmd,

0 commit comments

Comments
 (0)