Skip to content

Commit 1542165

Browse files
authored
[AGENT-3621] Adds support for extending DRUM server mode (#698)
* Add new hook file for flask extension If this file is present in the custom model dir then we will use it to extend our Flask application * Add new example for extending Flask server I took the sklearn example and just added the flask customizer code to it. I felt like sklearn is pretty simple so that should not confuse from main part of the example which is the `custom_flask.py` code. I'm using 3rd party Flask extension to help with some of the authentication code to show that this solution is highly customizable. * Bump version 1.9.11 * Fix lint * Add new line * Change print to logging * Improve README formatting * Make it a dev release * Small renaming tweaks to flask hook code * Improve comments for sample custom flask hook code * Add test case for custom flask hook * Fix changelog * Fix black * Bump to 1.9.12 * Fix comment * Improve readme for new sample * Rename and update docs I don't want people confusing the flask extension with a real model template. * Don't use assert for user driven logic
1 parent f3e25b3 commit 1542165

File tree

14 files changed

+258
-11
lines changed

14 files changed

+258
-11
lines changed

custom_model_runner/CHANGELOG.md

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,13 @@ All notable changes to this project will be documented in this file.
44
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
55
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
66

7+
#### [1.9.12] - 2022-10-25
8+
##### Added
9+
- Add support for a new hook (`custom_flask.py`) in the model-dir to allow extending the Flask
10+
application when drum is running in server mode.
11+
- Add a new model template sample (`flask_extension_httpauth`) to illustrate a potential
12+
authentication use-case using the new `custom_flask.py` hook.
13+
714
#### [1.9.11] - 2022-10-24
815
##### Changed
916
- Add support to initialize DR Python client in order to allow DR API access.
@@ -18,7 +25,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1825
- Bump com.datarobot.datarobot-prediction package to 2.2.1
1926
##### Fixed
2027
- Pin `datarobot==2.27.0`
21-
- Handle missing values in image typeschema validator
28+
- Handle missing values in image typeschema validator
2229

2330
#### [1.9.8] - 2022-08-04
2431
##### Changed
@@ -62,7 +69,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
6269

6370
#### [1.8.0] - 2022-02-18
6471
##### Added
65-
- Built-in support for ONNX models
72+
- Built-in support for ONNX models
6673
- Support for new custom (training) task templates
6774

6875
#### [1.7.2dev1] - 2022-02-16
@@ -164,7 +171,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
164171
#### [1.5.11] - 2021-08-19
165172
##### Fixed
166173
- Apply default schema to transforms
167-
##### Added
174+
##### Added
168175
- type schema to pipeline examples
169176

170177
#### [1.5.10] - 2021-07-30
@@ -306,7 +313,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
306313

307314
#### [1.4.5] - 2020-12-02
308315
##### Added
309-
- **/transform** endpoint added to prediction server
316+
- **/transform** endpoint added to prediction server
310317
##### Changes
311318
- Allow multiclass to function with only 2 labels
312319

@@ -375,12 +382,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
375382
#### [1.1.4] - 2020-08-04
376383
##### Added
377384
- the docker flag now takes a directory, and will build a docker image
378-
- the `push` verb lets you add your code into DataRobot.
385+
- the `push` verb lets you add your code into DataRobot.
379386
- H2O models support
380387
- r_lang fit component, pipeline, and template
381388
##### Changed
382389
- search custom.py recursively in the code dir
383-
- set rpy2 dependcy <= 3.2.7 to avoid pandas import error
390+
- set rpy2 dependcy <= 3.2.7 to avoid pandas import error
384391

385392
## [1.1.3] - 2020-07-17
386393
### Added
@@ -427,15 +434,15 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
427434
- change command `cmrun predict` to `cmrun score`
428435

429436
### Added
430-
- `new` subcommand for model templates creation
437+
- `new` subcommand for model templates creation
431438

432439
## [1.0.16] - 2020-05-05
433440
### Changed
434441
- unpin rpy2 dependency version
435442

436443
## [1.0.15] - 2020-05-04
437444
### Changed
438-
- require to use sub-command, e.g. `cmrun predict`
445+
- require to use sub-command, e.g. `cmrun predict`
439446

440447
## [1.0.14] - 2020-04-30
441448
### Changed

custom_model_runner/datarobot_drum/drum/description.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,6 @@
44
This is proprietary source code of DataRobot, Inc. and its affiliates.
55
Released under the terms of DataRobot Tool and Utility Agreement.
66
"""
7-
version = "1.9.11"
7+
version = "1.9.12"
88
__version__ = version
99
project_name = "datarobot-drum"

custom_model_runner/datarobot_drum/drum/enum.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@
2121
CUSTOM_FILE_NAME = "custom"
2222

2323

24+
FLASK_EXT_FILE_NAME = "custom_flask"
25+
26+
2427
POSITIVE_CLASS_LABEL_ARG_KEYWORD = "positive_class_label"
2528

2629

custom_model_runner/datarobot_drum/drum/model_adapter.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,8 @@ def _load_custom_hooks_for_legacy_drum(self, custom_module):
138138

139139
def load_custom_hooks(self):
140140
custom_file_paths = list(Path(self._model_dir).rglob("{}.py".format(CUSTOM_FILE_NAME)))
141-
assert len(custom_file_paths) <= 1
141+
if len(custom_file_paths) > 1:
142+
raise RuntimeError("Found too many custom hook files: {}".format(custom_file_paths))
142143

143144
if len(custom_file_paths) == 0:
144145
print("No {}.py file detected in {}".format(CUSTOM_FILE_NAME, self._model_dir))

custom_model_runner/datarobot_drum/resource/components/Python/prediction_server/prediction_server.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
"""
77
import logging
88
import sys
9+
from pathlib import Path
910
from mlpiper.components.connectable_component import ConnectableComponent
1011

1112
from datarobot_drum.drum.common import (
@@ -18,6 +19,7 @@
1819
ModelInfoKeys,
1920
RunLanguage,
2021
TargetType,
22+
FLASK_EXT_FILE_NAME,
2123
)
2224
from datarobot_drum.drum.description import version as drum_version
2325
from datarobot_drum.drum.exceptions import DrumCommonException
@@ -192,6 +194,7 @@ def handle_exception(e):
192194
cli.show_server_banner = lambda *x: None
193195

194196
app = get_flask_app(model_api)
197+
self.load_flask_extensions(app)
195198

196199
host = self._params.get("host", None)
197200
port = self._params.get("port", None)
@@ -209,3 +212,25 @@ def terminate(self):
209212
terminate_op = getattr(self._predictor, "terminate", None)
210213
if callable(terminate_op):
211214
terminate_op()
215+
216+
def load_flask_extensions(self, app):
217+
custom_file_paths = list(Path(self._code_dir).rglob("{}.py".format(FLASK_EXT_FILE_NAME)))
218+
if len(custom_file_paths) > 1:
219+
raise RuntimeError("Found too many custom hook files: {}".format(custom_file_paths))
220+
221+
if len(custom_file_paths) == 0:
222+
logger.info("No %s.py file detected in %s", FLASK_EXT_FILE_NAME, self._code_dir)
223+
return
224+
225+
custom_file_path = custom_file_paths[0]
226+
logger.info("Detected %s .. trying to load Flask extensions", custom_file_path)
227+
sys.path.insert(0, str(custom_file_path.parent))
228+
229+
try:
230+
custom_module = __import__(FLASK_EXT_FILE_NAME)
231+
custom_module.init_app(app)
232+
except ImportError as e:
233+
logger.error("Could not load hooks", exc_info=True)
234+
raise DrumCommonException(
235+
"Failed to extend Flask app from [{}] : {}".format(custom_file_path, e)
236+
)
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
## Extending Web Server Behavior
2+
3+
This sample is meant to illustrate how you can add thirdparty or your own custom Flask extensions
4+
when drum is in server-mode. To demonstrate one potential usecase, the `custom_flask.py` file in
5+
this model directory will extend the HTTP server to require a specific [Bearer Token](https://swagger.io/docs/specification/authentication/bearer-authentication/)
6+
when making any requests to it.
7+
8+
For completeness, we also include all the model related files from the [Python Sklearn Inference Model Template](../python3_sklearn/).
9+
10+
Note: it is **not** necessary (nor recommended) to add authentication to custom models that are created in DataRobot MLOps.
11+
This example is simply to demonstration the flexibility of the `custom_flask.py` hook.
12+
13+
## Instructions
14+
Create a new custom model with these files and use the Python Drop-In Environment with it
15+
16+
### To run locally using 'drum'
17+
Paths are relative to `./datarobot-user-models`:
18+
```
19+
drum server --docker public_dropin_environments/python3_sklearn --code-dir model_templates/python3_sklearn_flask_ext/ --target-type regression --address localhost:8080
20+
```
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
"""
2+
Copyright 2021 DataRobot, Inc. and its affiliates.
3+
All rights reserved.
4+
This is proprietary source code of DataRobot, Inc. and its affiliates.
5+
Released under the terms of DataRobot Tool and Utility Agreement.
6+
"""
7+
import pandas as pd
8+
9+
10+
def transform(data, model):
11+
"""
12+
Note: This hook may not have to be implemented for your model.
13+
In this case implemented for the model used in the example.
14+
15+
Modify this method to add data transformation before scoring calls. For example, this can be
16+
used to implement one-hot encoding for models that don't include it on their own.
17+
18+
Parameters
19+
----------
20+
data: pd.DataFrame
21+
model: object, the deserialized model
22+
23+
Returns
24+
-------
25+
pd.DataFrame
26+
"""
27+
# Execute any steps you need to do before scoring
28+
# Remove target columns if they're in the dataset
29+
for target_col in ["Grade 2014", "Species"]:
30+
if target_col in data:
31+
data.pop(target_col)
32+
data = data.fillna(0)
33+
return data
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
"""
2+
Copyright 2022 DataRobot, Inc. and its affiliates.
3+
All rights reserved.
4+
This is proprietary source code of DataRobot, Inc. and its affiliates.
5+
Released under the terms of DataRobot Tool and Utility Agreement.
6+
"""
7+
import logging
8+
9+
from flask import request
10+
from flask_httpauth import HTTPTokenAuth
11+
12+
logger = logging.getLogger(__name__)
13+
14+
token_auth = HTTPTokenAuth()
15+
16+
AUTHENTICATION_TOKEN = "DUMMY_TOKEN_123"
17+
18+
19+
@token_auth.verify_token
20+
def verify(token):
21+
# Hard-code users for demo purposes but in a real setup this data would be
22+
# fetched from a database or secure key vault, for example.
23+
if token == AUTHENTICATION_TOKEN:
24+
# flask_httpauth requires this function to return a username on successful authentication
25+
# so it can be used for authorization (but we aren't implementing that for this sample).
26+
return "dummy_user"
27+
28+
29+
def init_app(app):
30+
"""
31+
Below is a sample hook that illustrates how to add simple token based
32+
authentication to most of the routes served by the custom model runner.
33+
34+
Parameters
35+
----------
36+
app: Flask
37+
"""
38+
# Health check endpoints shouldn't require auth
39+
no_auth_endpoints = {"model_api.ping", "model_api.health"}
40+
logger.info("Setting up authentication on all routes except ping routes")
41+
42+
@app.before_request
43+
def check_auth():
44+
auth = token_auth.get_auth()
45+
46+
# Flask normally handles OPTIONS requests on its own, but in the case it is configured to
47+
# forward those to the application, we need to ignore authentication headers and let the
48+
# request through to avoid unwanted interactions with CORS.
49+
if request.method != "OPTIONS" and request.endpoint not in no_auth_endpoints:
50+
user = token_auth.authenticate(auth, None)
51+
if user in (False, None):
52+
return token_auth.auth_error_callback(401)
53+
54+
logger.info(
55+
'Please authenticate to the server:\n\t`curl -H "Authorization: Bearer %s" ...`',
56+
AUTHENTICATION_TOKEN,
57+
)
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
name: drumpush-regression
2+
type: inference
3+
targetType: regression
4+
modelID: 5f1f15a4d6111f01cb7f91fd
5+
environmentID: 5e8c889607389fe0f466c72d
6+
inferenceModel:
7+
targetName: Grade 2014
8+
validation:
9+
# Path is relative to this file
10+
input: ../../../tests/testdata/juniors_3_year_stats_regression.csv
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Flask-HTTPAuth==4.7.0
Binary file not shown.

requirements_dev.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@
22
-r requirements_lint.txt
33
-r requirements_test.txt
44

5-
datarobot-oss-java-jdk11==1.11.0.13.post1+dr
5+
datarobot-oss-java-jdk11==1.11.0.13.post1+dr
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
"""
2+
Copyright 2022 DataRobot, Inc. and its affiliates.
3+
All rights reserved.
4+
This is proprietary source code of DataRobot, Inc. and its affiliates.
5+
Released under the terms of DataRobot Tool and Utility Agreement.
6+
"""
7+
from pathlib import Path
8+
import shutil
9+
10+
import pytest
11+
import requests
12+
13+
from .constants import (
14+
PYTHON_UNSTRUCTURED,
15+
UNSTRUCTURED,
16+
TESTS_FIXTURES_PATH,
17+
)
18+
from datarobot_drum.resource.drum_server_utils import DrumServerRun
19+
from datarobot_drum.resource.utils import _create_custom_model_dir
20+
from datarobot_drum.drum.utils.drum_utils import unset_drum_supported_env_vars
21+
22+
23+
class TestDrumServerCustomAuth:
24+
@pytest.fixture(scope="class")
25+
def custom_flask_script(self):
26+
return (Path(TESTS_FIXTURES_PATH) / "custom_flask_demo_auth.py", "custom_flask.py")
27+
28+
@pytest.fixture(scope="class")
29+
def custom_model_dir(self, custom_flask_script, resources, tmp_path_factory):
30+
tmp_dir = tmp_path_factory.mktemp("model_dir")
31+
custom_model_dir = _create_custom_model_dir(
32+
resources, tmp_dir, None, UNSTRUCTURED, PYTHON_UNSTRUCTURED,
33+
)
34+
fixture_filename, target_name = custom_flask_script
35+
shutil.copy2(fixture_filename, custom_model_dir / target_name)
36+
return custom_model_dir
37+
38+
@pytest.fixture(scope="class")
39+
def drum_server(self, resources, custom_model_dir):
40+
unset_drum_supported_env_vars()
41+
with DrumServerRun(
42+
resources.target_types(UNSTRUCTURED),
43+
resources.class_labels(None, UNSTRUCTURED),
44+
custom_model_dir,
45+
) as run:
46+
yield run
47+
48+
def test_auth_passthrough(self, drum_server):
49+
response = requests.get(drum_server.url_server_address + "/ping/")
50+
assert response.ok
51+
52+
def test_missing_auth_header(self, drum_server):
53+
response = requests.get(drum_server.url_server_address + "/info/")
54+
assert response.status_code == 401
55+
assert response.json()["message"] == "Missing X-Auth header"
56+
57+
def test_bad_auth_token(self, drum_server):
58+
response = requests.get(
59+
drum_server.url_server_address + "/info/", headers={"X-Auth": "token"}
60+
)
61+
assert response.status_code == 401
62+
assert response.json()["message"] == "Auth token is invalid"
63+
64+
def test_successful_auth(self, drum_server):
65+
response = requests.get(
66+
drum_server.url_server_address + "/info/", headers={"X-Auth": "t0k3n"}
67+
)
68+
assert response.ok
69+
assert response.json()["drumServer"] == "flask"
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
"""
2+
Copyright 2022 DataRobot, Inc. and its affiliates.
3+
All rights reserved.
4+
This is proprietary source code of DataRobot, Inc. and its affiliates.
5+
Released under the terms of DataRobot Tool and Utility Agreement.
6+
"""
7+
from flask import request, jsonify
8+
9+
10+
def init_app(app):
11+
@app.before_request
12+
def check_header():
13+
# Allow ping route with no Auth otherwise test setup would fail
14+
if request.endpoint != "model_api.ping":
15+
try:
16+
token = request.headers["X-Auth"]
17+
except KeyError:
18+
return jsonify({"message": "Missing X-Auth header"}), 401
19+
else:
20+
if token != "t0k3n":
21+
return jsonify({"message": "Auth token is invalid"}), 401

0 commit comments

Comments
 (0)