diff --git a/src/core/player.rs b/src/core/player.rs index 3457a88..0a2a8b7 100644 --- a/src/core/player.rs +++ b/src/core/player.rs @@ -227,8 +227,7 @@ pub enum Updates { } /// Handle to a player registry — a shared data structure containing information about players. -#[derive(Clone)] -pub struct PlayerRegistry(Arc>); +pub struct PlayerRegistry(RwLock); impl PlayerRegistry { pub fn empty( room_registry: RoomRegistry, @@ -242,10 +241,10 @@ impl PlayerRegistry { players: HashMap::new(), metric_active_players, }; - Ok(PlayerRegistry(Arc::new(RwLock::new(inner)))) + Ok(PlayerRegistry(RwLock::new(inner))) } - pub async fn get_or_create_player(&mut self, id: PlayerId) -> PlayerHandle { + pub async fn get_or_create_player(&self, id: PlayerId) -> PlayerHandle { let mut inner = self.0.write().unwrap(); if let Some((handle, _)) = inner.players.get(&id) { handle.clone() @@ -257,7 +256,7 @@ impl PlayerRegistry { } } - pub async fn connect_to_player(&mut self, id: PlayerId) -> PlayerConnection { + pub async fn connect_to_player(&self, id: PlayerId) -> PlayerConnection { let player_handle = self.get_or_create_player(id).await; player_handle.subscribe().await } diff --git a/src/core/repo/mod.rs b/src/core/repo/mod.rs index 10b601a..261c164 100644 --- a/src/core/repo/mod.rs +++ b/src/core/repo/mod.rs @@ -33,7 +33,7 @@ impl Storage { Ok(Storage { conn }) } - pub async fn retrieve_user_by_name(&mut self, name: &str) -> Result> { + pub async fn retrieve_user_by_name(&self, name: &str) -> Result> { let mut executor = self.conn.lock().await; let res = sqlx::query_as( "select u.id, u.name, c.password diff --git a/src/main.rs b/src/main.rs index dadb8ba..6b081ad 100644 --- a/src/main.rs +++ b/src/main.rs @@ -55,11 +55,10 @@ async fn main() -> Result<()> { let storage = Storage::open(storage_config).await?; let rooms = RoomRegistry::new(&mut metrics, storage.clone())?; let mut players = PlayerRegistry::empty(rooms.clone(), &mut metrics)?; - let telemetry_terminator = - util::telemetry::launch(telemetry_config, metrics.clone(), rooms.clone()).await?; // unsafe: outer future is never dropped, scope is joined on `scope.collect` let mut scope = unsafe { Scope::create() }; + let telemetry_terminator = util::telemetry::launch(telemetry_config, &metrics, &rooms, &mut scope).await?; let irc = projections::irc::launch(&irc_config, &players, &rooms, &metrics, &storage, &mut scope).await?; let xmpp = projections::xmpp::launch(xmpp_config, &players, &rooms, &metrics, &mut scope).await?; tracing::info!("Started"); @@ -69,10 +68,10 @@ async fn main() -> Result<()> { tracing::info!("Begin shutdown"); let _ = xmpp.send(()); let _ = irc.send(()); + let _ = telemetry_terminator.send(()); let _ = scope.collect().await; drop(scope); - telemetry_terminator.terminate().await?; players.shutdown_all().await?; drop(players); drop(rooms); diff --git a/src/projections/irc/mod.rs b/src/projections/irc/mod.rs index 2f74919..13477f9 100644 --- a/src/projections/irc/mod.rs +++ b/src/projections/irc/mod.rs @@ -2,7 +2,6 @@ use std::collections::HashMap; use std::net::SocketAddr; use futures_util::FutureExt; -use futures_util::future::join_all; use prometheus::{IntCounter, IntGauge, Registry as MetricsRegistry}; use serde::Deserialize; use tokio::io::{AsyncBufReadExt, AsyncWrite, AsyncWriteExt, BufReader, BufWriter}; @@ -17,7 +16,6 @@ use crate::prelude::*; use crate::protos::irc::client::{client_message, ClientMessage}; use crate::protos::irc::server::{AwayStatus, ServerMessage, ServerMessageBody}; use crate::protos::irc::{Chan, Recipient}; -use crate::util::Terminator; #[cfg(test)] mod test; @@ -44,10 +42,10 @@ async fn handle_socket( config: ServerConfig, mut stream: TcpStream, socket_addr: &SocketAddr, - players: PlayerRegistry, - rooms: RoomRegistry, + players: &PlayerRegistry, + rooms: &RoomRegistry, termination: Deferred<()>, // TODO use it to stop the connection gracefully - mut storage: Storage, + storage: &Storage, ) -> Result<()> { let (reader, writer) = stream.split(); let mut reader: BufReader = BufReader::new(reader); @@ -67,7 +65,7 @@ async fn handle_socket( writer.flush().await?; let registered_user: Result = - handle_registration(&mut reader, &mut writer, &mut storage).await; + handle_registration(&mut reader, &mut writer, &storage).await; match registered_user { Ok(user) => { @@ -84,7 +82,7 @@ async fn handle_socket( async fn handle_registration<'a>( reader: &mut BufReader>, writer: &mut BufWriter>, - storage: &mut Storage, + storage: &Storage, ) -> Result { let mut buffer = vec![]; @@ -177,8 +175,8 @@ async fn handle_registration<'a>( async fn handle_registered_socket<'a>( config: ServerConfig, - mut players: PlayerRegistry, - rooms: RoomRegistry, + players: &PlayerRegistry, + rooms: &RoomRegistry, reader: &mut BufReader>, writer: &mut BufWriter>, user: RegisteredUser, @@ -716,6 +714,7 @@ pub async fn launch<'a>( // TODO probably should separate logic for accepting new connection and storing them // into two tasks so that they don't block each other let mut actors = HashMap::new(); + let mut scope = unsafe { Scope::create() }; loop { select! { biased; @@ -738,23 +737,22 @@ pub async fn launch<'a>( continue; } - let terminator = Terminator::spawn(|termination| { - let players = players.clone(); - let rooms = rooms.clone(); - let current_connections_clone = current_connections.clone(); - let stopped_tx = stopped_tx.clone(); - let storage = storage.clone(); - async move { - match handle_socket(config, stream, &socket_addr, players, rooms, termination, storage).await { - Ok(_) => log::info!("Connection terminated"), - Err(err) => log::warn!("Connection failed: {err}"), - } - current_connections_clone.dec(); - stopped_tx.send(socket_addr).await?; - Ok(()) + + let current_connections_clone = current_connections.clone(); + let stopped_tx = stopped_tx.clone(); + + let (a, rx) = oneshot(); + let future = async move { + match handle_socket(config, stream, &socket_addr, players, rooms, rx, storage).await { + Ok(_) => log::info!("Connection terminated"), + Err(err) => log::warn!("Connection failed: {err}"), } - }); - actors.insert(socket_addr, terminator); + current_connections_clone.dec(); + stopped_tx.send(socket_addr).await?; + Ok(()) + }; + scope.spawn(future.map(|_: Result<()>| ())); + actors.insert(socket_addr, a); }, Err(err) => log::warn!("Failed to accept new connection: {err}"), } @@ -763,15 +761,14 @@ pub async fn launch<'a>( } log::info!("Stopping IRC projection"); - join_all(actors.into_iter().map(|(socket_addr, terminator)| async move { - log::debug!("Stopping IRC connection at {socket_addr}"); - match terminator.terminate().await { - Ok(_) => log::debug!("Stopped IRC connection at {socket_addr}"), - Err(err) => { - log::warn!("IRC connection to {socket_addr} finished with error: {err}") - } + for (socket_addr, terminator) in actors { + match terminator.send(()) { + Ok(_) => log::debug!("Stopping IRC connection at {socket_addr}"), + Err(_) => log::debug!("IRC connection at {socket_addr} already stopped") } - })).await; + } + let _ = scope.collect().await; + drop(scope); log::info!("Stopped IRC projection"); Ok(()) }; diff --git a/src/projections/xmpp/mod.rs b/src/projections/xmpp/mod.rs index 9260c04..5382a88 100644 --- a/src/projections/xmpp/mod.rs +++ b/src/projections/xmpp/mod.rs @@ -8,7 +8,6 @@ use std::path::PathBuf; use std::sync::Arc; use futures_util::FutureExt; -use futures_util::future::join_all; use prometheus::Registry as MetricsRegistry; use quick_xml::events::{BytesDecl, Event}; use quick_xml::{NsReader, Writer}; @@ -31,7 +30,6 @@ use crate::protos::xmpp::roster::RosterQuery; use crate::protos::xmpp::session::Session; use crate::protos::xmpp::stream::*; use crate::util::xml::{Continuation, FromXml, Parser, ToXml}; -use crate::util::Terminator; use self::proto::{ClientPacket, IqClientBody}; @@ -82,6 +80,7 @@ pub async fn launch<'a>( let future = async move { let (stopped_tx, mut stopped_rx) = channel(32); let mut actors = HashMap::new(); + let mut scope = unsafe { Scope::create() }; loop { select! { biased; @@ -99,22 +98,20 @@ pub async fn launch<'a>( // TODO kill the older connection and restart it continue; } - let players = players.clone(); - let rooms = rooms.clone(); - let terminator = Terminator::spawn(|termination| { - let stopped_tx = stopped_tx.clone(); - let loaded_config = loaded_config.clone(); - async move { - match handle_socket(loaded_config, stream, &socket_addr, players, rooms, termination).await { + let stopped_tx = stopped_tx.clone(); + let loaded_config = loaded_config.clone(); + let (a, rx) = oneshot(); + let future = async move { + match handle_socket(loaded_config, stream, &socket_addr, players, rooms, rx).await { Ok(_) => log::info!("Connection terminated"), Err(err) => log::warn!("Connection failed: {err}"), } stopped_tx.send(socket_addr).await?; Ok(()) - } - }); + }; + scope.spawn(future.map(|_: Result<()>| ())); - actors.insert(socket_addr, terminator); + actors.insert(socket_addr, a); }, Err(err) => log::warn!("Failed to accept new connection: {err}"), } @@ -122,22 +119,14 @@ pub async fn launch<'a>( } } log::info!("Stopping XMPP projection"); - join_all( - actors - .into_iter() - .map(|(socket_addr, terminator)| async move { - log::debug!("Stopping XMPP connection at {socket_addr}"); - match terminator.terminate().await { - Ok(_) => log::debug!("Stopped XMPP connection at {socket_addr}"), - Err(err) => { - log::warn!( - "XMPP connection to {socket_addr} finished with error: {err}" - ) - } - } - }), - ) - .await; + for (socket_addr, terminator) in actors { + match terminator.send(()) { + Ok(_) => log::debug!("Stopping XMPP connection at {socket_addr}"), + Err(_) => log::debug!("XMPP connection at {socket_addr} already stopped") + } + } + let _ = scope.collect().await; + drop(scope); log::info!("Stopped XMPP projection"); Ok(()) }; @@ -150,8 +139,8 @@ async fn handle_socket( config: Arc, mut stream: TcpStream, socket_addr: &SocketAddr, - mut players: PlayerRegistry, - rooms: RoomRegistry, + players: &PlayerRegistry, + rooms: &RoomRegistry, termination: Deferred<()>, // TODO use it to stop the connection gracefully ) -> Result<()> { log::debug!("Received an XMPP connection from {socket_addr}"); diff --git a/src/util/telemetry.rs b/src/util/telemetry.rs index 26387dd..1994112 100644 --- a/src/util/telemetry.rs +++ b/src/util/telemetry.rs @@ -15,7 +15,6 @@ use crate::core::room::RoomRegistry; use crate::prelude::*; use crate::util::http::*; -use crate::util::Terminator; type BoxBody = http_body_util::combinators::BoxBody; type HttpResult = std::result::Result; @@ -25,51 +24,48 @@ pub struct ServerConfig { pub listen_on: SocketAddr, } -pub async fn launch( +pub async fn launch<'a>( config: ServerConfig, - metrics: MetricsRegistry, - rooms: RoomRegistry, -) -> Result { + metrics: &'a MetricsRegistry, + rooms: &'a RoomRegistry, + scope: &mut Scope<'a>, +) -> Result> { log::info!("Starting the telemetry service"); let listener = TcpListener::bind(config.listen_on).await?; log::debug!("Listener started"); - let terminator = Terminator::spawn(|rx| main_loop(listener, metrics, rooms, rx.map(|_| ()))); - Ok(terminator) -} -async fn main_loop( - listener: TcpListener, - metrics: MetricsRegistry, - rooms: RoomRegistry, - termination: impl Future, -) -> Result<()> { - pin!(termination); - loop { - select! { - biased; - _ = &mut termination => break, - result = listener.accept() => { - let (stream, _) = result?; - let metrics = metrics.clone(); - let rooms = rooms.clone(); - tokio::task::spawn(async move { - let registry = metrics.clone(); - let rooms = rooms.clone(); - let server = http1::Builder::new().serve_connection(stream, service_fn(move |r| route(registry.clone(), rooms.clone(), r))); - if let Err(err) = server.await { - tracing::error!("Error serving connection: {:?}", err); - } - }); - }, + let (signal, mut rx) = oneshot(); + + let future = async move { + let mut scope = unsafe { Scope::create() }; + loop { + select! { + biased; + _ = &mut rx => break, + result = listener.accept() => { + let (stream, _) = result?; + scope.spawn(async move { + let server = http1::Builder::new().serve_connection(stream, service_fn(move |r| route(metrics, rooms, r))); + if let Err(err) = server.await { + tracing::error!("Error serving connection: {:?}", err); + } + }); + }, + } } - } - log::info!("Terminating the telemetry service"); - Ok(()) + let _ = scope.collect().await; + drop(scope); + log::info!("Terminating the telemetry service"); + Ok(()) + }; + scope.spawn(future.map(|_: Result<()>| ())); + + Ok(signal) } async fn route( - registry: MetricsRegistry, - rooms: RoomRegistry, + registry: &MetricsRegistry, + rooms: &RoomRegistry, request: Request, ) -> std::result::Result, Infallible> { match (request.method(), request.uri().path()) { @@ -79,7 +75,7 @@ async fn route( } } -fn endpoint_metrics(registry: MetricsRegistry) -> HttpResult>> { +fn endpoint_metrics(registry: &MetricsRegistry) -> HttpResult>> { let mf = registry.gather(); let mut buffer = vec![]; TextEncoder @@ -88,7 +84,7 @@ fn endpoint_metrics(registry: MetricsRegistry) -> HttpResult HttpResult>> { +async fn endpoint_rooms(rooms: &RoomRegistry) -> HttpResult>> { let room_list = rooms.get_all_rooms().await; let mut buffer = vec![]; serde_json::to_writer(&mut buffer, &room_list).expect("unexpected fail when writing to vec");