Skip to content

Commit a992a9e

Browse files
authored
Removing the deprecated log_xxx_metadata calls (#28)
* removing the deprecated calls * correcting the ref * fixing the review comments * fixing the steps * new way to fetch artifacts * fixed the errors * fixed imports
1 parent a200a13 commit a992a9e

File tree

3 files changed

+28
-18
lines changed

3 files changed

+28
-18
lines changed

.github/workflows/ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,5 +55,5 @@ jobs:
5555
with:
5656
stack-name: ${{ matrix.stack-name }}
5757
python-version: ${{ matrix.python-version }}
58-
ref-zenml: ${{ inputs.ref-zenml || 'develop' }}
58+
ref-zenml: ${{ inputs.ref-zenml || 'feature/followup-run-metadata' }}
5959
ref-template: ${{ inputs.ref-template || github.ref }}

template/steps/data_preprocessor.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from sklearn.preprocessing import MinMaxScaler
88
from typing_extensions import Annotated
99
from utils.preprocess import ColumnsDropper, DataFrameCaster, NADropper
10-
from zenml import log_artifact_metadata, step
10+
from zenml import log_metadata, step
1111

1212

1313
@step
@@ -67,8 +67,9 @@ def data_preprocessor(
6767
dataset_tst = preprocess_pipeline.transform(dataset_tst)
6868

6969
# Log metadata so we can load it in the inference pipeline
70-
log_artifact_metadata(
71-
artifact_name="preprocess_pipeline",
70+
log_metadata(
7271
metadata={"random_state": random_state, "target": target},
72+
artifact_name="preprocess_pipeline",
73+
infer_artifact=True,
7374
)
7475
return dataset_trn, dataset_tst, preprocess_pipeline

template/steps/model_evaluator.py

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,20 +4,22 @@
44

55
import pandas as pd
66
from sklearn.base import ClassifierMixin
7-
from zenml import log_artifact_metadata, step
7+
8+
from zenml import log_metadata, step
9+
from zenml.client import Client
810
from zenml.logger import get_logger
911

1012
logger = get_logger(__name__)
1113

1214

1315
@step
1416
def model_evaluator(
15-
model: ClassifierMixin,
16-
dataset_trn: pd.DataFrame,
17-
dataset_tst: pd.DataFrame,
18-
min_train_accuracy: float = 0.0,
19-
min_test_accuracy: float = 0.0,
20-
target: Optional[str] = "target",
17+
model: ClassifierMixin,
18+
dataset_trn: pd.DataFrame,
19+
dataset_tst: pd.DataFrame,
20+
min_train_accuracy: float = 0.0,
21+
min_test_accuracy: float = 0.0,
22+
target: Optional[str] = "target",
2123
) -> float:
2224
"""Evaluate a trained model.
2325
@@ -63,24 +65,31 @@ def model_evaluator(
6365
dataset_tst.drop(columns=[target]),
6466
dataset_tst[target],
6567
)
66-
logger.info(f"Train accuracy={trn_acc*100:.2f}%")
67-
logger.info(f"Test accuracy={tst_acc*100:.2f}%")
68+
logger.info(f"Train accuracy={trn_acc * 100:.2f}%")
69+
logger.info(f"Test accuracy={tst_acc * 100:.2f}%")
6870

6971
messages = []
7072
if trn_acc < min_train_accuracy:
7173
messages.append(
72-
f"Train accuracy {trn_acc*100:.2f}% is below {min_train_accuracy*100:.2f}% !"
74+
f"Train accuracy {trn_acc * 100:.2f}% is below {min_train_accuracy * 100:.2f}% !"
7375
)
7476
if tst_acc < min_test_accuracy:
7577
messages.append(
76-
f"Test accuracy {tst_acc*100:.2f}% is below {min_test_accuracy*100:.2f}% !"
78+
f"Test accuracy {tst_acc * 100:.2f}% is below {min_test_accuracy * 100:.2f}% !"
7779
)
7880
else:
7981
for message in messages:
8082
logger.warning(message)
8183

82-
log_artifact_metadata(
83-
metadata={"train_accuracy": float(trn_acc), "test_accuracy": float(tst_acc)},
84-
artifact_name="sklearn_classifier",
84+
client = Client()
85+
latest_classifier = client.get_artifact_version("sklearn_classifier")
86+
87+
log_metadata(
88+
metadata={
89+
"train_accuracy": float(trn_acc),
90+
"test_accuracy": float(tst_acc)
91+
},
92+
artifact_version_id=latest_classifier.id
8593
)
94+
8695
return float(tst_acc)

0 commit comments

Comments
 (0)