Skip to content

Commit 1bc5d11

Browse files
wojcik91Maciej Wójcik
and
Maciej Wójcik
authored
fix VPN client disconnect event (#1234)
* refactor how clients are disconnected * avoid marking inactive peers as connected --------- Co-authored-by: Maciej Wójcik <[email protected]>
1 parent 4a2975c commit 1bc5d11

File tree

2 files changed

+95
-51
lines changed

2 files changed

+95
-51
lines changed

crates/defguard_core/src/grpc/gateway/client_state.rs

Lines changed: 22 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -166,33 +166,30 @@ impl ClientMap {
166166
let mut disconnected_clients = Vec::new();
167167

168168
// get client state map for given location
169-
let location_map = match self.0.get_mut(&location_id) {
170-
Some(location_map) => location_map,
171-
None => {
172-
return Err(ClientMapError::LocationNotFound { location_id });
173-
}
169+
if let Some(location_map) = self.0.get_mut(&location_id) {
170+
let disconnect_threshold = TimeDelta::seconds(peer_disconnect_threshold_secs.into());
171+
172+
// remove clients which have been inactive longer than given location's `peer_disconnect_threshold`
173+
location_map.retain(|public_key, client_state| {
174+
let now = Utc::now().naive_utc();
175+
if (now - client_state.latest_handshake) > disconnect_threshold {
176+
debug!("VPN client's {public_key} ({}, ID {}) latest handshake ({}) was more than {peer_disconnect_threshold_secs} seconds ago. Marking VPN client as disconnected", client_state.device.name, client_state.device.id, client_state.latest_handshake);
177+
let disconnect_event_context = GrpcRequestContext::new(
178+
client_state.user_id,
179+
client_state.username.clone(),
180+
client_state.endpoint.ip(),
181+
client_state.device.id,
182+
client_state.device.name.clone(),
183+
);
184+
disconnected_clients
185+
.push((client_state.device.clone(), disconnect_event_context));
186+
187+
return false;
188+
};
189+
true
190+
});
174191
};
175192

176-
let disconnect_threshold = TimeDelta::seconds(peer_disconnect_threshold_secs.into());
177-
178-
// remove clients which have been inactive longer than given location's `peer_disconnect_threshold`
179-
location_map.retain(|_public_key, client_state| {
180-
let now = Utc::now().naive_utc();
181-
if (now - client_state.latest_update) > disconnect_threshold {
182-
let disconnect_event_context = GrpcRequestContext::new(
183-
client_state.user_id,
184-
client_state.username.clone(),
185-
client_state.endpoint.ip(),
186-
client_state.device.id,
187-
client_state.device.name.clone(),
188-
);
189-
disconnected_clients.push((client_state.device.clone(), disconnect_event_context));
190-
191-
return false;
192-
};
193-
true
194-
});
195-
196193
Ok(disconnected_clients)
197194
}
198195
}

crates/defguard_core/src/grpc/gateway/mod.rs

Lines changed: 73 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use std::{
66
task::{Context, Poll},
77
};
88

9-
use chrono::{DateTime, Utc};
9+
use chrono::{DateTime, TimeDelta, Utc};
1010
use client_state::ClientMap;
1111
use sqlx::{query, Error as SqlxError, PgExecutor, PgPool};
1212
use thiserror::Error;
@@ -16,6 +16,7 @@ use tokio::{
1616
mpsc::{self, error::SendError, Receiver, UnboundedSender},
1717
},
1818
task::JoinHandle,
19+
time::{interval, Duration},
1920
};
2021
use tokio_stream::Stream;
2122
use tonic::{metadata::MetadataMap, Code, Request, Response, Status};
@@ -34,6 +35,8 @@ use crate::{
3435
mail::Mail,
3536
};
3637

38+
const PEER_DISCONNECT_INTERVAL: u64 = 60;
39+
3740
/// Sends given `GatewayEvent` to be handled by gateway GRPC server
3841
///
3942
/// If you want to use it inside the API context, use [`crate::AppState::send_wireguard_event`] instead
@@ -192,6 +195,7 @@ impl GatewayServer {
192195
.client_state
193196
.lock()
194197
.map_err(|_| GatewayServerError::ClientStateMutexError)?;
198+
debug!("Current VPN client state map: {client_state:?}");
195199
Ok(client_state)
196200
}
197201

@@ -726,8 +730,46 @@ impl gateway_service_server::GatewayService for GatewayServer {
726730
let network_id = Self::get_network_id(request.metadata())?;
727731
let gateway_hostname = Self::get_gateway_hostname(request.metadata())?;
728732
let mut stream = request.into_inner();
733+
let mut disconnect_timer = interval(Duration::from_secs(PEER_DISCONNECT_INTERVAL));
734+
735+
loop {
736+
// wait for a message or update client map at least once a mninute if no messages are received
737+
let stats_update = tokio::select! {
738+
message = stream.message() => {
739+
match message? {
740+
Some(update) => update,
741+
None => break, // Stream ended
742+
}
743+
}
744+
_ = disconnect_timer.tick() => {
745+
debug!("No stats updates received in last {PEER_DISCONNECT_INTERVAL} seconds. Updating disconnected VPN clients");
746+
// fetch location to get current peer disconnect threshold
747+
let location = self.fetch_location_from_db(network_id).await?;
748+
749+
// perform client state operations in a dedicated block to drop mutex guard
750+
let disconnected_clients = {
751+
// acquire lock on client state map
752+
let mut client_map = self.get_client_state_guard()?;
753+
754+
// disconnect inactive clients
755+
client_map.disconnect_inactive_vpn_clients_for_location(
756+
network_id,
757+
location.peer_disconnect_threshold,
758+
)?
759+
};
760+
761+
// emit client disconnect events
762+
for (device, context) in disconnected_clients {
763+
self.emit_event(GrpcEvent::ClientDisconnected {
764+
context,
765+
location: location.clone(),
766+
device,
767+
})?;
768+
};
769+
continue;
770+
}
771+
};
729772

730-
while let Some(stats_update) = stream.message().await? {
731773
debug!("Received stats message: {stats_update:?}");
732774
let Some(stats_update::Payload::PeerStats(peer_stats)) = stats_update.payload else {
733775
debug!("Received stats message is empty, skipping.");
@@ -781,30 +823,35 @@ impl gateway_service_server::GatewayService for GatewayServer {
781823
);
782824
}
783825
None => {
784-
// mark new VPN client as connected
785-
client_map.connect_vpn_client(
786-
network_id,
787-
&gateway_hostname,
788-
&public_key,
789-
&device,
790-
&user,
791-
socket_addr,
792-
&stats,
793-
)?;
794-
795-
// emit connection event
796-
let context = GrpcRequestContext::new(
797-
user.id,
798-
user.username.clone(),
799-
socket_addr.ip(),
800-
device.id,
801-
device.name.clone(),
802-
);
803-
self.emit_event(GrpcEvent::ClientConnected {
804-
context,
805-
location: location.clone(),
806-
device: device.clone(),
807-
})?;
826+
// don't mark inactive peers as connected
827+
if (Utc::now().naive_utc() - stats.latest_handshake)
828+
< TimeDelta::seconds(location.peer_disconnect_threshold.into())
829+
{
830+
// mark new VPN client as connected
831+
client_map.connect_vpn_client(
832+
network_id,
833+
&gateway_hostname,
834+
&public_key,
835+
&device,
836+
&user,
837+
socket_addr,
838+
&stats,
839+
)?;
840+
841+
// emit connection event
842+
let context = GrpcRequestContext::new(
843+
user.id,
844+
user.username.clone(),
845+
socket_addr.ip(),
846+
device.id,
847+
device.name.clone(),
848+
);
849+
self.emit_event(GrpcEvent::ClientConnected {
850+
context,
851+
location: location.clone(),
852+
device: device.clone(),
853+
})?;
854+
}
808855
}
809856
};
810857

0 commit comments

Comments
 (0)