diff --git a/crates/lavina-core/src/player.rs b/crates/lavina-core/src/player.rs index 6dc65de..30635b0 100644 --- a/crates/lavina-core/src/player.rs +++ b/crates/lavina-core/src/player.rs @@ -55,7 +55,7 @@ pub struct ConnectionId(pub AnonKey); /// The connection is used to send commands to the player actor and to receive updates that might be sent to the client. pub struct PlayerConnection { pub connection_id: ConnectionId, - pub receiver: Receiver, + pub receiver: Receiver, player_handle: PlayerHandle, } impl PlayerConnection { @@ -160,7 +160,7 @@ impl PlayerHandle { enum ActorCommand { /// Establish a new connection. AddConnection { - sender: Sender, + sender: Sender, promise: Promise, }, /// Terminate an existing connection. @@ -276,11 +276,27 @@ impl PlayerRegistry { Ok(()) } + #[tracing::instrument(skip(self), name = "PlayerRegistry::get_player")] pub async fn get_player(&self, id: &PlayerId) -> Option { let inner = self.0.read().await; inner.players.get(id).map(|(handle, _)| handle.clone()) } + #[tracing::instrument(skip(self), name = "PlayerRegistry::stop_player")] + pub async fn stop_player(&self, id: &PlayerId) -> Result> { + let mut inner = self.0.write().await; + if let Some((handle, fiber)) = inner.players.remove(id) { + handle.send(ActorCommand::Stop).await; + drop(handle); + fiber.await?; + inner.metric_active_players.dec(); + Ok(Some(())) + } else { + Ok(None) + } + } + + #[tracing::instrument(skip(self), name = "PlayerRegistry::get_or_launch_player")] pub async fn get_or_launch_player(&mut self, id: &PlayerId) -> PlayerHandle { let inner = self.0.read().await; if let Some((handle, _)) = inner.players.get(id) { @@ -305,6 +321,7 @@ impl PlayerRegistry { } } + #[tracing::instrument(skip(self), name = "PlayerRegistry::connect_to_player")] pub async fn connect_to_player(&mut self, id: &PlayerId) -> PlayerConnection { let player_handle = self.get_or_launch_player(id).await; player_handle.subscribe().await @@ -337,7 +354,7 @@ struct PlayerRegistryInner { struct Player { player_id: PlayerId, storage_id: u32, - connections: AnonTable>, + connections: AnonTable>, my_rooms: HashMap, banned_from: HashSet, rx: Receiver<(ActorCommand, Span)>, @@ -438,7 +455,7 @@ impl Player { _ => {} } for (_, connection) in &self.connections { - let _ = connection.send(update.clone()).await; + let _ = connection.send(ConnectionMessage::Update(update.clone())).await; } } @@ -589,7 +606,18 @@ impl Player { if ConnectionId(a) == except { continue; } - let _ = b.send(update.clone()).await; + let _ = b.send(ConnectionMessage::Update(update.clone())).await; } } } + +pub enum ConnectionMessage { + Update(Updates), + Stop(StopReason), +} + +#[derive(Debug)] +pub enum StopReason { + ServerShutdown, + InternalError, +} diff --git a/crates/mgmt-api/src/lib.rs b/crates/mgmt-api/src/lib.rs index cfe5b69..c21ff85 100644 --- a/crates/mgmt-api/src/lib.rs +++ b/crates/mgmt-api/src/lib.rs @@ -11,6 +11,11 @@ pub struct CreatePlayerRequest<'a> { pub name: &'a str, } +#[derive(Serialize, Deserialize)] +pub struct StopPlayerRequest<'a> { + pub name: &'a str, +} + #[derive(Serialize, Deserialize)] pub struct ChangePasswordRequest<'a> { pub player_name: &'a str, @@ -19,6 +24,7 @@ pub struct ChangePasswordRequest<'a> { pub mod paths { pub const CREATE_PLAYER: &'static str = "/mgmt/create_player"; + pub const STOP_PLAYER: &'static str = "/mgmt/stop_player"; pub const SET_PASSWORD: &'static str = "/mgmt/set_password"; } diff --git a/crates/projection-irc/src/lib.rs b/crates/projection-irc/src/lib.rs index a320c5d..2a310a1 100644 --- a/crates/projection-irc/src/lib.rs +++ b/crates/projection-irc/src/lib.rs @@ -507,11 +507,18 @@ async fn handle_registered_socket<'a>( buffer.clear(); }, update = connection.receiver.recv() => { - if let Some(update) = update { - handle_update(&config, &user, &player_id, writer, &rooms, update).await?; - } else { - log::warn!("Player is terminated, must terminate the connection"); - break; + match update { + Some(ConnectionMessage::Update(update)) => { + handle_update(&config, &user, &player_id, writer, &rooms, update).await?; + } + Some(ConnectionMessage::Stop(_)) => { + tracing::debug!("Connection is being terminated"); + break; + } + None => { + log::warn!("Player is terminated, must terminate the connection"); + break; + } } } } diff --git a/crates/projection-xmpp/src/lib.rs b/crates/projection-xmpp/src/lib.rs index fe56481..eec6fc3 100644 --- a/crates/projection-xmpp/src/lib.rs +++ b/crates/projection-xmpp/src/lib.rs @@ -23,7 +23,7 @@ use tokio_rustls::rustls::{Certificate, PrivateKey}; use tokio_rustls::TlsAcceptor; use lavina_core::auth::{Authenticator, Verdict}; -use lavina_core::player::{PlayerConnection, PlayerId, PlayerRegistry}; +use lavina_core::player::{ConnectionMessage, PlayerConnection, PlayerId, PlayerRegistry, StopReason}; use lavina_core::prelude::*; use lavina_core::repo::Storage; use lavina_core::room::RoomRegistry; @@ -31,6 +31,7 @@ use lavina_core::terminator::Terminator; use lavina_core::LavinaCore; use proto_xmpp::bind::{Name, Resource}; use proto_xmpp::stream::*; +use proto_xmpp::streamerror::{StreamError, StreamErrorKind}; use proto_xmpp::xml::{Continuation, FromXml, Parser, ToXml}; use sasl::AuthBody; @@ -395,16 +396,41 @@ async fn socket_final( true }, update = conn.user_handle.receiver.recv() => { - if let Some(update) = update { - conn.handle_update(&mut events, update).await?; - for i in &events { - xml_writer.write_event_async(i).await?; + match update { + Some(ConnectionMessage::Update(update)) => { + conn.handle_update(&mut events, update).await?; + for i in &events { + xml_writer.write_event_async(i).await?; + } + events.clear(); + xml_writer.get_mut().flush().await?; + } + Some(ConnectionMessage::Stop(reason)) => { + tracing::debug!("Connection is being terminated: {reason:?}"); + let kind = match reason { + StopReason::ServerShutdown => StreamErrorKind::SystemShutdown, + StopReason::InternalError => StreamErrorKind::InternalServerError, + }; + StreamError { kind }.serialize(&mut events); + ServerStreamEnd.serialize(&mut events); + for i in &events { + xml_writer.write_event_async(i).await?; + } + events.clear(); + xml_writer.get_mut().flush().await?; + break; + } + None => { + log::error!("Player is terminated, must terminate the connection"); + StreamError { kind: StreamErrorKind::SystemShutdown }.serialize(&mut events); + ServerStreamEnd.serialize(&mut events); + for i in &events { + xml_writer.write_event_async(i).await?; + } + events.clear(); + xml_writer.get_mut().flush().await?; + break; } - events.clear(); - xml_writer.get_mut().flush().await?; - } else { - log::warn!("Player is terminated, must terminate the connection"); - break; } false } diff --git a/crates/proto-xmpp/src/lib.rs b/crates/proto-xmpp/src/lib.rs index 1e97a31..d3e25ba 100644 --- a/crates/proto-xmpp/src/lib.rs +++ b/crates/proto-xmpp/src/lib.rs @@ -10,6 +10,7 @@ pub mod sasl; pub mod session; pub mod stanzaerror; pub mod stream; +pub mod streamerror; pub mod tls; pub mod xml; diff --git a/crates/proto-xmpp/src/streamerror.rs b/crates/proto-xmpp/src/streamerror.rs new file mode 100644 index 0000000..0ba71cd --- /dev/null +++ b/crates/proto-xmpp/src/streamerror.rs @@ -0,0 +1,41 @@ +use crate::xml::ToXml; +use quick_xml::events::{BytesEnd, BytesStart, Event}; + +/// Stream error condition +/// +/// [Spec](https://xmpp.org/rfcs/rfc6120.html#streams-error-conditions). +pub enum StreamErrorKind { + /// The server has experienced a misconfiguration or other internal error that prevents it from servicing the stream. + InternalServerError, + /// The server is being shut down and all active streams are being closed. + SystemShutdown, +} +impl StreamErrorKind { + pub fn from_str(s: &str) -> Option { + match s { + "internal-server-error" => Some(Self::InternalServerError), + "system-shutdown" => Some(Self::SystemShutdown), + _ => None, + } + } + pub fn as_str(&self) -> &'static str { + match self { + Self::InternalServerError => "internal-server-error", + Self::SystemShutdown => "system-shutdown", + } + } +} + +pub struct StreamError { + pub kind: StreamErrorKind, +} +impl ToXml for StreamError { + fn serialize(&self, events: &mut Vec>) { + events.push(Event::Start(BytesStart::new("stream:error"))); + events.push(Event::Empty(BytesStart::new(format!( + r#"{} xmlns="urn:ietf:params:xml:ns:xmpp-streams""#, + self.kind.as_str() + )))); + events.push(Event::End(BytesEnd::new("stream:error"))); + } +} diff --git a/src/http.rs b/src/http.rs index b39a6f2..ae64676 100644 --- a/src/http.rs +++ b/src/http.rs @@ -13,6 +13,7 @@ use serde::{Deserialize, Serialize}; use tokio::net::TcpListener; use lavina_core::auth::{Authenticator, UpdatePasswordResult}; +use lavina_core::player::{PlayerId, PlayerRegistry}; use lavina_core::prelude::*; use lavina_core::repo::Storage; use lavina_core::room::RoomRegistry; @@ -85,8 +86,9 @@ async fn route( (&Method::GET, "/metrics") => endpoint_metrics(registry), (&Method::GET, "/rooms") => endpoint_rooms(core.rooms).await, (&Method::POST, paths::CREATE_PLAYER) => endpoint_create_player(request, storage).await.or5xx(), + (&Method::POST, paths::STOP_PLAYER) => endpoint_stop_player(request, core.players).await.or5xx(), (&Method::POST, paths::SET_PASSWORD) => endpoint_set_password(request, storage).await.or5xx(), - _ => not_found(), + _ => endpoint_not_found(), }; Ok(res) } @@ -98,6 +100,7 @@ fn endpoint_metrics(registry: MetricsRegistry) -> Response> { Response::new(Full::new(Bytes::from(buffer))) } +#[tracing::instrument(skip_all)] async fn endpoint_rooms(rooms: RoomRegistry) -> Response> { // TODO introduce management API types independent from core-domain types // TODO remove `Serialize` implementations from all core-domain types @@ -105,6 +108,7 @@ async fn endpoint_rooms(rooms: RoomRegistry) -> Response> { Response::new(room_list) } +#[tracing::instrument(skip_all)] async fn endpoint_create_player( request: Request, mut storage: Storage, @@ -120,6 +124,27 @@ async fn endpoint_create_player( Ok(response) } +#[tracing::instrument(skip_all)] +async fn endpoint_stop_player( + request: Request, + players: PlayerRegistry, +) -> Result>> { + let str = request.collect().await?.to_bytes(); + let Ok(res) = serde_json::from_slice::(&str[..]) else { + return Ok(malformed_request()); + }; + let Ok(player_id) = PlayerId::from(res.name) else { + return Ok(player_not_found()); + }; + let Some(()) = players.stop_player(&player_id).await? else { + return Ok(player_not_found()); + }; + let mut response = Response::new(Full::::default()); + *response.status_mut() = StatusCode::NO_CONTENT; + Ok(response) +} + +#[tracing::instrument(skip_all)] async fn endpoint_set_password( request: Request, storage: Storage, @@ -132,14 +157,7 @@ async fn endpoint_set_password( match verdict { UpdatePasswordResult::PasswordUpdated => {} UpdatePasswordResult::UserNotFound => { - let payload = ErrorResponse { - code: errors::PLAYER_NOT_FOUND, - message: "No such player exists", - } - .to_body(); - let mut response = Response::new(payload); - *response.status_mut() = StatusCode::UNPROCESSABLE_ENTITY; - return Ok(response); + return Ok(player_not_found()); } } let mut response = Response::new(Full::::default()); @@ -147,7 +165,7 @@ async fn endpoint_set_password( Ok(response) } -pub fn not_found() -> Response> { +fn endpoint_not_found() -> Response> { let payload = ErrorResponse { code: errors::INVALID_PATH, message: "The path does not exist", @@ -159,6 +177,17 @@ pub fn not_found() -> Response> { response } +fn player_not_found() -> Response> { + let payload = ErrorResponse { + code: errors::PLAYER_NOT_FOUND, + message: "No such player exists", + } + .to_body(); + let mut response = Response::new(payload); + *response.status_mut() = StatusCode::UNPROCESSABLE_ENTITY; + response +} + fn malformed_request() -> Response> { let payload = ErrorResponse { code: errors::MALFORMED_REQUEST, @@ -174,6 +203,7 @@ fn malformed_request() -> Response> { trait Or5xx { fn or5xx(self) -> Response>; } + impl Or5xx for Result>> { fn or5xx(self) -> Response> { self.unwrap_or_else(|e| { @@ -187,6 +217,7 @@ impl Or5xx for Result>> { trait ToBody { fn to_body(&self) -> Full; } + impl ToBody for T where T: Serialize,