Skip to content

Commit 7918793

Browse files
committed
fix: upload machine learning models code sources
1 parent 571d548 commit 7918793

File tree

12 files changed

+548
-5
lines changed

12 files changed

+548
-5
lines changed

src/ano_detection/config/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
LoggerArgumentsConfig,
55
DataIngestionConfig,
66
DataProcessingConfig,
7-
DataTrainingConfig
7+
DataTrainingConfig,
8+
ModelArgumentsConfig
89
)
910

1011
__all__ = [
@@ -13,5 +14,6 @@
1314
"LoggerArgumentsConfig",
1415
"DataIngestionConfig",
1516
"DataProcessingConfig",
16-
"DataTrainingConfig"
17+
"DataTrainingConfig",
18+
"ModelArgumentsConfig",
1719
]

src/ano_detection/config/arguments_config.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,44 @@ class DataTrainingConfig:
222222
}
223223
)
224224

225+
@dataclass(frozen=True)
226+
class ModelArgumentsConfig:
227+
root_dir: str = field(
228+
default="src/artifacts/models",
229+
metadata={
230+
"help": "Root directory of the dataset.",
231+
"type": "string",
232+
}
233+
)
234+
model_name: str = field(
235+
default="name_model",
236+
metadata={
237+
"help": "Name of the model.",
238+
"type": "string",
239+
}
240+
)
241+
model_path: str = field(
242+
default="model.pkl",
243+
metadata={
244+
"help": "Path of the model.",
245+
"type": "string",
246+
}
247+
)
248+
model_params: dict = field(
249+
default={"n_estimators": 100, "max_depth": 5, "random_state": 42},
250+
metadata={
251+
"help": "Parameters of the model.",
252+
"type": "dict",
253+
}
254+
)
255+
model_description: str = field(
256+
default="Model Description",
257+
metadata={
258+
"help": "Description of the model.",
259+
"type": "string",
260+
}
261+
)
262+
225263

226264

227265

src/ano_detection/config/manager_config.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
DataIngestionConfig,
1111
DataProcessingConfig,
1212
DataTrainingConfig,
13+
ModelArgumentsConfig
1314

1415
)
1516

@@ -104,4 +105,20 @@ def get_data_training_arguments_config(self) -> DataTrainingConfig:
104105
)
105106

106107
return data_training_config
108+
109+
### MODEL TRAINING CONFIG PARAMS PHASE ###
110+
def get_model_training_arguments_config(self) -> ModelArgumentsConfig:
111+
config = self.config.model.stages.models
112+
113+
create_directories([config.root_dir])
114+
115+
model_training_config = ModelArgumentsConfig(
116+
root_dir=config.root_dir,
117+
model_name=config.model_name,
118+
model_path=config.model_path,
119+
model_params=config.model_params,
120+
model_description=config.model_description,
121+
)
122+
123+
return model_training_config
107124

src/ano_detection/models/base_model.py

Lines changed: 61 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,20 @@
44
from abc import ABC, abstractmethod # ( Abstract Base Classes Library )
55
from src.ano_detection.logger import logger
66
from src.ano_detection.exception import MyException
7+
from src.ano_detection.config import ModelArgumentsConfig
8+
from src.ano_detection.visualization import VILModel
9+
from src.config_params import ROOT_PROJECT
710

11+
from sklearn.metrics import (
12+
accuracy_score,
13+
recall_score,
14+
precision_score,
15+
f1_score,
16+
roc_curve,
17+
auc,
18+
precision_recall_curve
19+
)
20+
from sklearn.preprocessing import label_binarize
821

922
class BaseModel(ABC):
1023
"""
@@ -19,9 +32,18 @@ class BaseModel(ABC):
1932
"""
2033
@abstractmethod
2134
def __init__(self,
22-
config: object,
35+
config: ModelArgumentsConfig,
2336
**kwargs):
24-
pass
37+
super(BaseModel, self).__init__(**kwargs)
38+
self.config = config
39+
self.root_dir = self.config.root_dir
40+
self.model_name = self.config.model_name
41+
self.model_path = ROOT_PROJECT / self.root_dir / self.config.model_path
42+
self.model_params = self.config.model_params
43+
self.model_description = self.config.model_description
44+
45+
self.visualization = VILModel()
46+
2547

2648
@abstractmethod
2749
def __repr__(self):
@@ -47,4 +69,41 @@ def save_model(self):
4769
def load_model(self):
4870
pass
4971

72+
@abstractmethod
73+
def get_score(self, y_true, y_pred, y_scores, n_classes):
74+
accuracy = accuracy_score(y_true, y_pred)
75+
recall = recall_score(y_true, y_pred, average='weighted')
76+
precision = precision_score(y_true, y_pred, average='weighted')
77+
f1 = f1_score(y_true, y_pred, average='weighted')
78+
79+
y_test_bin = label_binarize(y_true, classes=[0, 1, 2]) # Convert y_test to binary format
80+
81+
# Compute ROC curve and AUC for each class
82+
fpr = {}
83+
tpr = {}
84+
roc_auc = {}
85+
for i in range(n_classes):
86+
fpr[i], tpr[i], _ = roc_curve(y_test_bin[:, i], y_scores[:, i])
87+
roc_auc[i] = auc(fpr[i], tpr[i])
88+
89+
# Compute Precision-Recall curve and AUC for each class
90+
precision_curve_pr = {}
91+
recall_curve_pr = {}
92+
pr_auc = {}
93+
for i in range(n_classes):
94+
precision_curve_pr[i], recall_curve_pr[i], _ = precision_recall_curve(y_test_bin[:, i], y_scores[:, i])
95+
pr_auc[i] = auc(recall_curve_pr[i], precision_curve_pr[i])
96+
97+
98+
print(f"Accuracy: {accuracy:.2f}")
99+
print(f"Recall: {recall:.2f}")
100+
print(f"Precision: {precision:.2f}")
101+
print(f"F1-score: {f1:.2f}")
102+
103+
self.visualization.auc_roc_curve(fpr, tpr, roc_auc, n_classes)
104+
self.visualization.auc_pr_curve(precision_curve_pr, recall_curve_pr, pr_auc, n_classes)
105+
print("Complete!")
106+
107+
return accuracy, recall, precision, f1
108+
50109

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
import os
2+
import sys
3+
4+
from .base_model import BaseModel
5+
# from src.ano_detection.config import ModelArgumentsConfig
6+
from src.ano_detection.logger import logger
7+
from src.ano_detection.exception import MyException
8+
9+
from sklearn.ensemble import GradientBoostingClassifier
10+
from sklearn.metrics import classification_report
11+
12+
from src.ano_detection.utils import save_pickle, load_pickle
13+
14+
15+
16+
class GradientBoostingModel(BaseModel):
17+
"""
18+
Configuration for Gradient Boosting Model
19+
20+
Params:
21+
config: object
22+
23+
Returns:
24+
model for training and prediction tasks
25+
26+
"""
27+
def __init__(self, **kwargs):
28+
super(GradientBoostingModel, self).__init__(**kwargs)
29+
30+
# self.config = kwargs.get("config")
31+
32+
self.gradient_boosting_model = GradientBoostingClassifier(random_state=2024,
33+
n_estimators=100,
34+
learning_rate=0.1,
35+
max_depth=10,
36+
min_samples_leaf=1,
37+
min_weight_fraction_leaf=0.0,
38+
max_features=1.0)
39+
40+
def __repr__(self):
41+
print(f"GradientBoostingModel config: {self.config}")
42+
logger.log_message("info", f"GradientBoostingModel config: {self.config}")
43+
return f"{self.__class__.__name__}"
44+
45+
def fit(self, X_train, y_train):
46+
try:
47+
logger.log_message("info", "Fitting Gradient Boosting Model....")
48+
self.gradient_boosting_model.fit(X_train, y_train)
49+
50+
logger.log_message("info", "Fitted Gradient Boosting Model successfully....")
51+
52+
except Exception as e:
53+
logger.log_message("error", f"Error in fitting Gradient Boosting Model: {e}")
54+
my_exception = MyException(
55+
error_message="Error in fitting Gradient Boosting Model",
56+
error_details=sys
57+
)
58+
print(my_exception)
59+
60+
def predict(self, X_test, y_test):
61+
try:
62+
logger.log_message("info", "Predicting Gradient Boosting Model....")
63+
y_pred_gbc = self.gradient_boosting_model.predict(X_test)
64+
print(classification_report(y_test, y_pred_gbc))
65+
66+
logger.log_message("info", "Predicted Gradient Boosting Model successfully....")
67+
68+
return y_pred_gbc
69+
70+
except Exception as e:
71+
logger.log_message("error", f"Error in predicting Gradient Boosting Model: {e}")
72+
my_exception = MyException(
73+
error_message="Error in predicting Gradient Boosting Model",
74+
error_details=sys
75+
)
76+
print(my_exception)
77+
78+
79+
def save_model(self):
80+
try:
81+
logger.log_message("info", "Saving Gradient Boosting Model....")
82+
save_pickle(self.gradient_boosting_model, self.model_path)
83+
84+
logger.log_message("info", "Saved Gradient Boosting Model successfully....")
85+
86+
except Exception as e:
87+
logger.log_message("error", f"Error in saving Gradient Boosting Model: {e}")
88+
my_exception = MyException(
89+
error_message="Error in saving Gradient Boosting Model",
90+
error_details=sys
91+
)
92+
print(my_exception)
93+
94+
def load_model(self):
95+
try:
96+
logger.log_message("info", "Loading Gradient Boosting Model....")
97+
self.gradient_boosting_model = load_pickle(self.model_path)
98+
99+
logger.log_message("info", "Loaded Gradient Boosting Model successfully....")
100+
101+
except Exception as e:
102+
logger.log_message("error", f"Error in loading Gradient Boosting Model: {e}")
103+
my_exception = MyException(
104+
error_message="Error in loading Gradient Boosting Model",
105+
error_details=sys
106+
107+
)
108+
print(my_exception)
109+
110+
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
import os
2+
import sys
3+
4+
from .base_model import BaseModel
5+
# from src.ano_detection.config import ModelArgumentsConfig
6+
from src.ano_detection.logger import logger
7+
from src.ano_detection.exception import MyException
8+
9+
from sklearn.linear_model import LogisticRegression
10+
from sklearn.metrics import classification_report
11+
12+
from src.ano_detection.utils import save_pickle, load_pickle
13+
14+
15+
16+
class LogisticRegressionModel(BaseModel):
17+
"""
18+
Configuration for Logistic Regression Model
19+
20+
Params:
21+
config: object
22+
23+
Returns:
24+
model for training and prediction tasks
25+
26+
"""
27+
def __init__(self, **kwargs):
28+
super(LogisticRegressionModel, self).__init__(**kwargs)
29+
30+
# self.config = kwargs.get("config")
31+
32+
self.lr_model = LogisticRegression(random_state=2024,
33+
max_iter=1000,
34+
solver='lbfgs',
35+
multi_class='multinomial')
36+
37+
def __repr__(self):
38+
print(f"LogisticRegressionModel config: {self.config}")
39+
logger.log_message("info", f"LLogisticRegressionModel config: {self.config}")
40+
return f"{self.__class__.__name__}"
41+
42+
def fit(self, X_train, y_train):
43+
try:
44+
logger.log_message("info", "Fitting Logistic Regression Model....")
45+
self.lr_model.fit(X_train, y_train)
46+
47+
logger.log_message("info", "Fitted Logistic Regression Model successfully....")
48+
49+
except Exception as e:
50+
logger.log_message("error", f"Error in fitting Logistic Regression Model: {e}")
51+
my_exception = MyException(
52+
error_message="Error in fitting Logistic Regression Model",
53+
error_details=sys
54+
)
55+
print(my_exception)
56+
57+
def predict(self, X_test, y_test):
58+
try:
59+
logger.log_message("info", "Predicting Logistic Regression Model....")
60+
y_pred_gbc = self.lr_model.predict(X_test)
61+
print(classification_report(y_test, y_pred_gbc))
62+
63+
logger.log_message("info", "Predicted Logistic Regression Model successfully....")
64+
65+
return y_pred_gbc
66+
67+
except Exception as e:
68+
logger.log_message("error", f"Error in predicting Logistic Regression Model: {e}")
69+
my_exception = MyException(
70+
error_message="Error in predicting Logistic Regression Model",
71+
error_details=sys
72+
)
73+
print(my_exception)
74+
75+
76+
def save_model(self):
77+
try:
78+
logger.log_message("info", "Saving Logistic Regression Model....")
79+
save_pickle(self.lr_model, self.model_path)
80+
81+
logger.log_message("info", "Saved Logistic Regression Model successfully....")
82+
83+
except Exception as e:
84+
logger.log_message("error", f"Error in saving Logistic Regression Model: {e}")
85+
my_exception = MyException(
86+
error_message="Error in saving Logistic Regression Model",
87+
error_details=sys
88+
)
89+
print(my_exception)
90+
91+
def load_model(self):
92+
try:
93+
logger.log_message("info", "Loading Logistic Regression Model....")
94+
self.lr_model = load_pickle(self.model_path)
95+
96+
logger.log_message("info", "Loaded Logistic Regression Model successfully....")
97+
98+
except Exception as e:
99+
logger.log_message("error", f"Error in loading Logistic Regression Model: {e}")
100+
my_exception = MyException(
101+
error_message="Error in loading Logistic Regression Model",
102+
error_details=sys
103+
104+
)
105+
print(my_exception)
106+
107+

0 commit comments

Comments
 (0)