Skip to content

Commit 1c78b21

Browse files
author
Gregory Meltzer
authored
[RAPTOR-3968] Allow DRUM multiclass with only 2 labels (#237), release 1.4.5
* [RAPTOR-3968] Allow DRUM multiclass with only 2 labels * fix transform test * remove un-needed print * change to full release instead of RC
1 parent 2ce93f5 commit 1c78b21

32 files changed

+139
-55
lines changed

custom_model_runner/CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@ All notable changes to this project will be documented in this file.
44
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
55
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
66

7+
#### [1.4.5] - 2020-12-02
8+
##### Changes
9+
- Allow multiclass to function with only 2 labels
10+
711
#### [1.4.4] - 2020-11-24
812
##### Added
913
- New `transform` target type for performing pre-/post- processing on features/targets

custom_model_runner/datarobot_drum/drum/args_parser.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -211,8 +211,8 @@ def are_both_labels_present(arg):
211211
@staticmethod
212212
def _reg_arg_multiclass_labels(*parsers):
213213
class RequiredLength(argparse.Action):
214-
ERROR_MESSAGE = "Multiclass classification requires at least 3 labels."
215-
MIN_LABELS = 3
214+
ERROR_MESSAGE = "Multiclass classification requires at least 2 labels."
215+
MIN_LABELS = 2
216216

217217
def __call__(self, parser, namespace, values, option_string=None):
218218
if len(values) < self.MIN_LABELS:

custom_model_runner/datarobot_drum/drum/artifact_predictors/keras_predictor.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -73,12 +73,18 @@ def predict(self, data, model, **kwargs):
7373
if self.target_type.value in TargetType.CLASSIFICATION.value:
7474
if predictions.shape[1] == 1:
7575
if self.target_type == TargetType.MULTICLASS:
76-
raise DrumCommonException(
77-
"Target type '{}' predictions must return the "
78-
"probability distribution for all class labels".format(self.target_type)
79-
)
80-
predictions = pd.DataFrame(predictions, columns=[self.positive_class_label])
81-
predictions[self.negative_class_label] = 1 - predictions[self.positive_class_label]
76+
if len(self.class_labels) > 2:
77+
raise DrumCommonException(
78+
"Target type '{}' predictions must return the "
79+
"probability distribution for all class labels".format(self.target_type)
80+
)
81+
pos_label = self.class_labels[1]
82+
neg_label = self.class_labels[0]
83+
else:
84+
pos_label = self.positive_class_label
85+
neg_label = self.negative_class_label
86+
predictions = pd.DataFrame(predictions, columns=[pos_label])
87+
predictions[neg_label] = 1 - predictions[pos_label]
8288
else:
8389
predictions = pd.DataFrame(predictions, columns=self.class_labels)
8490
elif self.target_type in [TargetType.REGRESSION, TargetType.ANOMALY]:

custom_model_runner/datarobot_drum/drum/artifact_predictors/torch_predictor.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -77,12 +77,18 @@ def predict(self, data, model, **kwargs):
7777
if self.target_type.value in TargetType.CLASSIFICATION.value:
7878
if predictions.shape[1] == 1:
7979
if self.target_type == TargetType.MULTICLASS:
80-
raise DrumCommonException(
81-
"Target type '{}' predictions must return the "
82-
"probability distribution for all class labels".format(self.target_type)
83-
)
84-
predictions = pd.DataFrame(predictions, columns=[self.positive_class_label])
85-
predictions[self.negative_class_label] = 1 - predictions[self.positive_class_label]
80+
if len(self.class_labels) > 2:
81+
raise DrumCommonException(
82+
"Target type '{}' predictions must return the "
83+
"probability distribution for all class labels".format(self.target_type)
84+
)
85+
pos_label = self.class_labels[1]
86+
neg_label = self.class_labels[0]
87+
else:
88+
pos_label = self.positive_class_label
89+
neg_label = self.negative_class_label
90+
predictions = pd.DataFrame(predictions, columns=[pos_label])
91+
predictions[neg_label] = 1 - predictions[pos_label]
8692
else:
8793
predictions = pd.DataFrame(predictions, columns=self.class_labels)
8894
elif self.target_type in [TargetType.REGRESSION, TargetType.ANOMALY]:

custom_model_runner/datarobot_drum/drum/artifact_predictors/xgboost_predictor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def predict(self, data, model, **kwargs):
8383
if self.target_type.value in TargetType.CLASSIFICATION.value:
8484
if xgboost_native:
8585
predictions = model.predict(data)
86-
if self.target_type == TargetType.BINARY:
86+
if self.target_type == TargetType.BINARY or len(self.class_labels) == 2:
8787
negative_preds = 1 - predictions
8888
predictions = np.concatenate(
8989
(negative_preds.reshape(-1, 1), predictions.reshape(-1, 1)), axis=1
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
version = "1.4.4"
1+
version = "1.4.5"
22
__version__ = version
33
project_name = "datarobot-drum"

custom_model_runner/datarobot_drum/drum/drum.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -455,7 +455,7 @@ def _prepare_fit_pipeline(self, run_language):
455455
self.options.negative_class_label,
456456
) = possible_class_labels
457457
elif self.target_type == TargetType.MULTICLASS:
458-
if len(possible_class_labels) <= 2:
458+
if len(possible_class_labels) < 2:
459459
raise DrumCommonException(
460460
"Target type {} requires more than 2 class labels. Detected {}: {}".format(
461461
TargetType.MULTICLASS,
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
pyarrow==0.14.1
2-
datarobot-drum==1.4.4
2+
datarobot-drum==1.4.5

public_dropin_environments/java_codegen/env_info.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,6 @@
33
"name": "[DataRobot] Java Drop-In (DR Codegen, H2O)",
44
"description": "This template can be used as an environment for DataRobot generated scoring code or models that implement the either the IClassificationPredictor or IRegressionPredictor interface from the datarobot-prediction package and for H2O models exported as POJO or MOJO.",
55
"programmingLanguage": "java",
6-
"environmentVersionId": "5fc33e141d41c812de88c04d",
6+
"environmentVersionId": "5fc7f2e5e9790c6fd6032869",
77
"isPublic": true
88
}
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
pyarrow==0.14.1
2-
datarobot-drum==1.4.4
2+
datarobot-drum==1.4.5

0 commit comments

Comments
 (0)