Skip to content

Commit 0e9c278

Browse files
authored
feat: use pzmm ot register pytorch models
* feat: re-added model info * style: replace 'is' with '==' to eliminate warnings * bug: allow binary file to be passed * feat: allow passing inline json * feat: simple parsing of pytorch models * chore: directly expose get_model_info() * fix: handle case where scikit model output not provided * fix: close zip file before reading bytes * feat: handle file-like or raw bytes * chore: use pzmm for open source models * fix: do not print unless running in a notebook * fix: remove hooks * feat: pass additional info when registering model * chore: misc cleanup * feat: generate score code * chore: misc cleanup * fix: skip for Viya 4+ * fix: ignore spaces in env var * fix: use sanitized model names in file names. * fix: use instance methods & variables when manipulating model data * fix: score code generation creates score() not predict() * chore: remove obsolete code * fix: updated for instance methods * fix: new pandas behavior * feat: require dill * test: remove obsolete tests * feat: reshape 3+d tensors * fix: update for tree-based models * fix: update for tree-based models * feat: use model_info. rename input to X. * test: update for changes to viya & pandas * fix: convert non str/bytes to str. * feat: allow passing files using pathlib * test: test case update for pzmm and cassette refresh * chore: black formatting * chore: black formatting
1 parent 1316c09 commit 0e9c278

File tree

508 files changed

+1644
-880
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

508 files changed

+1644
-880
lines changed

examples/register_scikit_classification_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
# Register the model in Model Manager
2626
register_model(model,
2727
model_name,
28-
input=X, # Use X to determine model inputs
28+
X=X, # Use X to determine model inputs
2929
project='Iris', # Register in "Iris" project
3030
force=True) # Create project if it doesn't exist
3131

@@ -36,5 +36,5 @@
3636
x = X.iloc[0, :]
3737

3838
# Call the published module and score the record
39-
result = module.predict(x)
39+
result = module.score(x)
4040
print(result)

examples/register_scikit_regression_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
project_name = 'Boston Housing'
2929

3030
# Register the model in SAS Model Manager
31-
register_model(model, model_name, project_name, input=X, force=True)
31+
register_model(model, model_name, project_name, X=X, force=True)
3232

3333
# Publish the model to the real-time scoring engine
3434
module = publish_model(model_name, 'maslocal', replace=True)
@@ -37,5 +37,5 @@
3737
x = X.iloc[0, :]
3838

3939
# Call the published module and score the record
40-
result = module.predict(x)
40+
result = module.score(x)
4141
print(result)

setup.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,10 @@ def get_file(filename):
4242
packages=find_packages(where="src"),
4343
package_dir={"": "src"},
4444
python_requires=">=3.6",
45-
install_requires=["pandas>=0.24.0", "requests", "pyyaml", "packaging"],
45+
install_requires=["dill", "pandas>=0.24.0", "requests", "pyyaml", "packaging"],
4646
extras_require={
4747
"swat": ["swat"],
4848
"GitPython": ["GitPython"],
49-
"numpy": ["numpy"],
5049
"scikit-learn": ["scikit-learn"],
5150
"kerberos": [
5251
'kerberos ; platform_system != "Windows"',

src/sasctl/_services/files.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# SPDX-License-Identifier: Apache-2.0
66

77
import os
8+
from pathlib import Path
89

910
from sasctl.utils.cli import sasctl_command
1011

@@ -40,7 +41,7 @@ def create_file(cls, file, folder=None, filename=None, expiration=None):
4041
4142
Parameters
4243
----------
43-
file : str or file_like
44+
file : str, pathlib.Path, or file_like
4445
Path to the file to upload or a file-like object.
4546
folder : str or dict, optional
4647
Name, or, or folder information as returned by :func:`.get_folder`.
@@ -55,8 +56,8 @@ def create_file(cls, file, folder=None, filename=None, expiration=None):
5556
A dictionary containing the file attributes.
5657
5758
"""
58-
if isinstance(file, str):
59-
filename = filename or os.path.splitext(os.path.split(file)[1])[0]
59+
if isinstance(file, (str, Path)):
60+
filename = filename or Path(file).name
6061

6162
with open(file, "rb") as f:
6263
file = f.read()

src/sasctl/_services/model_repository.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -524,7 +524,7 @@ def import_model_from_zip(
524524
project : str or dict
525525
The name or id of the model project, or a dictionary
526526
representation of the project.
527-
file : bytes
527+
file : bytes or file-like object
528528
The ZIP file containing the model and contents.
529529
description : str
530530
The description of the model.
@@ -551,9 +551,14 @@ def import_model_from_zip(
551551
}
552552
params = "&".join("{}={}".format(k, v) for k, v in params.items())
553553

554+
if not isinstance(file, bytes):
555+
if file.seekable():
556+
file.seek(0)
557+
file = file.read()
558+
554559
r = cls.post(
555560
"/models#octetStream",
556-
data=file.read(),
561+
data=file,
557562
params=params,
558563
headers={"Content-Type": "application/octet-stream"},
559564
)

src/sasctl/pzmm/import_model.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,21 @@ def get_model_properties(
2222
model_files: Union[str, Path, None] = None,
2323
):
2424
if type(model_files) is dict:
25-
model = model_files["ModelProperties.json"]
26-
input_var = model_files["inputVar.json"]
27-
output_var = model_files["outputVar.json"]
25+
try:
26+
model = json.loads(model_files["ModelProperties.json"])
27+
except (json.JSONDecodeError, TypeError):
28+
model = model_files["ModelProperties.json"]
29+
30+
try:
31+
input_var = json.loads(model_files["inputVar.json"])
32+
except (json.JSONDecodeError, TypeError):
33+
input_var = model_files["inputVar.json"]
34+
35+
try:
36+
output_var = json.loads(model_files["outputVar.json"])
37+
except (json.JSONDecodeError, TypeError):
38+
output_var = model_files["outputVar.json"]
39+
2840
else:
2941
with open(Path(model_files) / "ModelProperties.json") as f:
3042
model = json.load(f)
@@ -99,7 +111,9 @@ def project_exists(
99111
response = _create_project(project, model, repo, input_var, output_var)
100112
else:
101113
response = mr.create_project(project, repo)
102-
print(f"A new project named {response.name} was created.")
114+
115+
if check_if_jupyter():
116+
print(f"A new project named {response.name} was created.")
103117
return response
104118
else:
105119
model, input_var, output_var = get_model_properties(target_values, model_files)
@@ -348,7 +362,7 @@ def import_model(
348362
# For SAS Viya 4, the score code can be written beforehand and imported with
349363
# all the model files
350364
elif current_session().version_info() == 4:
351-
score_code_dict = sc.write_score_code(
365+
score_code_dict = sc().write_score_code(
352366
model_prefix,
353367
input_data,
354368
predict_method,
@@ -447,7 +461,7 @@ def import_model(
447461
except AttributeError:
448462
print("Model failed to import to SAS Model Manager.")
449463

450-
score_code_dict = sc.write_score_code(
464+
score_code_dict = sc().write_score_code(
451465
model_prefix,
452466
input_data,
453467
predict_method,

src/sasctl/pzmm/pickle_model.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
# SPDX-License-Identifier: Apache-2.0
33
# %%
44
import codecs
5-
import gzip
65
import pickle
76
import shutil
87
from pathlib import Path
@@ -77,6 +76,10 @@ def pickle_trained_model(
7776
models.
7877
7978
"""
79+
from .write_score_code import ScoreCode
80+
81+
sanitized_prefix = ScoreCode.sanitize_model_prefix(model_prefix)
82+
8083
if is_binary_string:
8184
# For models that use a binary string representation
8285
binary_string = codecs.encode(
@@ -91,25 +94,25 @@ def pickle_trained_model(
9194
# For models imported from MLFlow
9295
shutil.copy(ml_pickle_path, pickle_path)
9396
pzmm_pickle_path = Path(pickle_path) / mlflow_details["model_path"]
94-
pzmm_pickle_path.rename(Path(pickle_path) / (model_prefix + PICKLE))
97+
pzmm_pickle_path.rename(Path(pickle_path) / (sanitized_prefix + PICKLE))
9598
else:
9699
with open(ml_pickle_path, "rb") as pickle_file:
97-
return {model_prefix + PICKLE: pickle.load(pickle_file)}
100+
return {sanitized_prefix + PICKLE: pickle.load(pickle_file)}
98101
else:
99102
# For all other model types
100103
if not is_h2o_model:
101104
if pickle_path:
102105
with open(
103-
Path(pickle_path) / (model_prefix + PICKLE), "wb"
106+
Path(pickle_path) / (sanitized_prefix + PICKLE), "wb"
104107
) as pickle_file:
105108
pickle.dump(trained_model, pickle_file)
106109
if cls.notebook_output:
107110
print(
108111
f"Model {model_prefix} was successfully pickled and saved "
109-
f"to {Path(pickle_path) / (model_prefix + PICKLE)}."
112+
f"to {Path(pickle_path) / (sanitized_prefix + PICKLE)}."
110113
)
111114
else:
112-
return {model_prefix + PICKLE: pickle.dumps(trained_model)}
115+
return {sanitized_prefix + PICKLE: pickle.dumps(trained_model)}
113116
# For binary H2O models, save the binary file as a "pickle" file
114117
elif is_h2o_model and is_binary_model and pickle_path:
115118
if not h2o:
@@ -121,7 +124,7 @@ def pickle_trained_model(
121124
model=trained_model,
122125
force=True,
123126
path=str(pickle_path),
124-
filename=f"{model_prefix}.pickle",
127+
filename=f"{sanitized_prefix}.pickle",
125128
)
126129
# For MOJO H2O models, save as a mojo file and adjust the extension to .mojo
127130
elif is_h2o_model and pickle_path:
@@ -130,7 +133,9 @@ def pickle_trained_model(
130133
"The h2o package is required to save the model as a mojo model."
131134
)
132135
trained_model.save_mojo(
133-
force=True, path=str(pickle_path), filename=f"{model_prefix}.mojo"
136+
force=True,
137+
path=str(pickle_path),
138+
filename=f"{sanitized_prefix}.mojo",
134139
)
135140
elif is_binary_model or is_h2o_model:
136141
raise ValueError(

src/sasctl/pzmm/write_json_files.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -498,18 +498,27 @@ def write_file_metadata_json(
498498
Dictionary containing a key-value pair representing the file name and json
499499
dump respectively.
500500
"""
501+
502+
from .write_score_code import ScoreCode
503+
504+
sanitized_prefix = ScoreCode.sanitize_model_prefix(model_prefix)
505+
501506
dict_list = [
502507
{"role": "inputVariables", "name": INPUT},
503508
{"role": "outputVariables", "name": OUTPUT},
504-
{"role": "score", "name": f"score_{model_prefix}.py"},
509+
{"role": "score", "name": f"score_{sanitized_prefix}.py"},
505510
]
506511
if is_h2o_model:
507-
dict_list.append({"role": "scoreResource", "name": model_prefix + ".mojo"})
512+
dict_list.append(
513+
{"role": "scoreResource", "name": sanitized_prefix + ".mojo"}
514+
)
508515
elif is_tf_keras_model:
509-
dict_list.append({"role": "scoreResource", "name": model_prefix + ".h5"})
516+
dict_list.append(
517+
{"role": "scoreResource", "name": sanitized_prefix + ".h5"}
518+
)
510519
else:
511520
dict_list.append(
512-
{"role": "scoreResource", "name": model_prefix + ".pickle"}
521+
{"role": "scoreResource", "name": sanitized_prefix + ".pickle"}
513522
)
514523

515524
if json_path:
@@ -2314,9 +2323,9 @@ def generate_model_card(
23142323
"Only classification and prediction target types are currently accepted."
23152324
)
23162325
if selection_statistic is None:
2317-
if target_type is "classification":
2326+
if target_type == "classification":
23182327
selection_statistic = "_KS_"
2319-
elif target_type is "prediction":
2328+
elif target_type == "prediction":
23202329
selection_statistic = "_ASE_"
23212330
if selection_statistic not in cls.valid_params:
23222331
raise RuntimeError(

0 commit comments

Comments
 (0)