player shutdown API (#58)

Reviewed-on: lavina/lavina#58
This commit is contained in:
Nikita Vilunov 2024-04-29 17:24:43 +00:00
parent 31f9da9b05
commit 25605322a0
7 changed files with 170 additions and 30 deletions

View File

@ -55,7 +55,7 @@ pub struct ConnectionId(pub AnonKey);
/// The connection is used to send commands to the player actor and to receive updates that might be sent to the client. /// The connection is used to send commands to the player actor and to receive updates that might be sent to the client.
pub struct PlayerConnection { pub struct PlayerConnection {
pub connection_id: ConnectionId, pub connection_id: ConnectionId,
pub receiver: Receiver<Updates>, pub receiver: Receiver<ConnectionMessage>,
player_handle: PlayerHandle, player_handle: PlayerHandle,
} }
impl PlayerConnection { impl PlayerConnection {
@ -160,7 +160,7 @@ impl PlayerHandle {
enum ActorCommand { enum ActorCommand {
/// Establish a new connection. /// Establish a new connection.
AddConnection { AddConnection {
sender: Sender<Updates>, sender: Sender<ConnectionMessage>,
promise: Promise<ConnectionId>, promise: Promise<ConnectionId>,
}, },
/// Terminate an existing connection. /// Terminate an existing connection.
@ -276,11 +276,27 @@ impl PlayerRegistry {
Ok(()) Ok(())
} }
#[tracing::instrument(skip(self), name = "PlayerRegistry::get_player")]
pub async fn get_player(&self, id: &PlayerId) -> Option<PlayerHandle> { pub async fn get_player(&self, id: &PlayerId) -> Option<PlayerHandle> {
let inner = self.0.read().await; let inner = self.0.read().await;
inner.players.get(id).map(|(handle, _)| handle.clone()) inner.players.get(id).map(|(handle, _)| handle.clone())
} }
#[tracing::instrument(skip(self), name = "PlayerRegistry::stop_player")]
pub async fn stop_player(&self, id: &PlayerId) -> Result<Option<()>> {
let mut inner = self.0.write().await;
if let Some((handle, fiber)) = inner.players.remove(id) {
handle.send(ActorCommand::Stop).await;
drop(handle);
fiber.await?;
inner.metric_active_players.dec();
Ok(Some(()))
} else {
Ok(None)
}
}
#[tracing::instrument(skip(self), name = "PlayerRegistry::get_or_launch_player")]
pub async fn get_or_launch_player(&mut self, id: &PlayerId) -> PlayerHandle { pub async fn get_or_launch_player(&mut self, id: &PlayerId) -> PlayerHandle {
let inner = self.0.read().await; let inner = self.0.read().await;
if let Some((handle, _)) = inner.players.get(id) { if let Some((handle, _)) = inner.players.get(id) {
@ -305,6 +321,7 @@ impl PlayerRegistry {
} }
} }
#[tracing::instrument(skip(self), name = "PlayerRegistry::connect_to_player")]
pub async fn connect_to_player(&mut self, id: &PlayerId) -> PlayerConnection { pub async fn connect_to_player(&mut self, id: &PlayerId) -> PlayerConnection {
let player_handle = self.get_or_launch_player(id).await; let player_handle = self.get_or_launch_player(id).await;
player_handle.subscribe().await player_handle.subscribe().await
@ -337,7 +354,7 @@ struct PlayerRegistryInner {
struct Player { struct Player {
player_id: PlayerId, player_id: PlayerId,
storage_id: u32, storage_id: u32,
connections: AnonTable<Sender<Updates>>, connections: AnonTable<Sender<ConnectionMessage>>,
my_rooms: HashMap<RoomId, RoomHandle>, my_rooms: HashMap<RoomId, RoomHandle>,
banned_from: HashSet<RoomId>, banned_from: HashSet<RoomId>,
rx: Receiver<(ActorCommand, Span)>, rx: Receiver<(ActorCommand, Span)>,
@ -438,7 +455,7 @@ impl Player {
_ => {} _ => {}
} }
for (_, connection) in &self.connections { for (_, connection) in &self.connections {
let _ = connection.send(update.clone()).await; let _ = connection.send(ConnectionMessage::Update(update.clone())).await;
} }
} }
@ -589,7 +606,18 @@ impl Player {
if ConnectionId(a) == except { if ConnectionId(a) == except {
continue; continue;
} }
let _ = b.send(update.clone()).await; let _ = b.send(ConnectionMessage::Update(update.clone())).await;
} }
} }
} }
pub enum ConnectionMessage {
Update(Updates),
Stop(StopReason),
}
#[derive(Debug)]
pub enum StopReason {
ServerShutdown,
InternalError,
}

View File

@ -11,6 +11,11 @@ pub struct CreatePlayerRequest<'a> {
pub name: &'a str, pub name: &'a str,
} }
#[derive(Serialize, Deserialize)]
pub struct StopPlayerRequest<'a> {
pub name: &'a str,
}
#[derive(Serialize, Deserialize)] #[derive(Serialize, Deserialize)]
pub struct ChangePasswordRequest<'a> { pub struct ChangePasswordRequest<'a> {
pub player_name: &'a str, pub player_name: &'a str,
@ -19,6 +24,7 @@ pub struct ChangePasswordRequest<'a> {
pub mod paths { pub mod paths {
pub const CREATE_PLAYER: &'static str = "/mgmt/create_player"; pub const CREATE_PLAYER: &'static str = "/mgmt/create_player";
pub const STOP_PLAYER: &'static str = "/mgmt/stop_player";
pub const SET_PASSWORD: &'static str = "/mgmt/set_password"; pub const SET_PASSWORD: &'static str = "/mgmt/set_password";
} }

View File

@ -507,11 +507,18 @@ async fn handle_registered_socket<'a>(
buffer.clear(); buffer.clear();
}, },
update = connection.receiver.recv() => { update = connection.receiver.recv() => {
if let Some(update) = update { match update {
handle_update(&config, &user, &player_id, writer, &rooms, update).await?; Some(ConnectionMessage::Update(update)) => {
} else { handle_update(&config, &user, &player_id, writer, &rooms, update).await?;
log::warn!("Player is terminated, must terminate the connection"); }
break; Some(ConnectionMessage::Stop(_)) => {
tracing::debug!("Connection is being terminated");
break;
}
None => {
log::warn!("Player is terminated, must terminate the connection");
break;
}
} }
} }
} }

View File

@ -23,7 +23,7 @@ use tokio_rustls::rustls::{Certificate, PrivateKey};
use tokio_rustls::TlsAcceptor; use tokio_rustls::TlsAcceptor;
use lavina_core::auth::{Authenticator, Verdict}; use lavina_core::auth::{Authenticator, Verdict};
use lavina_core::player::{PlayerConnection, PlayerId, PlayerRegistry}; use lavina_core::player::{ConnectionMessage, PlayerConnection, PlayerId, PlayerRegistry, StopReason};
use lavina_core::prelude::*; use lavina_core::prelude::*;
use lavina_core::repo::Storage; use lavina_core::repo::Storage;
use lavina_core::room::RoomRegistry; use lavina_core::room::RoomRegistry;
@ -31,6 +31,7 @@ use lavina_core::terminator::Terminator;
use lavina_core::LavinaCore; use lavina_core::LavinaCore;
use proto_xmpp::bind::{Name, Resource}; use proto_xmpp::bind::{Name, Resource};
use proto_xmpp::stream::*; use proto_xmpp::stream::*;
use proto_xmpp::streamerror::{StreamError, StreamErrorKind};
use proto_xmpp::xml::{Continuation, FromXml, Parser, ToXml}; use proto_xmpp::xml::{Continuation, FromXml, Parser, ToXml};
use sasl::AuthBody; use sasl::AuthBody;
@ -395,16 +396,41 @@ async fn socket_final(
true true
}, },
update = conn.user_handle.receiver.recv() => { update = conn.user_handle.receiver.recv() => {
if let Some(update) = update { match update {
conn.handle_update(&mut events, update).await?; Some(ConnectionMessage::Update(update)) => {
for i in &events { conn.handle_update(&mut events, update).await?;
xml_writer.write_event_async(i).await?; for i in &events {
xml_writer.write_event_async(i).await?;
}
events.clear();
xml_writer.get_mut().flush().await?;
}
Some(ConnectionMessage::Stop(reason)) => {
tracing::debug!("Connection is being terminated: {reason:?}");
let kind = match reason {
StopReason::ServerShutdown => StreamErrorKind::SystemShutdown,
StopReason::InternalError => StreamErrorKind::InternalServerError,
};
StreamError { kind }.serialize(&mut events);
ServerStreamEnd.serialize(&mut events);
for i in &events {
xml_writer.write_event_async(i).await?;
}
events.clear();
xml_writer.get_mut().flush().await?;
break;
}
None => {
log::error!("Player is terminated, must terminate the connection");
StreamError { kind: StreamErrorKind::SystemShutdown }.serialize(&mut events);
ServerStreamEnd.serialize(&mut events);
for i in &events {
xml_writer.write_event_async(i).await?;
}
events.clear();
xml_writer.get_mut().flush().await?;
break;
} }
events.clear();
xml_writer.get_mut().flush().await?;
} else {
log::warn!("Player is terminated, must terminate the connection");
break;
} }
false false
} }

View File

@ -10,6 +10,7 @@ pub mod sasl;
pub mod session; pub mod session;
pub mod stanzaerror; pub mod stanzaerror;
pub mod stream; pub mod stream;
pub mod streamerror;
pub mod tls; pub mod tls;
pub mod xml; pub mod xml;

View File

@ -0,0 +1,41 @@
use crate::xml::ToXml;
use quick_xml::events::{BytesEnd, BytesStart, Event};
/// Stream error condition
///
/// [Spec](https://xmpp.org/rfcs/rfc6120.html#streams-error-conditions).
pub enum StreamErrorKind {
/// The server has experienced a misconfiguration or other internal error that prevents it from servicing the stream.
InternalServerError,
/// The server is being shut down and all active streams are being closed.
SystemShutdown,
}
impl StreamErrorKind {
pub fn from_str(s: &str) -> Option<Self> {
match s {
"internal-server-error" => Some(Self::InternalServerError),
"system-shutdown" => Some(Self::SystemShutdown),
_ => None,
}
}
pub fn as_str(&self) -> &'static str {
match self {
Self::InternalServerError => "internal-server-error",
Self::SystemShutdown => "system-shutdown",
}
}
}
pub struct StreamError {
pub kind: StreamErrorKind,
}
impl ToXml for StreamError {
fn serialize(&self, events: &mut Vec<Event<'static>>) {
events.push(Event::Start(BytesStart::new("stream:error")));
events.push(Event::Empty(BytesStart::new(format!(
r#"{} xmlns="urn:ietf:params:xml:ns:xmpp-streams""#,
self.kind.as_str()
))));
events.push(Event::End(BytesEnd::new("stream:error")));
}
}

View File

@ -13,6 +13,7 @@ use serde::{Deserialize, Serialize};
use tokio::net::TcpListener; use tokio::net::TcpListener;
use lavina_core::auth::{Authenticator, UpdatePasswordResult}; use lavina_core::auth::{Authenticator, UpdatePasswordResult};
use lavina_core::player::{PlayerId, PlayerRegistry};
use lavina_core::prelude::*; use lavina_core::prelude::*;
use lavina_core::repo::Storage; use lavina_core::repo::Storage;
use lavina_core::room::RoomRegistry; use lavina_core::room::RoomRegistry;
@ -85,8 +86,9 @@ async fn route(
(&Method::GET, "/metrics") => endpoint_metrics(registry), (&Method::GET, "/metrics") => endpoint_metrics(registry),
(&Method::GET, "/rooms") => endpoint_rooms(core.rooms).await, (&Method::GET, "/rooms") => endpoint_rooms(core.rooms).await,
(&Method::POST, paths::CREATE_PLAYER) => endpoint_create_player(request, storage).await.or5xx(), (&Method::POST, paths::CREATE_PLAYER) => endpoint_create_player(request, storage).await.or5xx(),
(&Method::POST, paths::STOP_PLAYER) => endpoint_stop_player(request, core.players).await.or5xx(),
(&Method::POST, paths::SET_PASSWORD) => endpoint_set_password(request, storage).await.or5xx(), (&Method::POST, paths::SET_PASSWORD) => endpoint_set_password(request, storage).await.or5xx(),
_ => not_found(), _ => endpoint_not_found(),
}; };
Ok(res) Ok(res)
} }
@ -98,6 +100,7 @@ fn endpoint_metrics(registry: MetricsRegistry) -> Response<Full<Bytes>> {
Response::new(Full::new(Bytes::from(buffer))) Response::new(Full::new(Bytes::from(buffer)))
} }
#[tracing::instrument(skip_all)]
async fn endpoint_rooms(rooms: RoomRegistry) -> Response<Full<Bytes>> { async fn endpoint_rooms(rooms: RoomRegistry) -> Response<Full<Bytes>> {
// TODO introduce management API types independent from core-domain types // TODO introduce management API types independent from core-domain types
// TODO remove `Serialize` implementations from all core-domain types // TODO remove `Serialize` implementations from all core-domain types
@ -105,6 +108,7 @@ async fn endpoint_rooms(rooms: RoomRegistry) -> Response<Full<Bytes>> {
Response::new(room_list) Response::new(room_list)
} }
#[tracing::instrument(skip_all)]
async fn endpoint_create_player( async fn endpoint_create_player(
request: Request<hyper::body::Incoming>, request: Request<hyper::body::Incoming>,
mut storage: Storage, mut storage: Storage,
@ -120,6 +124,27 @@ async fn endpoint_create_player(
Ok(response) Ok(response)
} }
#[tracing::instrument(skip_all)]
async fn endpoint_stop_player(
request: Request<hyper::body::Incoming>,
players: PlayerRegistry,
) -> Result<Response<Full<Bytes>>> {
let str = request.collect().await?.to_bytes();
let Ok(res) = serde_json::from_slice::<StopPlayerRequest>(&str[..]) else {
return Ok(malformed_request());
};
let Ok(player_id) = PlayerId::from(res.name) else {
return Ok(player_not_found());
};
let Some(()) = players.stop_player(&player_id).await? else {
return Ok(player_not_found());
};
let mut response = Response::new(Full::<Bytes>::default());
*response.status_mut() = StatusCode::NO_CONTENT;
Ok(response)
}
#[tracing::instrument(skip_all)]
async fn endpoint_set_password( async fn endpoint_set_password(
request: Request<hyper::body::Incoming>, request: Request<hyper::body::Incoming>,
storage: Storage, storage: Storage,
@ -132,14 +157,7 @@ async fn endpoint_set_password(
match verdict { match verdict {
UpdatePasswordResult::PasswordUpdated => {} UpdatePasswordResult::PasswordUpdated => {}
UpdatePasswordResult::UserNotFound => { UpdatePasswordResult::UserNotFound => {
let payload = ErrorResponse { return Ok(player_not_found());
code: errors::PLAYER_NOT_FOUND,
message: "No such player exists",
}
.to_body();
let mut response = Response::new(payload);
*response.status_mut() = StatusCode::UNPROCESSABLE_ENTITY;
return Ok(response);
} }
} }
let mut response = Response::new(Full::<Bytes>::default()); let mut response = Response::new(Full::<Bytes>::default());
@ -147,7 +165,7 @@ async fn endpoint_set_password(
Ok(response) Ok(response)
} }
pub fn not_found() -> Response<Full<Bytes>> { fn endpoint_not_found() -> Response<Full<Bytes>> {
let payload = ErrorResponse { let payload = ErrorResponse {
code: errors::INVALID_PATH, code: errors::INVALID_PATH,
message: "The path does not exist", message: "The path does not exist",
@ -159,6 +177,17 @@ pub fn not_found() -> Response<Full<Bytes>> {
response response
} }
fn player_not_found() -> Response<Full<Bytes>> {
let payload = ErrorResponse {
code: errors::PLAYER_NOT_FOUND,
message: "No such player exists",
}
.to_body();
let mut response = Response::new(payload);
*response.status_mut() = StatusCode::UNPROCESSABLE_ENTITY;
response
}
fn malformed_request() -> Response<Full<Bytes>> { fn malformed_request() -> Response<Full<Bytes>> {
let payload = ErrorResponse { let payload = ErrorResponse {
code: errors::MALFORMED_REQUEST, code: errors::MALFORMED_REQUEST,
@ -174,6 +203,7 @@ fn malformed_request() -> Response<Full<Bytes>> {
trait Or5xx { trait Or5xx {
fn or5xx(self) -> Response<Full<Bytes>>; fn or5xx(self) -> Response<Full<Bytes>>;
} }
impl Or5xx for Result<Response<Full<Bytes>>> { impl Or5xx for Result<Response<Full<Bytes>>> {
fn or5xx(self) -> Response<Full<Bytes>> { fn or5xx(self) -> Response<Full<Bytes>> {
self.unwrap_or_else(|e| { self.unwrap_or_else(|e| {
@ -187,6 +217,7 @@ impl Or5xx for Result<Response<Full<Bytes>>> {
trait ToBody { trait ToBody {
fn to_body(&self) -> Full<Bytes>; fn to_body(&self) -> Full<Bytes>;
} }
impl<T> ToBody for T impl<T> ToBody for T
where where
T: Serialize, T: Serialize,