diff --git a/src/alerts/alert_types.rs b/src/alerts/alert_types.rs new file mode 100644 index 000000000..74be90393 --- /dev/null +++ b/src/alerts/alert_types.rs @@ -0,0 +1,250 @@ +/* + * Parseable Server (C) 2022 - 2024 Parseable, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as + * published by the Free Software Foundation, either version 3 of the + * License, or (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + * + */ + +use std::time::Duration; + +use chrono::{DateTime, Utc}; +use tonic::async_trait; +use ulid::Ulid; + +use crate::{ + alerts::{ + AlertConfig, AlertError, AlertState, AlertType, AlertVersion, EvalConfig, Severity, + ThresholdConfig, + alerts_utils::{evaluate_condition, execute_alert_query, extract_time_range}, + is_query_aggregate, + target::{self, TARGETS}, + traits::AlertTrait, + }, + handlers::http::query::create_streams_for_distributed, + option::Mode, + parseable::PARSEABLE, + query::resolve_stream_names, + rbac::map::SessionKey, + utils::user_auth_for_query, +}; + +/// Struct which defines the threshold type alerts +#[derive(Debug, serde::Serialize, serde::Deserialize, Clone)] +pub struct ThresholdAlert { + pub version: AlertVersion, + #[serde(default)] + pub id: Ulid, + pub severity: Severity, + pub title: String, + pub query: String, + pub alert_type: AlertType, + pub threshold_config: ThresholdConfig, + pub eval_config: EvalConfig, + pub targets: Vec, + // for new alerts, state should be resolved + #[serde(default)] + pub state: AlertState, + pub created: DateTime, + pub tags: Option>, + pub datasets: Vec, +} + +#[async_trait] +impl AlertTrait for ThresholdAlert { + async fn eval_alert(&self) -> Result<(bool, f64), AlertError> { + let time_range = extract_time_range(&self.eval_config)?; + let final_value = execute_alert_query(self.get_query(), &time_range).await?; + let result = evaluate_condition( + &self.threshold_config.operator, + final_value, + self.threshold_config.value, + ); + Ok((result, final_value)) + } + + async fn validate(&self, session_key: &SessionKey) -> Result<(), AlertError> { + // validate alert type + // Anomaly is only allowed in Prism + if self.alert_type.eq(&AlertType::Anomaly) && PARSEABLE.options.mode != Mode::Prism { + return Err(AlertError::CustomError( + "Anomaly alert is only allowed on Prism mode".into(), + )); + } + + // validate evalType + let eval_frequency = match &self.eval_config { + EvalConfig::RollingWindow(rolling_window) => { + if humantime::parse_duration(&rolling_window.eval_start).is_err() { + return Err(AlertError::Metadata( + "evalStart should be of type humantime", + )); + } + rolling_window.eval_frequency + } + }; + + // validate that target repeat notifs !> eval_frequency + for target_id in &self.targets { + let target = TARGETS.get_target_by_id(target_id).await?; + match &target.notification_config.times { + target::Retry::Infinite => {} + target::Retry::Finite(repeat) => { + let notif_duration = + Duration::from_secs(60 * target.notification_config.interval) + * *repeat as u32; + if (notif_duration.as_secs_f64()).gt(&((eval_frequency * 60) as f64)) { + return Err(AlertError::Metadata( + "evalFrequency should be greater than target repetition interval", + )); + } + } + } + } + + // validate that the query is valid + if self.query.is_empty() { + return Err(AlertError::InvalidAlertQuery); + } + + let tables = resolve_stream_names(&self.query)?; + if tables.is_empty() { + return Err(AlertError::InvalidAlertQuery); + } + create_streams_for_distributed(tables) + .await + .map_err(|_| AlertError::InvalidAlertQuery)?; + + // validate that the user has access to the tables mentioned in the query + user_auth_for_query(session_key, &self.query).await?; + + // validate that the alert query is valid and can be evaluated + if !is_query_aggregate(&self.query).await? { + return Err(AlertError::InvalidAlertQuery); + } + Ok(()) + } + + fn get_id(&self) -> &Ulid { + &self.id + } + + fn get_query(&self) -> &str { + &self.query + } + + fn get_severity(&self) -> &Severity { + &self.severity + } + + fn get_title(&self) -> &str { + &self.title + } + + fn get_alert_type(&self) -> &AlertType { + &self.alert_type + } + + fn get_threshold_config(&self) -> &ThresholdConfig { + &self.threshold_config + } + + fn get_eval_config(&self) -> &EvalConfig { + &self.eval_config + } + + fn get_targets(&self) -> &Vec { + &self.targets + } + + fn get_state(&self) -> &AlertState { + &self.state + } + + fn get_eval_frequency(&self) -> u64 { + match &self.eval_config { + EvalConfig::RollingWindow(rolling_window) => rolling_window.eval_frequency, + } + } + + fn get_eval_window(&self) -> String { + match &self.eval_config { + EvalConfig::RollingWindow(rolling_window) => rolling_window.eval_start.clone(), + } + } + + fn get_created(&self) -> String { + self.created.to_string() + } + + fn get_tags(&self) -> &Option> { + &self.tags + } + + fn get_datasets(&self) -> &Vec { + &self.datasets + } + + fn to_alert_config(&self) -> AlertConfig { + let clone = self.clone(); + clone.into() + } + + fn clone_box(&self) -> Box { + Box::new(self.clone()) + } + + fn set_state(&mut self, new_state: AlertState) { + self.state = new_state + } +} + +impl From for ThresholdAlert { + fn from(value: AlertConfig) -> Self { + Self { + version: value.version, + id: value.id, + severity: value.severity, + title: value.title, + query: value.query, + alert_type: value.alert_type, + threshold_config: value.threshold_config, + eval_config: value.eval_config, + targets: value.targets, + state: value.state, + created: value.created, + tags: value.tags, + datasets: value.datasets, + } + } +} + +impl From for AlertConfig { + fn from(val: ThresholdAlert) -> Self { + AlertConfig { + version: val.version, + id: val.id, + severity: val.severity, + title: val.title, + query: val.query, + alert_type: val.alert_type, + threshold_config: val.threshold_config, + eval_config: val.eval_config, + targets: val.targets, + state: val.state, + created: val.created, + tags: val.tags, + datasets: val.datasets, + } + } +} diff --git a/src/alerts/alerts_utils.rs b/src/alerts/alerts_utils.rs index d60b0fb09..f6eb0b1ce 100644 --- a/src/alerts/alerts_utils.rs +++ b/src/alerts/alerts_utils.rs @@ -24,10 +24,11 @@ use datafusion::{ logical_expr::Literal, prelude::{Expr, lit}, }; +use itertools::Itertools; use tracing::trace; use crate::{ - alerts::{Conditions, LogicalOperator, WhereConfigOperator}, + alerts::{AlertTrait, Conditions, LogicalOperator, WhereConfigOperator}, handlers::http::{ cluster::send_query_request, query::{Query, create_streams_for_distributed}, @@ -38,7 +39,7 @@ use crate::{ utils::time::TimeRange, }; -use super::{ALERTS, AlertConfig, AlertError, AlertOperator, AlertState}; +use super::{ALERTS, AlertError, AlertOperator, AlertState}; /// accept the alert /// @@ -51,22 +52,16 @@ use super::{ALERTS, AlertConfig, AlertError, AlertOperator, AlertState}; /// collect the results in the end /// /// check whether notification needs to be triggered or not -pub async fn evaluate_alert(alert: &AlertConfig) -> Result<(), AlertError> { +pub async fn evaluate_alert(alert: &dyn AlertTrait) -> Result<(), AlertError> { trace!("RUNNING EVAL TASK FOR- {alert:?}"); - let time_range = extract_time_range(&alert.eval_config)?; - let final_value = execute_alert_query(alert, &time_range).await?; - let result = evaluate_condition( - &alert.threshold_config.operator, - final_value, - alert.threshold_config.value, - ); + let (result, final_value) = alert.eval_alert().await?; update_alert_state(alert, result, final_value).await } /// Extract time range from alert evaluation configuration -fn extract_time_range(eval_config: &super::EvalConfig) -> Result { +pub fn extract_time_range(eval_config: &super::EvalConfig) -> Result { let (start_time, end_time) = match eval_config { super::EvalConfig::RollingWindow(rolling_window) => (&rolling_window.eval_start, "now"), }; @@ -76,13 +71,10 @@ fn extract_time_range(eval_config: &super::EvalConfig) -> Result Result { +pub async fn execute_alert_query(query: &str, time_range: &TimeRange) -> Result { match PARSEABLE.options.mode { - Mode::All | Mode::Query => execute_local_query(alert, time_range).await, - Mode::Prism => execute_remote_query(alert, time_range).await, + Mode::All | Mode::Query => execute_local_query(query, time_range).await, + Mode::Prism => execute_remote_query(query, time_range).await, _ => Err(AlertError::CustomError(format!( "Unsupported mode '{:?}' for alert evaluation", PARSEABLE.options.mode @@ -91,12 +83,8 @@ async fn execute_alert_query( } /// Execute alert query locally (Query/All mode) -async fn execute_local_query( - alert: &AlertConfig, - time_range: &TimeRange, -) -> Result { +async fn execute_local_query(query: &str, time_range: &TimeRange) -> Result { let session_state = QUERY_SESSION.state(); - let query = &alert.query; let tables = resolve_stream_names(query)?; create_streams_for_distributed(tables.clone()) @@ -127,12 +115,9 @@ async fn execute_local_query( } /// Execute alert query remotely (Prism mode) -async fn execute_remote_query( - alert: &AlertConfig, - time_range: &TimeRange, -) -> Result { +async fn execute_remote_query(query: &str, time_range: &TimeRange) -> Result { let query_request = Query { - query: alert.query.clone(), + query: query.to_string(), start_time: time_range.start.to_rfc3339(), end_time: time_range.end.to_rfc3339(), streaming: false, @@ -150,20 +135,21 @@ async fn execute_remote_query( /// Convert JSON result value to f64 fn convert_result_to_f64(result_value: serde_json::Value) -> Result { - if let Some(value) = result_value.as_f64() { - Ok(value) - } else if let Some(value) = result_value.as_i64() { - Ok(value as f64) - } else if let Some(value) = result_value.as_u64() { - Ok(value as f64) + // due to the previous validations, we can be sure that we get an array of objects with just one entry + // [{"countField": Number(1120.251)}] + if let Some(array_val) = result_value.as_array().filter(|arr| !arr.is_empty()) + && let Some(object) = array_val[0].as_object() + { + let values = object.values().map(|v| v.as_f64().unwrap()).collect_vec(); + Ok(values[0]) } else { Err(AlertError::CustomError( - "Query result is not a number".to_string(), + "Query result is not a number or response is empty".to_string(), )) } } -fn evaluate_condition(operator: &AlertOperator, actual: f64, expected: f64) -> bool { +pub fn evaluate_condition(operator: &AlertOperator, actual: f64, expected: f64) -> bool { match operator { AlertOperator::GreaterThan => actual > expected, AlertOperator::LessThan => actual < expected, @@ -175,31 +161,43 @@ fn evaluate_condition(operator: &AlertOperator, actual: f64, expected: f64) -> b } async fn update_alert_state( - alert: &AlertConfig, + alert: &dyn AlertTrait, final_res: bool, actual_value: f64, ) -> Result<(), AlertError> { + let guard = ALERTS.write().await; + let alerts = if let Some(alerts) = guard.as_ref() { + alerts + } else { + return Err(AlertError::CustomError("No AlertManager set".into())); + }; + if final_res { let message = format!( "Alert Triggered: {}\n\nThreshold: ({} {})\nCurrent Value: {}\nEvaluation Window: {} | Frequency: {}\n\nQuery:\n{}", - alert.id, - alert.threshold_config.operator, - alert.threshold_config.value, + alert.get_id(), + alert.get_threshold_config().operator, + alert.get_threshold_config().value, actual_value, alert.get_eval_window(), alert.get_eval_frequency(), - alert.query + alert.get_query() ); - ALERTS - .update_state(alert.id, AlertState::Triggered, Some(message)) + + alerts + .update_state(*alert.get_id(), AlertState::Triggered, Some(message)) .await - } else if ALERTS.get_state(alert.id).await?.eq(&AlertState::Triggered) { - ALERTS - .update_state(alert.id, AlertState::Resolved, Some("".into())) + } else if alerts + .get_state(*alert.get_id()) + .await? + .eq(&AlertState::Triggered) + { + alerts + .update_state(*alert.get_id(), AlertState::Resolved, Some("".into())) .await } else { - ALERTS - .update_state(alert.id, AlertState::Resolved, None) + alerts + .update_state(*alert.get_id(), AlertState::Resolved, None) .await } } diff --git a/src/alerts/mod.rs b/src/alerts/mod.rs index 6429f519b..6ce1ff02b 100644 --- a/src/alerts/mod.rs +++ b/src/alerts/mod.rs @@ -17,7 +17,7 @@ */ use actix_web::http::header::ContentType; -use arrow_schema::{DataType, Schema}; +use arrow_schema::{ArrowError, DataType, Schema}; use async_trait::async_trait; use chrono::{DateTime, Utc}; use datafusion::logical_expr::{LogicalPlan, Projection}; @@ -25,11 +25,12 @@ use datafusion::sql::sqlparser::parser::ParserError; use derive_more::FromStrError; use derive_more::derive::FromStr; use http::StatusCode; -use once_cell::sync::Lazy; +// use once_cell::sync::Lazy; use serde::Serialize; use serde_json::{Error as SerdeError, Value as JsonValue}; use std::collections::HashMap; -use std::fmt::{self, Display}; +use std::fmt::{self, Debug, Display}; +use std::sync::Arc; use std::thread; use std::time::Duration; use tokio::sync::oneshot::{Receiver, Sender}; @@ -38,12 +39,17 @@ use tokio::task::JoinHandle; use tracing::{error, trace, warn}; use ulid::Ulid; +pub mod alert_types; pub mod alerts_utils; pub mod target; +pub mod traits; +use crate::alerts::alert_types::ThresholdAlert; use crate::alerts::target::TARGETS; +use crate::alerts::traits::{AlertManagerTrait, AlertTrait}; use crate::handlers::http::fetch_schema; use crate::handlers::http::query::create_streams_for_distributed; +use crate::option::Mode; use crate::parseable::{PARSEABLE, StreamNotFound}; use crate::query::{QUERY_SESSION, resolve_stream_names}; use crate::rbac::map::SessionKey; @@ -64,38 +70,68 @@ pub type ScheduledTaskHandlers = (JoinHandle<()>, Receiver<()>, Sender<()>); pub const CURRENT_ALERTS_VERSION: &str = "v2"; -pub static ALERTS: Lazy = Lazy::new(|| { +pub static ALERTS: RwLock>> = RwLock::const_new(None); + +pub async fn get_alert_manager() -> Arc { + let guard = ALERTS.read().await; + if let Some(manager) = guard.as_ref() { + manager.clone() + } else { + drop(guard); + let mut write_guard = ALERTS.write().await; + if write_guard.is_none() { + *write_guard = Some(Arc::new(create_default_alerts_manager())); + } + write_guard.as_ref().unwrap().clone() + } +} + +pub async fn set_alert_manager(manager: Arc) { + *ALERTS.write().await = Some(manager); +} + +pub fn create_default_alerts_manager() -> Alerts { let (tx, rx) = mpsc::channel::(10); let alerts = Alerts { alerts: RwLock::new(HashMap::new()), sender: tx, }; - thread::spawn(|| alert_runtime(rx)); - alerts -}); +} + +// pub static ALERTS: Lazy = Lazy::new(|| { +// let (tx, rx) = mpsc::channel::(10); +// let alerts = Alerts { +// alerts: RwLock::new(HashMap::new()), +// sender: tx, +// }; + +// thread::spawn(|| alert_runtime(rx)); + +// alerts +// }); #[derive(Debug)] pub struct Alerts { - pub alerts: RwLock>, + pub alerts: RwLock>>, pub sender: mpsc::Sender, } pub enum AlertTask { - Create(Box), + Create(Box), Delete(Ulid), } #[derive(Default, Debug, serde::Serialize, serde::Deserialize, Clone)] #[serde(rename_all = "lowercase")] -pub enum AlertVerison { +pub enum AlertVersion { V1, #[default] V2, } -impl From<&str> for AlertVerison { +impl From<&str> for AlertVersion { fn from(value: &str) -> Self { match value { "v1" => Self::V1, @@ -191,16 +227,18 @@ impl DeploymentInfo { } } -#[derive(Debug, serde::Serialize, serde::Deserialize, Clone)] +#[derive(Debug, serde::Serialize, serde::Deserialize, Clone, PartialEq)] #[serde(rename_all = "camelCase")] pub enum AlertType { Threshold, + Anomaly, } impl Display for AlertType { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { AlertType::Threshold => write!(f, "threshold"), + AlertType::Anomaly => write!(f, "anomaly"), } } } @@ -461,7 +499,11 @@ impl AlertState { return Err(AlertError::InvalidStateChange(msg)); } else { // update state on disk and in memory - ALERTS + let guard = ALERTS.read().await; + let alerts = guard.as_ref().ok_or_else(|| { + AlertError::CustomError("Alert manager not initialized".into()) + })?; + alerts .update_state(alert_id, new_state, Some("".into())) .await?; } @@ -470,7 +512,11 @@ impl AlertState { // from here, the user can only go to Resolved if new_state == AlertState::Resolved { // update state on disk and in memory - ALERTS + let guard = ALERTS.read().await; + let alerts = guard.as_ref().ok_or_else(|| { + AlertError::CustomError("Alert manager not initialized".into()) + })?; + alerts .update_state(alert_id, new_state, Some("".into())) .await?; } else { @@ -501,10 +547,10 @@ pub enum Severity { impl Display for Severity { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { - Severity::Critical => write!(f, "Critical (P0)"), - Severity::High => write!(f, "High (P1)"), - Severity::Medium => write!(f, "Medium (P2)"), - Severity::Low => write!(f, "Low (P3)"), + Severity::Critical => write!(f, "Critical"), + Severity::High => write!(f, "High"), + Severity::Medium => write!(f, "Medium"), + Severity::Low => write!(f, "Low"), } } } @@ -547,7 +593,7 @@ impl AlertRequest { } let datasets = resolve_stream_names(&self.query)?; let config = AlertConfig { - version: AlertVerison::from(CURRENT_ALERTS_VERSION), + version: AlertVersion::from(CURRENT_ALERTS_VERSION), id: Ulid::new(), severity: self.severity, title: self.title, @@ -568,7 +614,7 @@ impl AlertRequest { #[derive(Debug, serde::Serialize, serde::Deserialize, Clone)] #[serde(rename_all = "camelCase")] pub struct AlertConfig { - pub version: AlertVerison, + pub version: AlertVersion, #[serde(default)] pub id: Ulid, pub severity: Severity, @@ -588,7 +634,7 @@ pub struct AlertConfig { impl AlertConfig { /// Migration function to convert v1 alerts to v2 structure - async fn migrate_from_v1( + pub async fn migrate_from_v1( alert_json: &JsonValue, store: &dyn crate::storage::ObjectStorage, ) -> Result { @@ -604,7 +650,7 @@ impl AlertConfig { // Create the migrated v2 alert let migrated_alert = AlertConfig { - version: AlertVerison::V2, + version: AlertVersion::V2, id: basic_fields.id, severity: basic_fields.severity, title: basic_fields.title, @@ -1067,6 +1113,14 @@ impl AlertConfig { /// Validations pub async fn validate(&self, session_key: SessionKey) -> Result<(), AlertError> { + // validate alert type + // Anomaly is only allowed in Prism + if self.alert_type.eq(&AlertType::Anomaly) && PARSEABLE.options.mode != Mode::Prism { + return Err(AlertError::CustomError( + "Anomaly alert is only allowed on Prism mode".into(), + )); + } + // validate evalType let eval_frequency = match &self.eval_config { EvalConfig::RollingWindow(rolling_window) => { @@ -1114,50 +1168,12 @@ impl AlertConfig { user_auth_for_query(&session_key, &self.query).await?; // validate that the alert query is valid and can be evaluated - if !Self::is_query_aggregate(&self.query).await? { + if !is_query_aggregate(&self.query).await? { return Err(AlertError::InvalidAlertQuery); } Ok(()) } - /// Check if a query is an aggregate query that returns a single value without executing it - async fn is_query_aggregate(query: &str) -> Result { - let session_state = QUERY_SESSION.state(); - - // Parse the query into a logical plan - let logical_plan = session_state - .create_logical_plan(query) - .await - .map_err(|err| AlertError::CustomError(format!("Failed to parse query: {err}")))?; - - // Check if the plan structure indicates an aggregate query - Ok(Self::is_logical_plan_aggregate(&logical_plan)) - } - - /// Analyze a logical plan to determine if it represents an aggregate query - fn is_logical_plan_aggregate(plan: &LogicalPlan) -> bool { - match plan { - // Direct aggregate: SELECT COUNT(*), AVG(col), etc. - LogicalPlan::Aggregate(_) => true, - - // Projection over aggregate: SELECT COUNT(*) as total, SELECT AVG(col) as average - LogicalPlan::Projection(Projection { input, expr, .. }) => { - // Check if input contains an aggregate and we have exactly one expression - let is_aggregate_input = Self::is_logical_plan_aggregate(input); - let single_expr = expr.len() == 1; - is_aggregate_input && single_expr - } - - // Recursively check wrapped plans (Filter, Limit, Sort, etc.) - _ => { - // Use inputs() method to get all input plans - plan.inputs() - .iter() - .any(|input| Self::is_logical_plan_aggregate(input)) - } - } - } - pub fn get_eval_frequency(&self) -> u64 { match &self.eval_config { EvalConfig::RollingWindow(rolling_window) => rolling_window.eval_frequency, @@ -1264,6 +1280,44 @@ impl AlertConfig { } } +/// Check if a query is an aggregate query that returns a single value without executing it +pub async fn is_query_aggregate(query: &str) -> Result { + let session_state = QUERY_SESSION.state(); + + // Parse the query into a logical plan + let logical_plan = session_state + .create_logical_plan(query) + .await + .map_err(|err| AlertError::CustomError(format!("Failed to parse query: {err}")))?; + + // Check if the plan structure indicates an aggregate query + Ok(is_logical_plan_aggregate(&logical_plan)) +} + +/// Analyze a logical plan to determine if it represents an aggregate query +pub fn is_logical_plan_aggregate(plan: &LogicalPlan) -> bool { + match plan { + // Direct aggregate: SELECT COUNT(*), AVG(col), etc. + LogicalPlan::Aggregate(_) => true, + + // Projection over aggregate: SELECT COUNT(*) as total, SELECT AVG(col) as average + LogicalPlan::Projection(Projection { input, expr, .. }) => { + // Check if input contains an aggregate and we have exactly one expression + let is_aggregate_input = is_logical_plan_aggregate(input); + let single_expr = expr.len() == 1; + is_aggregate_input && single_expr + } + + // Recursively check wrapped plans (Filter, Limit, Sort, etc.) + _ => { + // Use inputs() method to get all input plans + plan.inputs() + .iter() + .any(|input| is_logical_plan_aggregate(input)) + } + } +} + #[derive(Debug, thiserror::Error)] pub enum AlertError { #[error("Storage Error: {0}")] @@ -1302,6 +1356,8 @@ pub enum AlertError { InvalidAlertQuery, #[error("Invalid query parameter")] InvalidQueryParameter, + #[error("{0}")] + ArrowError(#[from] ArrowError), } impl actix_web::ResponseError for AlertError { @@ -1325,6 +1381,7 @@ impl actix_web::ResponseError for AlertError { Self::ParserError(_) => StatusCode::BAD_REQUEST, Self::InvalidAlertQuery => StatusCode::BAD_REQUEST, Self::InvalidQueryParameter => StatusCode::BAD_REQUEST, + Self::ArrowError(_) => StatusCode::INTERNAL_SERVER_ERROR, } } @@ -1335,9 +1392,10 @@ impl actix_web::ResponseError for AlertError { } } -impl Alerts { +#[async_trait] +impl AlertManagerTrait for Alerts { /// Loads alerts from disk, blocks - pub async fn load(&self) -> anyhow::Result<()> { + async fn load(&self) -> anyhow::Result<()> { let mut map = self.alerts.write().await; let store = PARSEABLE.storage.get_object_store(); @@ -1398,21 +1456,24 @@ impl Alerts { } }; + let alert: Box = match &alert.alert_type { + AlertType::Threshold => { + Box::new(ThresholdAlert::from(alert)) as Box + } + AlertType::Anomaly => { + return Err(anyhow::Error::msg( + "Get Parseable Enterprise for Anomaly alerts", + )); + } + }; + // Create alert task - match self - .sender - .send(AlertTask::Create(Box::new(alert.clone()))) - .await - { + match self.sender.send(AlertTask::Create(alert.clone_box())).await { Ok(_) => {} Err(e) => { warn!("Failed to create alert task: {e}\nRetrying..."); // Retry sending the task - match self - .sender - .send(AlertTask::Create(Box::new(alert.clone()))) - .await - { + match self.sender.send(AlertTask::Create(alert.clone_box())).await { Ok(_) => {} Err(e) => { error!("Failed to create alert task: {e}"); @@ -1422,14 +1483,14 @@ impl Alerts { } }; - map.insert(alert.id, alert); + map.insert(*alert.get_id(), alert); } Ok(()) } /// Returns a list of alerts that the user has access to (based on query auth) - pub async fn list_alerts_for_user( + async fn list_alerts_for_user( &self, session: SessionKey, tags: Vec, @@ -1437,8 +1498,11 @@ impl Alerts { let mut alerts: Vec = Vec::new(); for (_, alert) in self.alerts.read().await.iter() { // filter based on whether the user can execute this query or not - if user_auth_for_query(&session, &alert.query).await.is_ok() { - alerts.push(alert.to_owned()); + if user_auth_for_query(&session, alert.get_query()) + .await + .is_ok() + { + alerts.push(alert.to_alert_config()); } } if tags.is_empty() { @@ -1457,10 +1521,10 @@ impl Alerts { } /// Returns a sigle alert that the user has access to (based on query auth) - pub async fn get_alert_by_id(&self, id: Ulid) -> Result { + async fn get_alert_by_id(&self, id: Ulid) -> Result { let read_access = self.alerts.read().await; if let Some(alert) = read_access.get(&id) { - Ok(alert.clone()) + Ok(alert.to_alert_config()) } else { Err(AlertError::CustomError(format!( "No alert found for the given ID- {id}" @@ -1469,12 +1533,15 @@ impl Alerts { } /// Update the in-mem vector of alerts - pub async fn update(&self, alert: &AlertConfig) { - self.alerts.write().await.insert(alert.id, alert.clone()); + async fn update(&self, alert: &dyn AlertTrait) { + self.alerts + .write() + .await + .insert(*alert.get_id(), alert.clone_box()); } /// Update the state of alert - pub async fn update_state( + async fn update_state( &self, alert_id: Ulid, new_state: AlertState, @@ -1496,9 +1563,9 @@ impl Alerts { // modify in memory let mut writer = self.alerts.write().await; if let Some(alert) = writer.get_mut(&alert_id) { - trace!("in memory alert-\n{}", alert.state); - alert.state = new_state; - trace!("in memory updated alert-\n{}", alert.state); + trace!("in memory alert-\n{}", alert.get_state()); + alert.set_state(new_state); + trace!("in memory updated alert-\n{}", alert.get_state()); }; drop(writer); @@ -1511,7 +1578,7 @@ impl Alerts { } /// Remove alert and scheduled task from disk and memory - pub async fn delete(&self, alert_id: Ulid) -> Result<(), AlertError> { + async fn delete(&self, alert_id: Ulid) -> Result<(), AlertError> { if self.alerts.write().await.remove(&alert_id).is_some() { trace!("removed alert from memory"); } else { @@ -1521,11 +1588,11 @@ impl Alerts { } /// Get state of alert using alert_id - pub async fn get_state(&self, alert_id: Ulid) -> Result { + async fn get_state(&self, alert_id: Ulid) -> Result { let read_access = self.alerts.read().await; if let Some(alert) = read_access.get(&alert_id) { - Ok(alert.state) + Ok(*alert.get_state()) } else { let msg = format!("No alert present for ID- {alert_id}"); Err(AlertError::CustomError(msg)) @@ -1533,16 +1600,16 @@ impl Alerts { } /// Start a scheduled alert task - pub async fn start_task(&self, alert: AlertConfig) -> Result<(), AlertError> { + async fn start_task(&self, alert: Box) -> Result<(), AlertError> { self.sender - .send(AlertTask::Create(Box::new(alert))) + .send(AlertTask::Create(alert)) .await .map_err(|e| AlertError::CustomError(e.to_string()))?; Ok(()) } /// Remove a scheduled alert task - pub async fn delete_task(&self, alert_id: Ulid) -> Result<(), AlertError> { + async fn delete_task(&self, alert_id: Ulid) -> Result<(), AlertError> { self.sender .send(AlertTask::Delete(alert_id)) .await @@ -1553,17 +1620,22 @@ impl Alerts { /// List tags from all alerts /// This function returns a list of unique tags from all alerts - pub async fn list_tags(&self) -> Vec { + async fn list_tags(&self) -> Vec { let alerts = self.alerts.read().await; let mut tags = alerts .iter() - .filter_map(|(_, alert)| alert.tags.as_ref()) + .filter_map(|(_, alert)| alert.get_tags().as_ref()) .flat_map(|t| t.iter().cloned()) .collect::>(); tags.sort(); tags.dedup(); tags } + + async fn get_all_alerts(&self) -> HashMap> { + let alerts = self.alerts.read().await; + alerts.iter().map(|(k, v)| (*k, v.clone_box())).collect() + } } #[derive(Debug, Serialize)] @@ -1589,8 +1661,15 @@ pub struct AlertsInfo { // TODO: add RBAC pub async fn get_alerts_summary() -> Result { - let alerts = ALERTS.alerts.read().await; + let guard = ALERTS.read().await; + let alerts = if let Some(alerts) = guard.as_ref() { + alerts.get_all_alerts().await + } else { + return Err(AlertError::CustomError("No AlertManager registered".into())); + }; + let total = alerts.len() as u64; + let mut triggered = 0; let mut resolved = 0; let mut silenced = 0; @@ -1601,29 +1680,29 @@ pub async fn get_alerts_summary() -> Result { // find total alerts for each state // get title, id and state of each alert for that state for (_, alert) in alerts.iter() { - match alert.state { + match alert.get_state() { AlertState::Triggered => { triggered += 1; triggered_alerts.push(AlertsInfo { - title: alert.title.clone(), - id: alert.id, - severity: alert.severity.clone(), + title: alert.get_title().to_string(), + id: *alert.get_id(), + severity: alert.get_severity().clone(), }); } AlertState::Silenced => { silenced += 1; silenced_alerts.push(AlertsInfo { - title: alert.title.clone(), - id: alert.id, - severity: alert.severity.clone(), + title: alert.get_title().to_string(), + id: *alert.get_id(), + severity: alert.get_severity().clone(), }); } AlertState::Resolved => { resolved += 1; resolved_alerts.push(AlertsInfo { - title: alert.title.clone(), - id: alert.id, - severity: alert.severity.clone(), + title: alert.get_title().to_string(), + id: *alert.get_id(), + severity: alert.get_severity().clone(), }); } } diff --git a/src/alerts/target.rs b/src/alerts/target.rs index d13151ca0..12b7ded54 100644 --- a/src/alerts/target.rs +++ b/src/alerts/target.rs @@ -101,8 +101,15 @@ impl TargetConfigs { pub async fn delete(&self, target_id: &Ulid) -> Result { // ensure that the target is not being used by any alert - for (_, alert) in ALERTS.alerts.read().await.iter() { - if alert.targets.contains(target_id) { + let guard = ALERTS.read().await; + let alerts = if let Some(alerts) = guard.as_ref() { + alerts + } else { + return Err(AlertError::CustomError("No AlertManager set".into())); + }; + + for (_, alert) in alerts.get_all_alerts().await.iter() { + if alert.get_targets().contains(target_id) { return Err(AlertError::TargetInUse); } } @@ -282,9 +289,17 @@ impl Target { trace!("Spawning retry task"); tokio::spawn(async move { + let guard = ALERTS.read().await; + let alerts = if let Some(alerts) = guard.as_ref() { + alerts + } else { + error!("No AlertManager set for alert_id: {alert_id}, stopping timeout task"); + *state.lock().unwrap() = TimeoutState::default(); + return; + }; match retry { Retry::Infinite => loop { - let current_state = if let Ok(state) = ALERTS.get_state(alert_id).await { + let current_state = if let Ok(state) = alerts.get_state(alert_id).await { state } else { *state.lock().unwrap() = TimeoutState::default(); @@ -302,7 +317,7 @@ impl Target { }, Retry::Finite(times) => { for _ in 0..(times - 1) { - let current_state = if let Ok(state) = ALERTS.get_state(alert_id).await { + let current_state = if let Ok(state) = alerts.get_state(alert_id).await { state } else { *state.lock().unwrap() = TimeoutState::default(); diff --git a/src/alerts/traits.rs b/src/alerts/traits.rs new file mode 100644 index 000000000..81dcd2555 --- /dev/null +++ b/src/alerts/traits.rs @@ -0,0 +1,74 @@ +/* + * Parseable Server (C) 2022 - 2024 Parseable, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as + * published by the Free Software Foundation, either version 3 of the + * License, or (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + * + */ + +use crate::{ + alerts::{ + AlertConfig, AlertError, AlertState, AlertType, EvalConfig, Severity, ThresholdConfig, + }, + rbac::map::SessionKey, +}; +use std::{collections::HashMap, fmt::Debug}; +use tonic::async_trait; +use ulid::Ulid; + +#[async_trait] +pub trait AlertTrait: Debug + Send + Sync { + async fn eval_alert(&self) -> Result<(bool, f64), AlertError>; + async fn validate(&self, session_key: &SessionKey) -> Result<(), AlertError>; + fn get_id(&self) -> &Ulid; + fn get_severity(&self) -> &Severity; + fn get_title(&self) -> &str; + fn get_query(&self) -> &str; + fn get_alert_type(&self) -> &AlertType; + fn get_threshold_config(&self) -> &ThresholdConfig; + fn get_eval_config(&self) -> &EvalConfig; + fn get_targets(&self) -> &Vec; + fn get_state(&self) -> &AlertState; + fn get_eval_window(&self) -> String; + fn get_eval_frequency(&self) -> u64; + fn get_created(&self) -> String; + fn get_tags(&self) -> &Option>; + fn get_datasets(&self) -> &Vec; + fn to_alert_config(&self) -> AlertConfig; + fn clone_box(&self) -> Box; + fn set_state(&mut self, new_state: AlertState); +} + +#[async_trait] +pub trait AlertManagerTrait: Send + Sync { + async fn load(&self) -> anyhow::Result<()>; + async fn list_alerts_for_user( + &self, + session: SessionKey, + tags: Vec, + ) -> Result, AlertError>; + async fn get_alert_by_id(&self, id: Ulid) -> Result; + async fn update(&self, alert: &dyn AlertTrait); + async fn update_state( + &self, + alert_id: Ulid, + new_state: AlertState, + trigger_notif: Option, + ) -> Result<(), AlertError>; + async fn delete(&self, alert_id: Ulid) -> Result<(), AlertError>; + async fn get_state(&self, alert_id: Ulid) -> Result; + async fn start_task(&self, alert: Box) -> Result<(), AlertError>; + async fn delete_task(&self, alert_id: Ulid) -> Result<(), AlertError>; + async fn list_tags(&self) -> Vec; + async fn get_all_alerts(&self) -> HashMap>; +} diff --git a/src/handlers/http/alerts.rs b/src/handlers/http/alerts.rs index 98ee18610..dcdb46ac0 100644 --- a/src/handlers/http/alerts.rs +++ b/src/handlers/http/alerts.rs @@ -19,9 +19,9 @@ use std::{collections::HashMap, str::FromStr}; use crate::{ + alerts::{AlertType, alert_types::ThresholdAlert, traits::AlertTrait}, parseable::PARSEABLE, storage::object_storage::alert_json_path, - // sync::schedule_alert_task, utils::{actix::extract_session_key_from_req, user_auth_for_query}, }; use actix_web::{ @@ -53,8 +53,14 @@ pub async fn list(req: HttpRequest) -> Result { } } } - - let alerts = ALERTS.list_alerts_for_user(session_key, tags_list).await?; + let guard = ALERTS.read().await; + let alerts = if let Some(alerts) = guard.as_ref() { + alerts + } else { + return Err(AlertError::CustomError("No AlertManager set".into())); + }; + + let alerts = alerts.list_alerts_for_user(session_key, tags_list).await?; let alerts_summary = alerts .iter() .map(|alert| alert.to_summary()) @@ -69,26 +75,46 @@ pub async fn post( ) -> Result { let alert: AlertConfig = alert.into().await?; + let threshold_alert; + let alert: &dyn AlertTrait = match &alert.alert_type { + AlertType::Threshold => { + threshold_alert = ThresholdAlert::from(alert); + &threshold_alert + } + AlertType::Anomaly => { + return Err(AlertError::CustomError( + "Get Parseable Enterprise for Anomaly alerts".into(), + )); + } + }; + + let guard = ALERTS.write().await; + let alerts = if let Some(alerts) = guard.as_ref() { + alerts + } else { + return Err(AlertError::CustomError("No AlertManager set".into())); + }; + // validate the incoming alert query // does the user have access to these tables or not? let session_key = extract_session_key_from_req(&req)?; - alert.validate(session_key).await?; + alert.validate(&session_key).await?; // now that we've validated that the user can run this query // move on to saving the alert in ObjectStore - ALERTS.update(&alert).await; + alerts.update(alert).await; - let path = alert_json_path(alert.id); + let path = alert_json_path(*alert.get_id()); let store = PARSEABLE.storage.get_object_store(); - let alert_bytes = serde_json::to_vec(&alert)?; + let alert_bytes = serde_json::to_vec(&alert.to_alert_config())?; store.put_object(&path, Bytes::from(alert_bytes)).await?; // start the task - ALERTS.start_task(alert.clone()).await?; + alerts.start_task(alert.clone_box()).await?; - Ok(web::Json(alert)) + Ok(web::Json(alert.to_alert_config())) } // GET /alerts/{alert_id} @@ -96,7 +122,14 @@ pub async fn get(req: HttpRequest, alert_id: Path) -> Result) -> Result) -> Result Result { - let tags = ALERTS.list_tags().await; + let guard = ALERTS.read().await; + let alerts = if let Some(alerts) = guard.as_ref() { + alerts + } else { + return Err(AlertError::CustomError("No AlertManager set".into())); + }; + let tags = alerts.list_tags().await; Ok(web::Json(tags)) } diff --git a/src/handlers/http/modal/mod.rs b/src/handlers/http/modal/mod.rs index b925c4ac0..0b57f8602 100644 --- a/src/handlers/http/modal/mod.rs +++ b/src/handlers/http/modal/mod.rs @@ -34,7 +34,7 @@ use tokio::sync::oneshot; use tracing::{error, info, warn}; use crate::{ - alerts::{ALERTS, target::TARGETS}, + alerts::{ALERTS, get_alert_manager, target::TARGETS}, cli::Options, correlation::CORRELATIONS, oidc::Claims, @@ -183,7 +183,16 @@ pub async fn load_on_init() -> anyhow::Result<()> { }, async { FILTERS.load().await.context("Failed to load filters") }, async { DASHBOARDS.load().await.context("Failed to load dashboards") }, - async { ALERTS.load().await.context("Failed to load alerts") }, + async { + get_alert_manager().await; + let guard = ALERTS.write().await; + let alerts = if let Some(alerts) = guard.as_ref() { + alerts + } else { + return Err(anyhow::Error::msg("No AlertManager set")); + }; + alerts.load().await + }, async { TARGETS.load().await.context("Failed to load targets") }, ) .await; diff --git a/src/handlers/http/modal/server.rs b/src/handlers/http/modal/server.rs index 88eb51a11..3e8170e41 100644 --- a/src/handlers/http/modal/server.rs +++ b/src/handlers/http/modal/server.rs @@ -255,6 +255,13 @@ impl Server { .route(web::get().to(alerts::list).authorize(Action::GetAlert)) .route(web::post().to(alerts::post).authorize(Action::PutAlert)), ) + .service( + web::resource("/list_tags").route( + web::get() + .to(alerts::list_tags) + .authorize(Action::ListDashboard), + ), + ) .service( web::resource("/{alert_id}") .route(web::get().to(alerts::get).authorize(Action::GetAlert)) @@ -269,13 +276,6 @@ impl Server { .authorize(Action::DeleteAlert), ), ) - .service( - web::resource("/list_tags").route( - web::get() - .to(alerts::list_tags) - .authorize(Action::ListDashboard), - ), - ) } pub fn get_targets_webscope() -> Scope { diff --git a/src/lib.rs b/src/lib.rs index ad82ae4e2..f34a62448 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -52,6 +52,7 @@ mod validator; use std::time::Duration; +pub use datafusion; pub use handlers::http::modal::{ ParseableServer, ingest_server::IngestServer, query_server::QueryServer, server::Server, }; diff --git a/src/prism/home/mod.rs b/src/prism/home/mod.rs index 81a2342f2..2d9801d56 100644 --- a/src/prism/home/mod.rs +++ b/src/prism/home/mod.rs @@ -340,7 +340,15 @@ async fn get_alert_titles( key: &SessionKey, query_value: &str, ) -> Result, PrismHomeError> { - let alerts = ALERTS + let guard = ALERTS.read().await; + let alerts = if let Some(alerts) = guard.as_ref() { + alerts + } else { + return Err(PrismHomeError::AlertError(AlertError::CustomError( + "No AlertManager set".into(), + ))); + }; + let alerts = alerts .list_alerts_for_user(key.clone(), vec![]) .await? .iter() diff --git a/src/sync.rs b/src/sync.rs index 93d643c4f..bfc6ba88d 100644 --- a/src/sync.rs +++ b/src/sync.rs @@ -295,18 +295,18 @@ pub async fn alert_runtime(mut rx: mpsc::Receiver) -> Result<(), anyh match task { AlertTask::Create(alert) => { // check if the alert already exists - if alert_tasks.contains_key(&alert.id) { - error!("Alert with id {} already exists", alert.id); + if alert_tasks.contains_key(alert.get_id()) { + error!("Alert with id {} already exists", alert.get_id()); continue; } - let alert = alert.clone(); - let id = alert.id; + let alert = alert.clone_box(); + let id = *alert.get_id(); let handle = tokio::spawn(async move { let mut retry_counter = 0; let mut sleep_duration = alert.get_eval_frequency(); loop { - match alerts_utils::evaluate_alert(&alert).await { + match alerts_utils::evaluate_alert(&*alert).await { Ok(_) => { retry_counter = 0; }