From 4e8eb091847187f70d02cdbde38e1361bf94f839 Mon Sep 17 00:00:00 2001 From: Nikita Vilunov Date: Tue, 4 Jun 2024 21:54:57 +0000 Subject: [PATCH] reduce usage of unwraps (#70) Reviewed-on: https://git.vilunov.me/lavina/lavina/pulls/70 --- crates/lavina-core/src/clustering/room.rs | 4 +- crates/lavina-core/src/dialog.rs | 7 +- crates/lavina-core/src/lib.rs | 8 +- crates/lavina-core/src/player.rs | 162 ++++++++++++---------- crates/lavina-core/src/room.rs | 26 ++-- crates/projection-irc/src/lib.rs | 15 +- crates/projection-irc/tests/lib.rs | 22 +-- crates/projection-xmpp/src/iq.rs | 2 +- crates/projection-xmpp/src/lib.rs | 10 +- crates/projection-xmpp/src/presence.rs | 41 +++--- crates/projection-xmpp/src/testkit.rs | 3 +- crates/projection-xmpp/tests/lib.rs | 6 +- crates/proto-xmpp/src/bind.rs | 35 +++-- crates/proto-xmpp/src/client.rs | 8 +- crates/proto-xmpp/src/roster.rs | 7 +- crates/proto-xmpp/src/stream.rs | 15 +- crates/sasl/src/lib.rs | 15 +- src/http.rs | 20 ++- 18 files changed, 235 insertions(+), 171 deletions(-) diff --git a/crates/lavina-core/src/clustering/room.rs b/crates/lavina-core/src/clustering/room.rs index fc246e4..9fff363 100644 --- a/crates/lavina-core/src/clustering/room.rs +++ b/crates/lavina-core/src/clustering/room.rs @@ -79,10 +79,10 @@ impl LavinaCore { message: Str, created_at: chrono::DateTime, ) -> Result> { - let Some(room_handle) = self.services.rooms.get_room(&self.services, &room_id).await else { + let Some(room_handle) = self.services.rooms.get_room(&self.services, &room_id).await? else { return Ok(None); }; - room_handle.send_message(&self.services, &player_id, message, created_at).await; + room_handle.send_message(&self.services, &player_id, message, created_at).await?; Ok(Some(())) } } diff --git a/crates/lavina-core/src/dialog.rs b/crates/lavina-core/src/dialog.rs index 3813a57..763ac6d 100644 --- a/crates/lavina-core/src/dialog.rs +++ b/crates/lavina-core/src/dialog.rs @@ -137,14 +137,15 @@ mod tests { use super::*; #[test] - fn test_dialog_id_new() { - let a = PlayerId::from("a").unwrap(); - let b = PlayerId::from("b").unwrap(); + fn test_dialog_id_new() -> Result<()> { + let a = PlayerId::from("a")?; + let b = PlayerId::from("b")?; let id1 = DialogId::new(a.clone(), b.clone()); let id2 = DialogId::new(a.clone(), b.clone()); // Dialog ids are invariant with respect to the order of participants assert_eq!(id1, id2); assert_eq!(id1.as_inner(), (&a, &b)); assert_eq!(id2.as_inner(), (&a, &b)); + Ok(()) } } diff --git a/crates/lavina-core/src/lib.rs b/crates/lavina-core/src/lib.rs index b449055..5c1821d 100644 --- a/crates/lavina-core/src/lib.rs +++ b/crates/lavina-core/src/lib.rs @@ -8,7 +8,7 @@ use prometheus::Registry as MetricsRegistry; use crate::clustering::broadcast::Broadcasting; use crate::clustering::{ClusterConfig, ClusterMetadata, LavinaClient}; use crate::dialog::DialogRegistry; -use crate::player::{PlayerConnection, PlayerId, PlayerRegistry}; +use crate::player::{PlayerConnectionResult, PlayerId, PlayerRegistry}; use crate::repo::Storage; use crate::room::{RoomHandle, RoomId, RoomInfo, RoomRegistry}; @@ -37,11 +37,11 @@ impl Deref for LavinaCore { } impl LavinaCore { - pub async fn connect_to_player(&self, player_id: &PlayerId) -> PlayerConnection { + pub async fn connect_to_player(&self, player_id: &PlayerId) -> Result { self.services.players.connect_to_player(&self, player_id).await } - pub async fn get_room(&self, room_id: &RoomId) -> Option { + pub async fn get_room(&self, room_id: &RoomId) -> Result> { self.services.rooms.get_room(&self.services, room_id).await } @@ -97,7 +97,7 @@ impl LavinaCore { } pub async fn shutdown(self) -> Storage { - let _ = self.players.shutdown_all().await; + self.players.shutdown_all().await; let services = match Arc::try_unwrap(self.services) { Ok(e) => e, Err(_) => { diff --git a/crates/lavina-core/src/player.rs b/crates/lavina-core/src/player.rs index a5db4b7..da4c615 100644 --- a/crates/lavina-core/src/player.rs +++ b/crates/lavina-core/src/player.rs @@ -64,7 +64,7 @@ impl PlayerConnection { let (promise, deferred) = oneshot(); let cmd = ClientCommand::SendMessage { room_id, body, promise }; self.player_handle.send(ActorCommand::ClientCommand(cmd, self.connection_id.clone())).await; - Ok(deferred.await?) + deferred.await? } /// Handled in [Player::join_room]. @@ -73,7 +73,7 @@ impl PlayerConnection { let (promise, deferred) = oneshot(); let cmd = ClientCommand::JoinRoom { room_id, promise }; self.player_handle.send(ActorCommand::ClientCommand(cmd, self.connection_id.clone())).await; - Ok(deferred.await?) + deferred.await? } /// Handled in [Player::change_room_topic]. @@ -86,7 +86,7 @@ impl PlayerConnection { promise, }; self.player_handle.send(ActorCommand::ClientCommand(cmd, self.connection_id.clone())).await; - Ok(deferred.await?) + deferred.await? } /// Handled in [Player::leave_room]. @@ -95,7 +95,7 @@ impl PlayerConnection { let (promise, deferred) = oneshot(); let cmd = ClientCommand::LeaveRoom { room_id, promise }; self.player_handle.send(ActorCommand::ClientCommand(cmd, self.connection_id.clone())).await; - Ok(deferred.await?) + deferred.await? } pub async fn terminate(self) { @@ -108,7 +108,7 @@ impl PlayerConnection { let (promise, deferred) = oneshot(); let cmd = ClientCommand::GetRooms { promise }; self.player_handle.send(ActorCommand::ClientCommand(cmd, self.connection_id.clone())).await; - Ok(deferred.await?) + deferred.await? } #[tracing::instrument(skip(self), name = "PlayerConnection::get_room_message_history")] @@ -120,7 +120,7 @@ impl PlayerConnection { limit, }; self.player_handle.send(ActorCommand::ClientCommand(cmd, self.connection_id.clone())).await; - Ok(deferred.await?) + deferred.await? } /// Handler in [Player::send_dialog_message]. @@ -133,7 +133,7 @@ impl PlayerConnection { promise, }; self.player_handle.send(ActorCommand::ClientCommand(cmd, self.connection_id.clone())).await; - Ok(deferred.await?) + deferred.await? } /// Handler in [Player::check_user_existence]. @@ -142,7 +142,7 @@ impl PlayerConnection { let (promise, deferred) = oneshot(); let cmd = ClientCommand::GetInfo { recipient, promise }; self.player_handle.send(ActorCommand::ClientCommand(cmd, self.connection_id.clone())).await; - Ok(deferred.await?) + deferred.await? } } @@ -196,37 +196,37 @@ enum ActorCommand { pub enum ClientCommand { JoinRoom { room_id: RoomId, - promise: Promise, + promise: Promise>, }, LeaveRoom { room_id: RoomId, - promise: Promise<()>, + promise: Promise>, }, SendMessage { room_id: RoomId, body: Str, - promise: Promise, + promise: Promise>, }, ChangeTopic { room_id: RoomId, new_topic: Str, - promise: Promise<()>, + promise: Promise>, }, GetRooms { - promise: Promise>, + promise: Promise>>, }, SendDialogMessage { recipient: PlayerId, body: Str, - promise: Promise<()>, + promise: Promise>, }, GetInfo { recipient: PlayerId, - promise: Promise, + promise: Promise>, }, GetRoomHistory { room_id: RoomId, - promise: Promise>, + promise: Promise>>, limit: u32, }, } @@ -317,40 +317,45 @@ impl PlayerRegistry { } #[tracing::instrument(skip(self, core), name = "PlayerRegistry::get_or_launch_player")] - pub async fn get_or_launch_player(&self, core: &LavinaCore, id: &PlayerId) -> PlayerHandle { + async fn get_or_launch_player(&self, core: &LavinaCore, id: &PlayerId) -> Result> { let inner = self.0.read().await; if let Some((handle, _)) = inner.players.get(id) { - handle.clone() + Ok(Some(handle.clone())) } else { drop(inner); let mut inner = self.0.write().await; if let Some((handle, _)) = inner.players.get(id) { - handle.clone() + Ok(Some(handle.clone())) } else { - let (handle, fiber) = Player::launch(id.clone(), core.clone()).await; + let Some((handle, fiber)) = Player::launch(id.clone(), core.clone()).await? else { + return Ok(None); + }; inner.players.insert(id.clone(), (handle.clone(), fiber)); inner.metric_active_players.inc(); - handle + Ok(Some(handle)) } } } #[tracing::instrument(skip(self, core), name = "PlayerRegistry::connect_to_player")] - pub async fn connect_to_player(&self, core: &LavinaCore, id: &PlayerId) -> PlayerConnection { - let player_handle = self.get_or_launch_player(core, id).await; - player_handle.subscribe().await + pub async fn connect_to_player(&self, core: &LavinaCore, id: &PlayerId) -> Result { + let Some(player_handle) = self.get_or_launch_player(core, id).await? else { + return Ok(PlayerConnectionResult::PlayerNotFound); + }; + Ok(PlayerConnectionResult::Success(player_handle.subscribe().await)) } - pub async fn shutdown_all(&self) -> Result<()> { + pub async fn shutdown_all(&self) { let mut inner = self.0.write().await; - for (i, (k, j)) in inner.players.drain() { - k.send(ActorCommand::Stop).await; - drop(k); - j.await?; - log::debug!("Player stopped #{i:?}") + for (id, (handle, task)) in inner.players.drain() { + handle.send(ActorCommand::Stop).await; + drop(handle); + match task.await { + Ok(_) => log::debug!("Player stopped #{id:?}"), + Err(e) => log::error!("Player #{id:?} failed to stop: {e}"), + } } log::debug!("All players stopped"); - Ok(()) } } @@ -378,11 +383,13 @@ struct Player { services: LavinaCore, } impl Player { - async fn launch(player_id: PlayerId, core: LavinaCore) -> (PlayerHandle, JoinHandle) { + async fn launch(player_id: PlayerId, core: LavinaCore) -> Result)>> { let (tx, rx) = channel(32); let handle = PlayerHandle { tx }; let handle_clone = handle.clone(); - let storage_id = core.services.storage.retrieve_user_id_by_name(player_id.as_inner()).await.unwrap().unwrap(); + let Some(storage_id) = core.services.storage.retrieve_user_id_by_name(player_id.as_inner()).await? else { + return Ok(None); + }; let player = Player { player_id, storage_id, @@ -397,7 +404,7 @@ impl Player { services: core, }; let fiber = tokio::task::spawn(player.main_loop()); - (handle_clone, fiber) + Ok(Some((handle_clone, fiber))) } fn room_location(&self, room_id: &RoomId) -> Option { @@ -417,7 +424,7 @@ impl Player { self.my_rooms.insert(room_id.clone(), RoomRef::Remote { node_id: remote_node }); self.services.subscribe(self.player_id.clone(), room_id).await; } else { - let room = self.services.rooms.get_room(&self.services, &room_id).await; + let room = self.services.rooms.get_room(&self.services, &room_id).await.unwrap(); if let Some(room) = room { room.subscribe(&self.player_id, self.handle.clone()).await; self.my_rooms.insert(room_id, RoomRef::Local(room)); @@ -496,8 +503,8 @@ impl Player { let _ = promise.send(result); } ClientCommand::LeaveRoom { room_id, promise } => { - self.leave_room(connection_id, room_id).await; - let _ = promise.send(()); + let result = self.leave_room(connection_id, room_id).await; + let _ = promise.send(result); } ClientCommand::SendMessage { room_id, body, promise } => { let result = self.send_room_message(connection_id, room_id, body).await; @@ -508,20 +515,20 @@ impl Player { new_topic, promise, } => { - self.change_room_topic(connection_id, room_id, new_topic).await; - let _ = promise.send(()); + let result = self.change_room_topic(connection_id, room_id, new_topic).await; + let _ = promise.send(result); } ClientCommand::GetRooms { promise } => { let result = self.get_rooms().await; - let _ = promise.send(result); + let _ = promise.send(Ok(result)); } ClientCommand::SendDialogMessage { recipient, body, promise, } => { - self.send_dialog_message(connection_id, recipient, body).await; - let _ = promise.send(()); + let result = self.send_dialog_message(connection_id, recipient, body).await; + let _ = promise.send(result); } ClientCommand::GetInfo { recipient, promise } => { let result = self.check_user_existence(recipient).await; @@ -539,12 +546,12 @@ impl Player { } #[tracing::instrument(skip(self), name = "Player::join_room")] - async fn join_room(&mut self, connection_id: ConnectionId, room_id: RoomId) -> JoinResult { + async fn join_room(&mut self, connection_id: ConnectionId, room_id: RoomId) -> Result { if self.banned_from.contains(&room_id) { - return JoinResult::Banned; + return Ok(JoinResult::Banned); } if self.my_rooms.contains_key(&room_id) { - return JoinResult::AlreadyJoined; + return Ok(JoinResult::AlreadyJoined); } if let Some(remote_node) = self.room_location(&room_id) { @@ -552,16 +559,15 @@ impl Player { room_id: room_id.as_inner(), player_id: self.player_id.as_inner(), }; - self.services.client.join_room(remote_node, req).await.unwrap(); - let room_storage_id = - self.services.storage.create_or_retrieve_room_id_by_name(room_id.as_inner()).await.unwrap(); - self.services.storage.add_room_member(room_storage_id, self.storage_id).await.unwrap(); + self.services.client.join_room(remote_node, req).await?; + let room_storage_id = self.services.storage.create_or_retrieve_room_id_by_name(room_id.as_inner()).await?; + self.services.storage.add_room_member(room_storage_id, self.storage_id).await?; self.my_rooms.insert(room_id.clone(), RoomRef::Remote { node_id: remote_node }); - JoinResult::Success(RoomInfo { + Ok(JoinResult::Success(RoomInfo { id: room_id, topic: "unknown".into(), members: vec![], - }) + })) } else { let room = match self.services.rooms.get_or_create_room(&self.services, room_id.clone()).await { Ok(room) => room, @@ -579,12 +585,12 @@ impl Player { new_member_id: self.player_id.clone(), }; self.broadcast_update(update, connection_id).await; - JoinResult::Success(room_info) + Ok(JoinResult::Success(room_info)) } } #[tracing::instrument(skip(self), name = "Player::retrieve_room_history")] - async fn get_room_history(&mut self, room_id: RoomId, limit: u32) -> Vec { + async fn get_room_history(&mut self, room_id: RoomId, limit: u32) -> Result> { let room = self.my_rooms.get(&room_id); if let Some(room) = room { match room { @@ -601,7 +607,7 @@ impl Player { } #[tracing::instrument(skip(self), name = "Player::leave_room")] - async fn leave_room(&mut self, connection_id: ConnectionId, room_id: RoomId) { + async fn leave_room(&mut self, connection_id: ConnectionId, room_id: RoomId) -> Result<()> { let room = self.my_rooms.remove(&room_id); if let Some(room) = room { match room { @@ -614,10 +620,10 @@ impl Player { room_id: room_id.as_inner(), player_id: self.player_id.as_inner(), }; - self.services.client.leave_room(node_id, req).await.unwrap(); + self.services.client.leave_room(node_id, req).await?; let room_storage_id = - self.services.storage.create_or_retrieve_room_id_by_name(room_id.as_inner()).await.unwrap(); - self.services.storage.remove_room_member(room_storage_id, self.storage_id).await.unwrap(); + self.services.storage.create_or_retrieve_room_id_by_name(room_id.as_inner()).await?; + self.services.storage.remove_room_member(room_storage_id, self.storage_id).await?; } } } @@ -626,6 +632,7 @@ impl Player { former_member_id: self.player_id.clone(), }; self.broadcast_update(update, connection_id).await; + Ok(()) } #[tracing::instrument(skip(self, body), name = "Player::send_room_message")] @@ -634,15 +641,15 @@ impl Player { connection_id: ConnectionId, room_id: RoomId, body: Str, - ) -> SendMessageResult { + ) -> Result { let Some(room) = self.my_rooms.get(&room_id) else { tracing::info!("Room with ID {room_id:?} not found"); - return SendMessageResult::NoSuchRoom; + return Ok(SendMessageResult::NoSuchRoom); }; let created_at = Utc::now(); match room { RoomRef::Local(room) => { - room.send_message(&self.services, &self.player_id, body.clone(), created_at.clone()).await; + room.send_message(&self.services, &self.player_id, body.clone(), created_at.clone()).await?; } RoomRef::Remote { node_id } => { let req = SendMessageReq { @@ -651,7 +658,7 @@ impl Player { message: &*body, created_at: &*created_at.to_rfc3339(), }; - self.services.client.send_room_message(*node_id, req).await.unwrap(); + self.services.client.send_room_message(*node_id, req).await?; self.services .broadcast( room_id.clone(), @@ -669,18 +676,19 @@ impl Player { created_at, }; self.broadcast_update(update, connection_id).await; - SendMessageResult::Success(created_at) + Ok(SendMessageResult::Success(created_at)) } #[tracing::instrument(skip(self, new_topic), name = "Player::change_room_topic")] - async fn change_room_topic(&mut self, connection_id: ConnectionId, room_id: RoomId, new_topic: Str) { + async fn change_room_topic(&mut self, connection_id: ConnectionId, room_id: RoomId, new_topic: Str) -> Result<()> { let Some(room) = self.my_rooms.get(&room_id) else { tracing::info!("Room with ID {room_id:?} not found"); - return; + // TODO + return Ok(()); }; match room { RoomRef::Local(room) => { - room.set_topic(&self.services, &self.player_id, new_topic.clone()).await; + room.set_topic(&self.services, &self.player_id, new_topic.clone()).await?; } RoomRef::Remote { node_id } => { let req = SetRoomTopicReq { @@ -688,11 +696,12 @@ impl Player { player_id: self.player_id.as_inner(), topic: &*new_topic, }; - self.services.client.set_room_topic(*node_id, req).await.unwrap(); + self.services.client.set_room_topic(*node_id, req).await?; } } let update = Updates::RoomTopicChanged { room_id, new_topic }; self.broadcast_update(update, connection_id).await; + Ok(()) } #[tracing::instrument(skip(self), name = "Player::get_rooms")] @@ -717,12 +726,9 @@ impl Player { } #[tracing::instrument(skip(self, body), name = "Player::send_dialog_message")] - async fn send_dialog_message(&self, connection_id: ConnectionId, recipient: PlayerId, body: Str) { + async fn send_dialog_message(&self, connection_id: ConnectionId, recipient: PlayerId, body: Str) -> Result<()> { let created_at = Utc::now(); - self.services - .send_dialog_message(self.player_id.clone(), recipient.clone(), body.clone(), &created_at) - .await - .unwrap(); + self.services.send_dialog_message(self.player_id.clone(), recipient.clone(), body.clone(), &created_at).await?; let update = Updates::NewDialogMessage { sender: self.player_id.clone(), receiver: recipient.clone(), @@ -730,14 +736,15 @@ impl Player { created_at, }; self.broadcast_update(update, connection_id).await; + Ok(()) } #[tracing::instrument(skip(self), name = "Player::check_user_existence")] - async fn check_user_existence(&self, recipient: PlayerId) -> GetInfoResult { - if self.services.storage.check_user_existence(recipient.as_inner().as_ref()).await.unwrap() { - GetInfoResult::UserExists + async fn check_user_existence(&self, recipient: PlayerId) -> Result { + if self.services.storage.check_user_existence(recipient.as_inner().as_ref()).await? { + Ok(GetInfoResult::UserExists) } else { - GetInfoResult::UserDoesntExist + Ok(GetInfoResult::UserDoesntExist) } } @@ -766,3 +773,8 @@ pub enum StopReason { ServerShutdown, InternalError, } + +pub enum PlayerConnectionResult { + Success(PlayerConnection), + PlayerNotFound, +} diff --git a/crates/lavina-core/src/room.rs b/crates/lavina-core/src/room.rs index 334adf1..acb0fbf 100644 --- a/crates/lavina-core/src/room.rs +++ b/crates/lavina-core/src/room.rs @@ -79,9 +79,9 @@ impl RoomRegistry { } #[tracing::instrument(skip(self, services), name = "RoomRegistry::get_room")] - pub async fn get_room(&self, services: &Services, room_id: &RoomId) -> Option { + pub async fn get_room(&self, services: &Services, room_id: &RoomId) -> Result> { let mut inner = self.0.write().await; - inner.get_or_load_room(services, room_id).await.unwrap() + inner.get_or_load_room(services, room_id).await } #[tracing::instrument(skip(self), name = "RoomRegistry::get_all_rooms")] @@ -161,8 +161,8 @@ impl RoomHandle { lock.broadcast_update(update, player_id).await; } - pub async fn get_message_history(&self, services: &Services, limit: u32) -> Vec { - return services.storage.get_room_message_history(self.0.read().await.storage_id, limit).await.unwrap(); + pub async fn get_message_history(&self, services: &Services, limit: u32) -> Result> { + services.storage.get_room_message_history(self.0.read().await.storage_id, limit).await } #[tracing::instrument(skip(self), name = "RoomHandle::unsubscribe")] @@ -186,12 +186,19 @@ impl RoomHandle { } #[tracing::instrument(skip(self, services, body, created_at), name = "RoomHandle::send_message")] - pub async fn send_message(&self, services: &Services, player_id: &PlayerId, body: Str, created_at: DateTime) { + pub async fn send_message( + &self, + services: &Services, + player_id: &PlayerId, + body: Str, + created_at: DateTime, + ) -> Result<()> { let mut lock = self.0.write().await; let res = lock.send_message(services, player_id, body, created_at).await; - if let Err(err) = res { - log::warn!("Failed to send message: {err:?}"); + if let Err(err) = &res { + tracing::error!("Failed to send message: {err:?}"); } + res } #[tracing::instrument(skip(self), name = "RoomHandle::get_room_info")] @@ -205,16 +212,17 @@ impl RoomHandle { } #[tracing::instrument(skip(self, services, new_topic), name = "RoomHandle::set_topic")] - pub async fn set_topic(&self, services: &Services, changer_id: &PlayerId, new_topic: Str) { + pub async fn set_topic(&self, services: &Services, changer_id: &PlayerId, new_topic: Str) -> Result<()> { let mut lock = self.0.write().await; let storage_id = lock.storage_id; lock.topic = new_topic.clone(); - services.storage.set_room_topic(storage_id, &new_topic).await.unwrap(); + services.storage.set_room_topic(storage_id, &new_topic).await?; let update = Updates::RoomTopicChanged { room_id: lock.room_id.clone(), new_topic: new_topic.clone(), }; lock.broadcast_update(update, changer_id).await; + Ok(()) } } diff --git a/crates/projection-irc/src/lib.rs b/crates/projection-irc/src/lib.rs index 46ba1a5..1dcea2a 100644 --- a/crates/projection-irc/src/lib.rs +++ b/crates/projection-irc/src/lib.rs @@ -437,7 +437,14 @@ async fn handle_registered_socket<'a>( log::info!("Handling registered user: {user:?}"); let player_id = PlayerId::from(user.nickname.clone())?; - let mut connection = core.connect_to_player(&player_id).await; + let mut connection = match core.connect_to_player(&player_id).await? { + PlayerConnectionResult::Success(connection) => connection, + PlayerConnectionResult::PlayerNotFound => { + tracing::error!("Authorized user unexpectedly not found in the database"); + return Err(anyhow!("no such user")); + } + }; + let text: Str = format!("Welcome to {} Server", &config.server_name).into(); ServerMessage { @@ -577,7 +584,7 @@ async fn handle_update( match update { Updates::RoomJoined { new_member_id, room_id } => { if player_id == &new_member_id { - if let Some(room) = core.get_room(&room_id).await { + if let Some(room) = core.get_room(&room_id).await? { let room_info = room.get_room_info().await; let chan = Chan::Global(room_id.as_inner().clone()); produce_on_join_cmd_messages(&config, &user, &chan, &room_info, writer).await?; @@ -784,7 +791,7 @@ async fn handle_incoming_message( writer.flush().await?; } Recipient::Chan(Chan::Global(chan)) => { - let room = core.get_room(&RoomId::try_from(chan.clone())?).await; + let room = core.get_room(&RoomId::try_from(chan.clone())?).await?; if let Some(room) = room { let room_info = room.get_room_info().await; for member in room_info.members { @@ -870,7 +877,7 @@ async fn handle_incoming_message( // TODO Respond with an error when a local channel is requested Chan::Local(chan) => chan, }; - let room = core.get_room(&RoomId::try_from(channel_name.clone())?).await; + let room = core.get_room(&RoomId::try_from(channel_name.clone())?).await?; // TODO Handle non-existent room if let Some(room) = room { let room_id = &RoomId::try_from(channel_name.clone())?; diff --git a/crates/projection-irc/tests/lib.rs b/crates/projection-irc/tests/lib.rs index 089cb95..8ee8fdd 100644 --- a/crates/projection-irc/tests/lib.rs +++ b/crates/projection-irc/tests/lib.rs @@ -9,7 +9,7 @@ use tokio::net::tcp::{ReadHalf, WriteHalf}; use tokio::net::TcpStream; use lavina_core::clustering::{ClusterConfig, ClusterMetadata}; -use lavina_core::player::{JoinResult, PlayerId, SendMessageResult}; +use lavina_core::player::{JoinResult, PlayerConnectionResult, PlayerId, SendMessageResult}; use lavina_core::repo::{Storage, StorageConfig}; use lavina_core::room::RoomId; use lavina_core::LavinaCore; @@ -109,7 +109,7 @@ impl TestServer { async fn start() -> Result { let _ = tracing_subscriber::fmt::try_init(); let config = ServerConfig { - listen_on: "127.0.0.1:0".parse().unwrap(), + listen_on: "127.0.0.1:0".parse()?, server_name: "testserver".into(), }; let mut metrics = MetricsRegistry::new(); @@ -126,13 +126,13 @@ impl TestServer { }, }; let core = LavinaCore::new(&mut metrics, cluster_config, storage).await?; - let server = launch(config, core.clone(), metrics.clone()).await.unwrap(); + let server = launch(config, core.clone(), metrics.clone()).await?; Ok(TestServer { core, server }) } async fn reboot(self) -> Result { let config = ServerConfig { - listen_on: "127.0.0.1:0".parse().unwrap(), + listen_on: "127.0.0.1:0".parse()?, server_name: "testserver".into(), }; let cluster_config = ClusterConfig { @@ -148,7 +148,7 @@ impl TestServer { let storage = core.shutdown().await; let mut metrics = MetricsRegistry::new(); let core = LavinaCore::new(&mut metrics, cluster_config, storage).await?; - let server = launch(config, core.clone(), metrics.clone()).await.unwrap(); + let server = launch(config, core.clone(), metrics.clone()).await?; Ok(TestServer { core, server }) } @@ -766,16 +766,18 @@ async fn server_time_capability() -> Result<()> { s.expect(":testserver 366 tester #test :End of /NAMES list").await?; server.core.create_player(&PlayerId::from("some_guy")?).await?; - let mut conn = server.core.connect_to_player(&PlayerId::from("some_guy").unwrap()).await; - let res = conn.join_room(RoomId::try_from("test").unwrap()).await?; + let mut conn = match server.core.connect_to_player(&PlayerId::from("some_guy")?).await? { + PlayerConnectionResult::Success(conn) => conn, + PlayerConnectionResult::PlayerNotFound => panic!("user was created, but not returned"), + }; + let res = conn.join_room(RoomId::try_from("test")?).await?; let JoinResult::Success(_) = res else { panic!("Failed to join room"); }; s.expect(":some_guy JOIN #test").await?; - let SendMessageResult::Success(res) = conn.send_message(RoomId::try_from("test").unwrap(), "Hello".into()).await? - else { + let SendMessageResult::Success(res) = conn.send_message(RoomId::try_from("test")?, "Hello".into()).await? else { panic!("Failed to send message"); }; s.expect(&format!( @@ -786,7 +788,7 @@ async fn server_time_capability() -> Result<()> { // formatting check assert_eq!( - DateTime::parse_from_rfc3339(&"2024-01-01T10:00:32.123Z").unwrap().to_rfc3339_opts(SecondsFormat::Millis, true), + DateTime::parse_from_rfc3339(&"2024-01-01T10:00:32.123Z")?.to_rfc3339_opts(SecondsFormat::Millis, true), "2024-01-01T10:00:32.123Z" ); diff --git a/crates/projection-xmpp/src/iq.rs b/crates/projection-xmpp/src/iq.rs index 19b6dbb..24b3af7 100644 --- a/crates/projection-xmpp/src/iq.rs +++ b/crates/projection-xmpp/src/iq.rs @@ -167,7 +167,7 @@ impl<'a> XmppConnection<'a> { resource: None, }) if server.0 == self.hostname_rooms => { let room_id = RoomId::try_from(room_name.0.clone()).unwrap(); - let Some(_) = self.core.get_room(&room_id).await else { + let Some(_) = self.core.get_room(&room_id).await.unwrap() else { // TODO should return item-not-found // example: // diff --git a/crates/projection-xmpp/src/lib.rs b/crates/projection-xmpp/src/lib.rs index 54b63e7..88e71f3 100644 --- a/crates/projection-xmpp/src/lib.rs +++ b/crates/projection-xmpp/src/lib.rs @@ -23,7 +23,7 @@ use tokio_rustls::rustls::{Certificate, PrivateKey}; use tokio_rustls::TlsAcceptor; use lavina_core::auth::Verdict; -use lavina_core::player::{ConnectionMessage, PlayerConnection, PlayerId, StopReason}; +use lavina_core::player::{ConnectionMessage, PlayerConnection, PlayerConnectionResult, PlayerId, StopReason}; use lavina_core::prelude::*; use lavina_core::terminator::Terminator; use lavina_core::LavinaCore; @@ -202,7 +202,13 @@ async fn handle_socket( authenticated = socket_auth(&mut xml_reader, &mut xml_writer, &mut reader_buf, &core, &hostname) => { match authenticated { Ok(authenticated) => { - let mut connection = core.connect_to_player(&authenticated.player_id).await; + let mut connection = match core.connect_to_player(&authenticated.player_id).await? { + PlayerConnectionResult::Success(connection) => connection, + PlayerConnectionResult::PlayerNotFound => { + tracing::error!("Authorized user unexpectedly not found in the database"); + return Err(anyhow!("no such user")); + } + }; socket_final( &mut xml_reader, &mut xml_writer, diff --git a/crates/projection-xmpp/src/presence.rs b/crates/projection-xmpp/src/presence.rs index c2a1cb3..aa9f972 100644 --- a/crates/projection-xmpp/src/presence.rs +++ b/crates/projection-xmpp/src/presence.rs @@ -220,7 +220,7 @@ impl<'a> XmppConnection<'a> { mod tests { use anyhow::Result; - use lavina_core::player::PlayerId; + use lavina_core::player::{PlayerConnectionResult, PlayerId}; use proto_xmpp::bind::{Jid, Name, Resource, Server}; use proto_xmpp::client::Presence; use proto_xmpp::muc::{Affiliation, Role, XUser, XUserItem}; @@ -230,11 +230,11 @@ mod tests { #[tokio::test] async fn test_muc_joining() -> Result<()> { - let server = TestServer::start().await.unwrap(); + let server = TestServer::start().await?; server.core.create_player(&PlayerId::from("tester")?).await?; - let player_id = PlayerId::from("tester").unwrap(); + let player_id = PlayerId::from("tester")?; let user = Authenticated { player_id, xmpp_name: Name("tester".into()), @@ -242,10 +242,13 @@ mod tests { xmpp_muc_name: Resource("tester".into()), }; - let mut player_conn = server.core.connect_to_player(&user.player_id).await; - let mut conn = expect_user_authenticated(&server, &user, &mut player_conn).await.unwrap(); + let mut player_conn = match server.core.connect_to_player(&user.player_id).await? { + PlayerConnectionResult::Success(conn) => conn, + PlayerConnectionResult::PlayerNotFound => panic!("user was created, but not returned"), + }; + let mut conn = expect_user_authenticated(&server, &user, &mut player_conn).await?; - let muc_presence = conn.retrieve_muc_presence(&user.xmpp_name).await.unwrap(); + let muc_presence = conn.retrieve_muc_presence(&user.xmpp_name).await?; let expected = Presence { to: Some(Jid { name: Some(conn.user.xmpp_name.clone()), @@ -274,7 +277,7 @@ mod tests { }; assert_eq!(expected, muc_presence); - server.shutdown().await.unwrap(); + server.shutdown().await; Ok(()) } @@ -282,11 +285,11 @@ mod tests { // i.e. in-memory cache of memberships is cleaned, does not cause any issues. #[tokio::test] async fn test_muc_joining_twice() -> Result<()> { - let server = TestServer::start().await.unwrap(); + let server = TestServer::start().await?; server.core.create_player(&PlayerId::from("tester")?).await?; - let player_id = PlayerId::from("tester").unwrap(); + let player_id = PlayerId::from("tester")?; let user = Authenticated { player_id, xmpp_name: Name("tester".into()), @@ -294,10 +297,13 @@ mod tests { xmpp_muc_name: Resource("tester".into()), }; - let mut player_conn = server.core.connect_to_player(&user.player_id).await; - let mut conn = expect_user_authenticated(&server, &user, &mut player_conn).await.unwrap(); + let mut player_conn = match server.core.connect_to_player(&user.player_id).await? { + PlayerConnectionResult::Success(conn) => conn, + PlayerConnectionResult::PlayerNotFound => panic!("user was created, but not returned"), + }; + let mut conn = expect_user_authenticated(&server, &user, &mut player_conn).await?; - let response = conn.retrieve_muc_presence(&user.xmpp_name).await.unwrap(); + let response = conn.retrieve_muc_presence(&user.xmpp_name).await?; let expected = Presence { to: Some(Jid { name: Some(conn.user.xmpp_name.clone()), @@ -329,13 +335,16 @@ mod tests { drop(conn); let server = server.reboot().await.unwrap(); - let mut player_conn = server.core.connect_to_player(&user.player_id).await; - let mut conn = expect_user_authenticated(&server, &user, &mut player_conn).await.unwrap(); + let mut player_conn = match server.core.connect_to_player(&user.player_id).await? { + PlayerConnectionResult::Success(conn) => conn, + PlayerConnectionResult::PlayerNotFound => panic!("user was created, but not returned"), + }; + let mut conn = expect_user_authenticated(&server, &user, &mut player_conn).await?; - let response = conn.retrieve_muc_presence(&user.xmpp_name).await.unwrap(); + let response = conn.retrieve_muc_presence(&user.xmpp_name).await?; assert_eq!(expected, response); - server.shutdown().await.unwrap(); + server.shutdown().await; Ok(()) } } diff --git a/crates/projection-xmpp/src/testkit.rs b/crates/projection-xmpp/src/testkit.rs index 39a22db..4e571f2 100644 --- a/crates/projection-xmpp/src/testkit.rs +++ b/crates/projection-xmpp/src/testkit.rs @@ -48,10 +48,9 @@ impl TestServer { Ok(TestServer { core }) } - pub async fn shutdown(self) -> anyhow::Result<()> { + pub async fn shutdown(self) { let storage = self.core.shutdown().await; storage.close().await; - Ok(()) } } diff --git a/crates/projection-xmpp/tests/lib.rs b/crates/projection-xmpp/tests/lib.rs index 5b4d212..ad97947 100644 --- a/crates/projection-xmpp/tests/lib.rs +++ b/crates/projection-xmpp/tests/lib.rs @@ -149,9 +149,9 @@ impl TestServer { async fn start() -> Result { let _ = tracing_subscriber::fmt::try_init(); let config = ServerConfig { - listen_on: "127.0.0.1:0".parse().unwrap(), - cert: "tests/certs/xmpp.pem".parse().unwrap(), - key: "tests/certs/xmpp.key".parse().unwrap(), + listen_on: "127.0.0.1:0".parse()?, + cert: "tests/certs/xmpp.pem".parse()?, + key: "tests/certs/xmpp.key".parse()?, hostname: "localhost".into(), }; let mut metrics = MetricsRegistry::new(); diff --git a/crates/proto-xmpp/src/bind.rs b/crates/proto-xmpp/src/bind.rs index 68e00df..9a7d185 100644 --- a/crates/proto-xmpp/src/bind.rs +++ b/crates/proto-xmpp/src/bind.rs @@ -48,7 +48,8 @@ impl Jid { use lazy_static::lazy_static; use regex::Regex; lazy_static! { - static ref RE: Regex = Regex::new(r"^(([a-zA-Z0-9]+)@)?([^@/]+)(/([a-zA-Z0-9\-]+))?$").unwrap(); + static ref RE: Regex = + Regex::new(r"^(([a-zA-Z0-9]+)@)?([^@/]+)(/([a-zA-Z0-9\-]+))?$").expect("this is a correct regex"); } let m = RE.captures(i).ok_or(anyhow!("Incorrectly format jid: {i}"))?; @@ -152,70 +153,74 @@ mod tests { use super::*; #[tokio::test] - async fn parse_message() { + async fn parse_message() -> Result<()> { let input = r#"mobile"#; let mut reader = NsReader::from_reader(input.as_bytes()); let mut buf = vec![]; - let (ns, event) = reader.read_resolved_event_into_async(&mut buf).await.unwrap(); + let (ns, event) = reader.read_resolved_event_into_async(&mut buf).await?; let mut parser = BindRequest::parse().consume(ns, &event); let result = loop { match parser { Continuation::Final(res) => break res, Continuation::Continue(next) => { - let (ns, event) = reader.read_resolved_event_into_async(&mut buf).await.unwrap(); + let (ns, event) = reader.read_resolved_event_into_async(&mut buf).await?; parser = next.consume(ns, &event); } } - } - .unwrap(); - assert_eq!(result, BindRequest(Resource("mobile".into())),) + }?; + assert_eq!(result, BindRequest(Resource("mobile".into()))); + Ok(()) } #[test] - fn jid_parse_full() { + fn jid_parse_full() -> Result<()> { let input = "chelik@server.example/kek"; let expected = Jid { name: Some(Name("chelik".into())), server: Server("server.example".into()), resource: Some(Resource("kek".into())), }; - let res = Jid::from_string(input).unwrap(); + let res = Jid::from_string(input)?; assert_eq!(res, expected); + Ok(()) } #[test] - fn jid_parse_user() { + fn jid_parse_user() -> Result<()> { let input = "chelik@server.example"; let expected = Jid { name: Some(Name("chelik".into())), server: Server("server.example".into()), resource: None, }; - let res = Jid::from_string(input).unwrap(); + let res = Jid::from_string(input)?; assert_eq!(res, expected); + Ok(()) } #[test] - fn jid_parse_server() { + fn jid_parse_server() -> Result<()> { let input = "server.example"; let expected = Jid { name: None, server: Server("server.example".into()), resource: None, }; - let res = Jid::from_string(input).unwrap(); + let res = Jid::from_string(input)?; assert_eq!(res, expected); + Ok(()) } #[test] - fn jid_parse_server_resource() { + fn jid_parse_server_resource() -> Result<()> { let input = "server.example/kek"; let expected = Jid { name: None, server: Server("server.example".into()), resource: Some(Resource("kek".into())), }; - let res = Jid::from_string(input).unwrap(); + let res = Jid::from_string(input)?; assert_eq!(res, expected); + Ok(()) } } diff --git a/crates/proto-xmpp/src/client.rs b/crates/proto-xmpp/src/client.rs index 36d3b55..eae996b 100644 --- a/crates/proto-xmpp/src/client.rs +++ b/crates/proto-xmpp/src/client.rs @@ -709,7 +709,7 @@ mod tests { #[tokio::test] async fn parse_message() { - let input = r#"daabbb"#; + let input = r#"daabbb"#; let result: Message = crate::xml::parse(input).unwrap(); assert_eq!( result, @@ -718,7 +718,7 @@ mod tests { id: Some("aacea".to_string()), to: Some(Jid { name: Some(Name("chelik".into())), - server: Server("xmpp.ru".into()), + server: Server("example.com".into()), resource: None }), r#type: MessageType::Chat, @@ -732,7 +732,7 @@ mod tests { #[tokio::test] async fn parse_message_empty_custom() { - let input = r#"daabbb"#; + let input = r#"daabbb"#; let result: Message = crate::xml::parse(input).unwrap(); assert_eq!( result, @@ -741,7 +741,7 @@ mod tests { id: Some("aacea".to_string()), to: Some(Jid { name: Some(Name("chelik".into())), - server: Server("xmpp.ru".into()), + server: Server("example.com".into()), resource: None }), r#type: MessageType::Chat, diff --git a/crates/proto-xmpp/src/roster.rs b/crates/proto-xmpp/src/roster.rs index 7eb97d1..0733926 100644 --- a/crates/proto-xmpp/src/roster.rs +++ b/crates/proto-xmpp/src/roster.rs @@ -49,11 +49,11 @@ mod tests { use crate::client::{Iq, IqType}; #[test] - fn test_parse() { + fn test_parse() -> Result<()> { let input = r#""#; - let result: Iq = parse(input).unwrap(); + let result: Iq = parse(input)?; assert_eq!( result, Iq { @@ -67,7 +67,8 @@ mod tests { r#type: IqType::Get, body: RosterQuery, } - ) + ); + Ok(()) } #[test] diff --git a/crates/proto-xmpp/src/stream.rs b/crates/proto-xmpp/src/stream.rs index b12891e..035f8a5 100644 --- a/crates/proto-xmpp/src/stream.rs +++ b/crates/proto-xmpp/src/stream.rs @@ -169,30 +169,31 @@ mod test { use super::*; #[tokio::test] - async fn client_stream_start_correct_parse() { - let input = r###""###; + async fn client_stream_start_correct_parse() -> Result<()> { + let input = r###""###; let mut reader = NsReader::from_reader(input.as_bytes()); let mut buf = vec![]; - let res = ClientStreamStart::parse(&mut reader, &mut buf).await.unwrap(); + let res = ClientStreamStart::parse(&mut reader, &mut buf).await?; assert_eq!( res, ClientStreamStart { - to: "xmpp.ru".to_owned(), + to: "example.com".to_owned(), lang: Some("en".to_owned()), version: "1.0".to_owned() } - ) + ); + Ok(()) } #[tokio::test] async fn server_stream_start_write() { let input = ServerStreamStart { - from: "xmpp.ru".to_owned(), + from: "example.com".to_owned(), lang: "en".to_owned(), id: "stream_id".to_owned(), version: "1.0".to_owned(), }; - let expected = r###""###; + let expected = r###""###; let mut output: Vec = vec![]; let mut writer = Writer::new(&mut output); input.write_xml(&mut writer).await.unwrap(); diff --git a/crates/sasl/src/lib.rs b/crates/sasl/src/lib.rs index 75f69c5..9ed954b 100644 --- a/crates/sasl/src/lib.rs +++ b/crates/sasl/src/lib.rs @@ -37,42 +37,45 @@ impl AuthBody { mod test { use super::*; #[test] - fn test_returning_auth_body() { + fn test_returning_auth_body() -> Result<()> { let orig = b"\x00login\x00pass"; let encoded = general_purpose::STANDARD.encode(orig); let expected = AuthBody { login: "login".to_string(), password: "pass".to_string(), }; - let result = AuthBody::from_str(encoded.as_bytes()).unwrap(); + let result = AuthBody::from_str(encoded.as_bytes())?; assert_eq!(expected, result); + Ok(()) } #[test] - fn test_ignoring_first_segment() { + fn test_ignoring_first_segment() -> Result<()> { let orig = b"ignored\x00login\x00pass"; let encoded = general_purpose::STANDARD.encode(orig); let expected = AuthBody { login: "login".to_string(), password: "pass".to_string(), }; - let result = AuthBody::from_str(encoded.as_bytes()).unwrap(); + let result = AuthBody::from_str(encoded.as_bytes())?; assert_eq!(expected, result); + Ok(()) } #[test] - fn test_returning_auth_body_with_empty_strings() { + fn test_returning_auth_body_with_empty_strings() -> Result<()> { let orig = b"\x00\x00"; let encoded = general_purpose::STANDARD.encode(orig); let expected = AuthBody { login: "".to_string(), password: "".to_string(), }; - let result = AuthBody::from_str(encoded.as_bytes()).unwrap(); + let result = AuthBody::from_str(encoded.as_bytes())?; assert_eq!(expected, result); + Ok(()) } #[test] diff --git a/src/http.rs b/src/http.rs index af66758..9bc189d 100644 --- a/src/http.rs +++ b/src/http.rs @@ -13,7 +13,7 @@ use serde::{Deserialize, Serialize}; use tokio::net::TcpListener; use lavina_core::auth::UpdatePasswordResult; -use lavina_core::player::{PlayerId, SendMessageResult}; +use lavina_core::player::{PlayerConnectionResult, PlayerId, SendMessageResult}; use lavina_core::prelude::*; use lavina_core::room::RoomId; use lavina_core::terminator::Terminator; @@ -170,8 +170,13 @@ async fn endpoint_send_room_message( let Ok(player_id) = PlayerId::from(req.author_id) else { return Ok(player_not_found()); }; - let mut player = core.connect_to_player(&player_id).await; - let res = player.send_message(room_id, req.message.into()).await?; + let mut connection = match core.connect_to_player(&player_id).await? { + PlayerConnectionResult::Success(connection) => connection, + PlayerConnectionResult::PlayerNotFound => { + return Ok(player_not_found()); + } + }; + let res = connection.send_message(room_id, req.message.into()).await?; match res { SendMessageResult::NoSuchRoom => Ok(room_not_found()), SendMessageResult::Success(_) => Ok(empty_204_request()), @@ -193,8 +198,13 @@ async fn endpoint_set_room_topic( let Ok(player_id) = PlayerId::from(req.author_id) else { return Ok(player_not_found()); }; - let mut player = core.connect_to_player(&player_id).await; - player.change_topic(room_id, req.topic.into()).await?; + let mut connection = match core.connect_to_player(&player_id).await? { + PlayerConnectionResult::Success(connection) => connection, + PlayerConnectionResult::PlayerNotFound => { + return Ok(player_not_found()); + } + }; + connection.change_topic(room_id, req.topic.into()).await?; Ok(empty_204_request()) }