|
4 | 4 |
|
5 | 5 | import pandas as pd
|
6 | 6 | from sklearn.base import ClassifierMixin
|
7 |
| -from zenml import log_artifact_metadata, step |
| 7 | + |
| 8 | +from zenml import log_metadata, step |
| 9 | +from zenml.client import Client |
8 | 10 | from zenml.logger import get_logger
|
9 | 11 |
|
10 | 12 | logger = get_logger(__name__)
|
11 | 13 |
|
12 | 14 |
|
13 | 15 | @step
|
14 | 16 | def model_evaluator(
|
15 |
| - model: ClassifierMixin, |
16 |
| - dataset_trn: pd.DataFrame, |
17 |
| - dataset_tst: pd.DataFrame, |
18 |
| - min_train_accuracy: float = 0.0, |
19 |
| - min_test_accuracy: float = 0.0, |
20 |
| - target: Optional[str] = "target", |
| 17 | + model: ClassifierMixin, |
| 18 | + dataset_trn: pd.DataFrame, |
| 19 | + dataset_tst: pd.DataFrame, |
| 20 | + min_train_accuracy: float = 0.0, |
| 21 | + min_test_accuracy: float = 0.0, |
| 22 | + target: Optional[str] = "target", |
21 | 23 | ) -> float:
|
22 | 24 | """Evaluate a trained model.
|
23 | 25 |
|
@@ -63,24 +65,31 @@ def model_evaluator(
|
63 | 65 | dataset_tst.drop(columns=[target]),
|
64 | 66 | dataset_tst[target],
|
65 | 67 | )
|
66 |
| - logger.info(f"Train accuracy={trn_acc*100:.2f}%") |
67 |
| - logger.info(f"Test accuracy={tst_acc*100:.2f}%") |
| 68 | + logger.info(f"Train accuracy={trn_acc * 100:.2f}%") |
| 69 | + logger.info(f"Test accuracy={tst_acc * 100:.2f}%") |
68 | 70 |
|
69 | 71 | messages = []
|
70 | 72 | if trn_acc < min_train_accuracy:
|
71 | 73 | messages.append(
|
72 |
| - f"Train accuracy {trn_acc*100:.2f}% is below {min_train_accuracy*100:.2f}% !" |
| 74 | + f"Train accuracy {trn_acc * 100:.2f}% is below {min_train_accuracy * 100:.2f}% !" |
73 | 75 | )
|
74 | 76 | if tst_acc < min_test_accuracy:
|
75 | 77 | messages.append(
|
76 |
| - f"Test accuracy {tst_acc*100:.2f}% is below {min_test_accuracy*100:.2f}% !" |
| 78 | + f"Test accuracy {tst_acc * 100:.2f}% is below {min_test_accuracy * 100:.2f}% !" |
77 | 79 | )
|
78 | 80 | else:
|
79 | 81 | for message in messages:
|
80 | 82 | logger.warning(message)
|
81 | 83 |
|
82 |
| - log_artifact_metadata( |
83 |
| - metadata={"train_accuracy": float(trn_acc), "test_accuracy": float(tst_acc)}, |
84 |
| - artifact_name="sklearn_classifier", |
| 84 | + client = Client() |
| 85 | + latest_classifier = client.get_artifact_version("sklearn_classifier") |
| 86 | + |
| 87 | + log_metadata( |
| 88 | + metadata={ |
| 89 | + "train_accuracy": float(trn_acc), |
| 90 | + "test_accuracy": float(tst_acc) |
| 91 | + }, |
| 92 | + artifact_version_id=latest_classifier.id |
85 | 93 | )
|
| 94 | + |
86 | 95 | return float(tst_acc)
|
0 commit comments