Skip to content

Commit 2ce93f5

Browse files
authored
[RAPTOR-3979][RAPTOR-4016][RAPTOR-4017] Transform endpoint using arrow/mtx (#236)
* do the things * add transformer * fix some things * refactor * fix test * mixin comments * add uwsgi endpoint * run black * add nginx test * update key name, use enum values * maybe fix a test * fix enum thing * do a bunch of things * do a bunch of things? * add arrow support * do a bunch more stuff * remove a thing * black * get rid of duplicate thing * try to fix some tests * black * change key name * run black * move some things, address some comments * unused import * add sparse indicator to payload * add csv fallback * add more things to tests * run black * missed a digit * update other stuff * deal with arrow versions maybe
1 parent 1a0cfac commit 2ce93f5

File tree

14 files changed

+358
-34
lines changed

14 files changed

+358
-34
lines changed

custom_model_runner/datarobot_drum/drum/common.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
NEGATIVE_CLASS_LABEL_ARG_KEYWORD = "negative_class_label"
1616
CLASS_LABELS_ARG_KEYWORD = "class_labels"
1717
TARGET_TYPE_ARG_KEYWORD = "target_type"
18+
X_TRANSFORM_KEY = "X.transformed"
1819

1920
URL_PREFIX_ENV_VAR_NAME = "URL_PREFIX"
2021

custom_model_runner/datarobot_drum/resource/components/Python/prediction_server/prediction_server.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from datarobot_drum.drum.exceptions import DrumCommonException
66
from datarobot_drum.profiler.stats_collector import StatsCollector, StatsOperation
77
from datarobot_drum.drum.memory_monitor import MemoryMonitor
8-
from datarobot_drum.drum.common import RunLanguage, TARGET_TYPE_ARG_KEYWORD
8+
from datarobot_drum.drum.common import RunLanguage, TARGET_TYPE_ARG_KEYWORD, TargetType
99
from datarobot_drum.resource.predict_mixin import PredictMixin
1010

1111
from datarobot_drum.drum.server import (
@@ -32,7 +32,7 @@ def configure(self, params):
3232
super(PredictionServer, self).configure(params)
3333
self._show_perf = self._params.get("show_perf")
3434
self._run_language = RunLanguage(params.get("run_language"))
35-
self._target_type = params[TARGET_TYPE_ARG_KEYWORD]
35+
self._target_type = TargetType(params[TARGET_TYPE_ARG_KEYWORD])
3636

3737
self._stats_collector = StatsCollector(disable_instance=not self._show_perf)
3838

@@ -91,6 +91,21 @@ def predict():
9191
self._stats_collector.disable()
9292
return response, response_status
9393

94+
@model_api.route("/transform/", methods=["POST"])
95+
def transform():
96+
97+
logger.debug("Entering transform() endpoint")
98+
99+
self._stats_collector.enable()
100+
self._stats_collector.mark("start")
101+
102+
try:
103+
response, response_status = self.do_transform(logger=logger)
104+
finally:
105+
self._stats_collector.mark("finish")
106+
self._stats_collector.disable()
107+
return response, response_status
108+
94109
@model_api.route("/predictUnstructured/", methods=["POST"])
95110
def predict_unstructured():
96111
logger.debug("Entering predict() endpoint")

custom_model_runner/datarobot_drum/resource/components/Python/uwsgi_component/uwsgi_serving.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
URL_PREFIX_ENV_VAR_NAME,
1212
TARGET_TYPE_ARG_KEYWORD,
1313
make_predictor_capabilities,
14+
TargetType,
1415
)
1516
from datarobot_drum.profiler.stats_collector import StatsCollector, StatsOperation
1617

@@ -61,7 +62,7 @@ def configure(self, params):
6162
super(UwsgiServing, self).configure(params)
6263
self._show_perf = self._params.get("show_perf")
6364
self._run_language = RunLanguage(params.get("run_language"))
64-
self._target_type = params[TARGET_TYPE_ARG_KEYWORD]
65+
self._target_type = TargetType(params[TARGET_TYPE_ARG_KEYWORD])
6566

6667
self._stats_collector = StatsCollector(disable_instance=not self._show_perf)
6768

@@ -179,6 +180,30 @@ def predict_unstructured(self, url_params, form_params):
179180
self._stats_collector.disable()
180181
return response_status, response
181182

183+
@FlaskRoute(
184+
"{}/transform/".format(os.environ.get(URL_PREFIX_ENV_VAR_NAME, "")), methods=["POST"]
185+
)
186+
def transform(self, url_params, form_params):
187+
if self._error_response:
188+
return HTTP_513_DRUM_PIPELINE_ERROR, self._error_response
189+
190+
self._stats_collector.enable()
191+
self._stats_collector.mark("start")
192+
193+
try:
194+
response, response_status = self.do_transform()
195+
196+
if response_status == HTTP_200_OK:
197+
# this counter is managed by uwsgi
198+
self._total_predict_requests.increase()
199+
self._predict_calls_count += 1
200+
except Exception as ex:
201+
response_status, response = self._handle_exception(ex)
202+
finally:
203+
self._stats_collector.mark("finish")
204+
self._stats_collector.disable()
205+
return response_status, response
206+
182207
def _handle_exception(self, ex):
183208
self._logger.error(ex)
184209
response_status = HTTP_500_INTERNAL_SERVER_ERROR

custom_model_runner/datarobot_drum/resource/predict_mixin.py

Lines changed: 73 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,13 @@
99
TargetType,
1010
UnstructuredDtoKeys,
1111
PredictionServerMimetypes,
12+
X_TRANSFORM_KEY,
13+
)
14+
from datarobot_drum.resource.transform_helpers import (
15+
make_arrow_payload,
16+
is_sparse,
17+
make_mtx_payload,
18+
make_csv_payload,
1219
)
1320
from datarobot_drum.resource.unstructured_helpers import (
1421
_resolve_incoming_unstructured_data,
@@ -29,12 +36,19 @@ class PredictMixin:
2936
3037
"""
3138

32-
def do_predict(self, logger=None):
39+
def _predict_or_transform(self, logger=None):
3340
response_status = HTTP_200_OK
3441

3542
file_key = "X"
3643
filestorage = request.files.get(file_key)
3744

45+
if self._target_type == TargetType.TRANSFORM:
46+
arrow_key = "arrow_version"
47+
arrow_version = request.files.get(arrow_key)
48+
if arrow_version is not None:
49+
arrow_version = eval(arrow_version.getvalue())
50+
use_arrow = arrow_version is not None
51+
3852
if not filestorage:
3953
wrong_key_error_message = (
4054
"Samples should be provided as a csv, mtx, or arrow file under `{}` key.".format(
@@ -53,10 +67,38 @@ def do_predict(self, logger=None):
5367
with tempfile.NamedTemporaryFile(suffix=file_ext) as f:
5468
filestorage.save(f)
5569
f.flush()
56-
out_data = self._predictor.predict(f.name)
70+
if self._target_type == TargetType.TRANSFORM:
71+
out_data = self._predictor.transform(f.name)
72+
else:
73+
out_data = self._predictor.predict(f.name)
5774

5875
if self._target_type == TargetType.UNSTRUCTURED:
5976
response = out_data
77+
elif self._target_type == TargetType.TRANSFORM:
78+
if is_sparse(out_data):
79+
mtx_payload = make_mtx_payload(out_data)
80+
response = (
81+
'{{"{transform_key}":{mtx_payload}, "out.format":"{out_format}"}}'.format(
82+
transform_key=X_TRANSFORM_KEY, mtx_payload=mtx_payload, out_format="sparse"
83+
)
84+
)
85+
else:
86+
if use_arrow:
87+
arrow_payload = make_arrow_payload(out_data, arrow_version)
88+
response = (
89+
'{{"{transform_key}":{arrow_payload}, "out.format":"{out_format}"}}'.format(
90+
transform_key=X_TRANSFORM_KEY,
91+
arrow_payload=arrow_payload,
92+
out_format="arrow",
93+
)
94+
)
95+
else:
96+
csv_payload = make_csv_payload(out_data)
97+
response = (
98+
'{{"{transform_key}":{csv_payload}, "out.format":"{out_format}"}}'.format(
99+
transform_key=X_TRANSFORM_KEY, csv_payload=csv_payload, out_format="csv"
100+
)
101+
)
60102
else:
61103
num_columns = len(out_data.columns)
62104
# float32 is not JSON serializable, so cast to float, which is float64
@@ -76,6 +118,19 @@ def do_predict(self, logger=None):
76118

77119
return response, response_status
78120

121+
def do_predict(self, logger=None):
122+
if self._target_type == TargetType.TRANSFORM:
123+
wrong_target_type_error_message = (
124+
"This project has target type {}, "
125+
"use the /transform/ endpoint.".format(self._target_type)
126+
)
127+
if logger is not None:
128+
logger.error(wrong_target_type_error_message)
129+
response_status = HTTP_422_UNPROCESSABLE_ENTITY
130+
return {"message": "ERROR: " + wrong_target_type_error_message}, response_status
131+
132+
return self._predict_or_transform(logger=logger)
133+
79134
def do_predict_unstructured(self, logger=None):
80135
def _validate_content_type_header(header):
81136
ret_mimetype, content_type_params_dict = werkzeug.http.parse_options_header(header)
@@ -115,3 +170,19 @@ def _validate_content_type_header(header):
115170
response.headers["Content-Type"] = content_type
116171

117172
return response, response_status
173+
174+
def do_transform(self, logger=None):
175+
if self._target_type != TargetType.TRANSFORM:
176+
endpoint = (
177+
"predictUnstructured" if self._target_type == TargetType.UNSTRUCTURED else "predict"
178+
)
179+
wrong_target_type_error_message = (
180+
"This project has target type {}, "
181+
"use the /{}/ endpoint.".format(self._target_type, endpoint)
182+
)
183+
if logger is not None:
184+
logger.error(wrong_target_type_error_message)
185+
response_status = HTTP_422_UNPROCESSABLE_ENTITY
186+
return {"message": "ERROR: " + wrong_target_type_error_message}, response_status
187+
188+
return self._predict_or_transform(logger=logger)
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
import pyarrow as pa
2+
import pandas as pd
3+
4+
from io import BytesIO, StringIO
5+
6+
from scipy.io import mmwrite, mmread
7+
from scipy.sparse.csr import csr_matrix
8+
from scipy.sparse import vstack
9+
10+
from datarobot_drum.drum.common import X_TRANSFORM_KEY
11+
12+
13+
def is_sparse(df):
14+
return hasattr(df, "sparse") or type(df.iloc[0].values[0]) == csr_matrix
15+
16+
17+
def make_arrow_payload(df, arrow_version):
18+
if arrow_version != pa.__version__ and arrow_version < 0.2:
19+
batch = pa.RecordBatch.from_pandas(df, nthreads=None, preserve_index=False)
20+
sink = pa.BufferOutputStream()
21+
options = pa.ipc.IpcWriteOptions(
22+
metadata_version=pa.MetadataVersion.V4, use_legacy_format=True
23+
)
24+
with pa.RecordBatchStreamWriter(sink, batch.schema, options=options) as writer:
25+
writer.write_batch(batch)
26+
return sink.getvalue().to_pybytes()
27+
else:
28+
return pa.ipc.serialize_pandas(df, preserve_index=False).to_pybytes()
29+
30+
31+
def make_csv_payload(df):
32+
s_buf = StringIO()
33+
df.to_csv(s_buf, index=False)
34+
return s_buf.getvalue().encode("utf-8")
35+
36+
37+
def read_arrow_payload(response_dict):
38+
bytes = response_dict[X_TRANSFORM_KEY]
39+
df = pa.ipc.deserialize_pandas(bytes)
40+
return df
41+
42+
43+
def read_csv_payload(response_dict):
44+
bytes = response_dict[X_TRANSFORM_KEY]
45+
return pd.read_csv(BytesIO(bytes))
46+
47+
48+
def make_mtx_payload(df):
49+
if hasattr(df, "sparse"):
50+
sparse_mat = csr_matrix(df.sparse.to_coo())
51+
else:
52+
sparse_mat = vstack(x[0] for x in df.values)
53+
sink = BytesIO()
54+
mmwrite(sink, sparse_mat)
55+
return sink.getvalue()
56+
57+
58+
def read_mtx_payload(response_dict):
59+
bytes = response_dict[X_TRANSFORM_KEY]
60+
sparse_mat = mmread(BytesIO(bytes))
61+
return csr_matrix(sparse_mat)
62+
63+
64+
def validate_transformed_output(transformed_output, should_be_sparse=False):
65+
if should_be_sparse:
66+
assert type(transformed_output) == csr_matrix
67+
assert transformed_output.shape[1] == 714
68+
else:
69+
assert type(transformed_output) == pd.DataFrame
70+
assert transformed_output.shape[1] == 10

model_templates/training/python3_sklearn_transform/README.md

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,7 @@ Categoricals:
3030
- Impute missing values with the string "missing"
3131
- One hot encode the data (ignoring new categorical levels at prediction time)
3232

33-
SVD:
34-
After all the above is done, run SVD to reduce the dimensionality of the dataset to 10.
35-
36-
This makes a dataset that can be used with basically any sklearn model. This step could be removed for models that support sparse data.
37-
33+
This makes a dataset that can be used with any sklearn model that supports sparse data.
3834

3935
### To run locally using 'drum'
4036
Paths are relative to `datarobot-user-models` root:

model_templates/training/python3_sklearn_transform/create_transform_pipeline.py

Lines changed: 1 addition & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import pandas as pd
33
from sagemaker_sklearn_extension.feature_extraction.text import MultiColumnTfidfVectorizer
44
from sklearn.compose import ColumnTransformer, make_column_selector
5-
from sklearn.decomposition import TruncatedSVD
65
from sklearn.impute import SimpleImputer
76
from sklearn.pipeline import Pipeline
87
from sklearn.preprocessing import OneHotEncoder, StandardScaler
@@ -34,24 +33,6 @@
3433
]
3534
)
3635

37-
# Modified TruncatedSVD that doesn't fail if n_components > ncols
38-
class MyTruncatedSVD(TruncatedSVD):
39-
def fit_transform(self, X, y=None):
40-
if X.shape[1] <= self.n_components:
41-
self.n_components = X.shape[1] - 1
42-
return TruncatedSVD.fit_transform(self, X=X, y=y)
43-
44-
45-
# Dense preprocessing pipeline, for models such as XGboost that do not do well with
46-
# extremely wide, sparse data
47-
# This preprocessing will work with linear models such as Ridge too
48-
dense_preprocessing_pipeline = Pipeline(
49-
steps=[
50-
("preprocessing", sparse_preprocessing_pipeline),
51-
("SVD", MyTruncatedSVD(n_components=10, random_state=42, algorithm="randomized")),
52-
]
53-
)
54-
5536

5637
def make_pipeline():
57-
return dense_preprocessing_pipeline
38+
return sparse_preprocessing_pipeline

model_templates/training/python3_sklearn_transform/custom.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import pickle
22
import pandas as pd
3+
from scipy.sparse.csr import csr_matrix
34

45
from create_transform_pipeline import make_pipeline
56

@@ -54,4 +55,8 @@ def transform(data, transformer):
5455
-------
5556
transformed DataFrame resulting from applying transform to incoming data
5657
"""
57-
return pd.DataFrame(transformer.transform(data))
58+
transformed = transformer.transform(data)
59+
if type(transform) == csr_matrix:
60+
return pd.DataFrame.sparse.from_spmatrix(transformed)
61+
else:
62+
return pd.DataFrame(transformed)

0 commit comments

Comments
 (0)