reduce usage of unwraps (#70)

Reviewed-on: lavina/lavina#70
This commit is contained in:
Nikita Vilunov 2024-06-04 21:54:57 +00:00
parent d0420ec834
commit 4e8eb09184
18 changed files with 235 additions and 171 deletions

View File

@ -79,10 +79,10 @@ impl LavinaCore {
message: Str,
created_at: chrono::DateTime<chrono::Utc>,
) -> Result<Option<()>> {
let Some(room_handle) = self.services.rooms.get_room(&self.services, &room_id).await else {
let Some(room_handle) = self.services.rooms.get_room(&self.services, &room_id).await? else {
return Ok(None);
};
room_handle.send_message(&self.services, &player_id, message, created_at).await;
room_handle.send_message(&self.services, &player_id, message, created_at).await?;
Ok(Some(()))
}
}

View File

@ -137,14 +137,15 @@ mod tests {
use super::*;
#[test]
fn test_dialog_id_new() {
let a = PlayerId::from("a").unwrap();
let b = PlayerId::from("b").unwrap();
fn test_dialog_id_new() -> Result<()> {
let a = PlayerId::from("a")?;
let b = PlayerId::from("b")?;
let id1 = DialogId::new(a.clone(), b.clone());
let id2 = DialogId::new(a.clone(), b.clone());
// Dialog ids are invariant with respect to the order of participants
assert_eq!(id1, id2);
assert_eq!(id1.as_inner(), (&a, &b));
assert_eq!(id2.as_inner(), (&a, &b));
Ok(())
}
}

View File

@ -8,7 +8,7 @@ use prometheus::Registry as MetricsRegistry;
use crate::clustering::broadcast::Broadcasting;
use crate::clustering::{ClusterConfig, ClusterMetadata, LavinaClient};
use crate::dialog::DialogRegistry;
use crate::player::{PlayerConnection, PlayerId, PlayerRegistry};
use crate::player::{PlayerConnectionResult, PlayerId, PlayerRegistry};
use crate::repo::Storage;
use crate::room::{RoomHandle, RoomId, RoomInfo, RoomRegistry};
@ -37,11 +37,11 @@ impl Deref for LavinaCore {
}
impl LavinaCore {
pub async fn connect_to_player(&self, player_id: &PlayerId) -> PlayerConnection {
pub async fn connect_to_player(&self, player_id: &PlayerId) -> Result<PlayerConnectionResult> {
self.services.players.connect_to_player(&self, player_id).await
}
pub async fn get_room(&self, room_id: &RoomId) -> Option<RoomHandle> {
pub async fn get_room(&self, room_id: &RoomId) -> Result<Option<RoomHandle>> {
self.services.rooms.get_room(&self.services, room_id).await
}
@ -97,7 +97,7 @@ impl LavinaCore {
}
pub async fn shutdown(self) -> Storage {
let _ = self.players.shutdown_all().await;
self.players.shutdown_all().await;
let services = match Arc::try_unwrap(self.services) {
Ok(e) => e,
Err(_) => {

View File

@ -64,7 +64,7 @@ impl PlayerConnection {
let (promise, deferred) = oneshot();
let cmd = ClientCommand::SendMessage { room_id, body, promise };
self.player_handle.send(ActorCommand::ClientCommand(cmd, self.connection_id.clone())).await;
Ok(deferred.await?)
deferred.await?
}
/// Handled in [Player::join_room].
@ -73,7 +73,7 @@ impl PlayerConnection {
let (promise, deferred) = oneshot();
let cmd = ClientCommand::JoinRoom { room_id, promise };
self.player_handle.send(ActorCommand::ClientCommand(cmd, self.connection_id.clone())).await;
Ok(deferred.await?)
deferred.await?
}
/// Handled in [Player::change_room_topic].
@ -86,7 +86,7 @@ impl PlayerConnection {
promise,
};
self.player_handle.send(ActorCommand::ClientCommand(cmd, self.connection_id.clone())).await;
Ok(deferred.await?)
deferred.await?
}
/// Handled in [Player::leave_room].
@ -95,7 +95,7 @@ impl PlayerConnection {
let (promise, deferred) = oneshot();
let cmd = ClientCommand::LeaveRoom { room_id, promise };
self.player_handle.send(ActorCommand::ClientCommand(cmd, self.connection_id.clone())).await;
Ok(deferred.await?)
deferred.await?
}
pub async fn terminate(self) {
@ -108,7 +108,7 @@ impl PlayerConnection {
let (promise, deferred) = oneshot();
let cmd = ClientCommand::GetRooms { promise };
self.player_handle.send(ActorCommand::ClientCommand(cmd, self.connection_id.clone())).await;
Ok(deferred.await?)
deferred.await?
}
#[tracing::instrument(skip(self), name = "PlayerConnection::get_room_message_history")]
@ -120,7 +120,7 @@ impl PlayerConnection {
limit,
};
self.player_handle.send(ActorCommand::ClientCommand(cmd, self.connection_id.clone())).await;
Ok(deferred.await?)
deferred.await?
}
/// Handler in [Player::send_dialog_message].
@ -133,7 +133,7 @@ impl PlayerConnection {
promise,
};
self.player_handle.send(ActorCommand::ClientCommand(cmd, self.connection_id.clone())).await;
Ok(deferred.await?)
deferred.await?
}
/// Handler in [Player::check_user_existence].
@ -142,7 +142,7 @@ impl PlayerConnection {
let (promise, deferred) = oneshot();
let cmd = ClientCommand::GetInfo { recipient, promise };
self.player_handle.send(ActorCommand::ClientCommand(cmd, self.connection_id.clone())).await;
Ok(deferred.await?)
deferred.await?
}
}
@ -196,37 +196,37 @@ enum ActorCommand {
pub enum ClientCommand {
JoinRoom {
room_id: RoomId,
promise: Promise<JoinResult>,
promise: Promise<Result<JoinResult>>,
},
LeaveRoom {
room_id: RoomId,
promise: Promise<()>,
promise: Promise<Result<()>>,
},
SendMessage {
room_id: RoomId,
body: Str,
promise: Promise<SendMessageResult>,
promise: Promise<Result<SendMessageResult>>,
},
ChangeTopic {
room_id: RoomId,
new_topic: Str,
promise: Promise<()>,
promise: Promise<Result<()>>,
},
GetRooms {
promise: Promise<Vec<RoomInfo>>,
promise: Promise<Result<Vec<RoomInfo>>>,
},
SendDialogMessage {
recipient: PlayerId,
body: Str,
promise: Promise<()>,
promise: Promise<Result<()>>,
},
GetInfo {
recipient: PlayerId,
promise: Promise<GetInfoResult>,
promise: Promise<Result<GetInfoResult>>,
},
GetRoomHistory {
room_id: RoomId,
promise: Promise<Vec<StoredMessage>>,
promise: Promise<Result<Vec<StoredMessage>>>,
limit: u32,
},
}
@ -317,40 +317,45 @@ impl PlayerRegistry {
}
#[tracing::instrument(skip(self, core), name = "PlayerRegistry::get_or_launch_player")]
pub async fn get_or_launch_player(&self, core: &LavinaCore, id: &PlayerId) -> PlayerHandle {
async fn get_or_launch_player(&self, core: &LavinaCore, id: &PlayerId) -> Result<Option<PlayerHandle>> {
let inner = self.0.read().await;
if let Some((handle, _)) = inner.players.get(id) {
handle.clone()
Ok(Some(handle.clone()))
} else {
drop(inner);
let mut inner = self.0.write().await;
if let Some((handle, _)) = inner.players.get(id) {
handle.clone()
Ok(Some(handle.clone()))
} else {
let (handle, fiber) = Player::launch(id.clone(), core.clone()).await;
let Some((handle, fiber)) = Player::launch(id.clone(), core.clone()).await? else {
return Ok(None);
};
inner.players.insert(id.clone(), (handle.clone(), fiber));
inner.metric_active_players.inc();
handle
Ok(Some(handle))
}
}
}
#[tracing::instrument(skip(self, core), name = "PlayerRegistry::connect_to_player")]
pub async fn connect_to_player(&self, core: &LavinaCore, id: &PlayerId) -> PlayerConnection {
let player_handle = self.get_or_launch_player(core, id).await;
player_handle.subscribe().await
pub async fn connect_to_player(&self, core: &LavinaCore, id: &PlayerId) -> Result<PlayerConnectionResult> {
let Some(player_handle) = self.get_or_launch_player(core, id).await? else {
return Ok(PlayerConnectionResult::PlayerNotFound);
};
Ok(PlayerConnectionResult::Success(player_handle.subscribe().await))
}
pub async fn shutdown_all(&self) -> Result<()> {
pub async fn shutdown_all(&self) {
let mut inner = self.0.write().await;
for (i, (k, j)) in inner.players.drain() {
k.send(ActorCommand::Stop).await;
drop(k);
j.await?;
log::debug!("Player stopped #{i:?}")
for (id, (handle, task)) in inner.players.drain() {
handle.send(ActorCommand::Stop).await;
drop(handle);
match task.await {
Ok(_) => log::debug!("Player stopped #{id:?}"),
Err(e) => log::error!("Player #{id:?} failed to stop: {e}"),
}
}
log::debug!("All players stopped");
Ok(())
}
}
@ -378,11 +383,13 @@ struct Player {
services: LavinaCore,
}
impl Player {
async fn launch(player_id: PlayerId, core: LavinaCore) -> (PlayerHandle, JoinHandle<Player>) {
async fn launch(player_id: PlayerId, core: LavinaCore) -> Result<Option<(PlayerHandle, JoinHandle<Player>)>> {
let (tx, rx) = channel(32);
let handle = PlayerHandle { tx };
let handle_clone = handle.clone();
let storage_id = core.services.storage.retrieve_user_id_by_name(player_id.as_inner()).await.unwrap().unwrap();
let Some(storage_id) = core.services.storage.retrieve_user_id_by_name(player_id.as_inner()).await? else {
return Ok(None);
};
let player = Player {
player_id,
storage_id,
@ -397,7 +404,7 @@ impl Player {
services: core,
};
let fiber = tokio::task::spawn(player.main_loop());
(handle_clone, fiber)
Ok(Some((handle_clone, fiber)))
}
fn room_location(&self, room_id: &RoomId) -> Option<u32> {
@ -417,7 +424,7 @@ impl Player {
self.my_rooms.insert(room_id.clone(), RoomRef::Remote { node_id: remote_node });
self.services.subscribe(self.player_id.clone(), room_id).await;
} else {
let room = self.services.rooms.get_room(&self.services, &room_id).await;
let room = self.services.rooms.get_room(&self.services, &room_id).await.unwrap();
if let Some(room) = room {
room.subscribe(&self.player_id, self.handle.clone()).await;
self.my_rooms.insert(room_id, RoomRef::Local(room));
@ -496,8 +503,8 @@ impl Player {
let _ = promise.send(result);
}
ClientCommand::LeaveRoom { room_id, promise } => {
self.leave_room(connection_id, room_id).await;
let _ = promise.send(());
let result = self.leave_room(connection_id, room_id).await;
let _ = promise.send(result);
}
ClientCommand::SendMessage { room_id, body, promise } => {
let result = self.send_room_message(connection_id, room_id, body).await;
@ -508,20 +515,20 @@ impl Player {
new_topic,
promise,
} => {
self.change_room_topic(connection_id, room_id, new_topic).await;
let _ = promise.send(());
let result = self.change_room_topic(connection_id, room_id, new_topic).await;
let _ = promise.send(result);
}
ClientCommand::GetRooms { promise } => {
let result = self.get_rooms().await;
let _ = promise.send(result);
let _ = promise.send(Ok(result));
}
ClientCommand::SendDialogMessage {
recipient,
body,
promise,
} => {
self.send_dialog_message(connection_id, recipient, body).await;
let _ = promise.send(());
let result = self.send_dialog_message(connection_id, recipient, body).await;
let _ = promise.send(result);
}
ClientCommand::GetInfo { recipient, promise } => {
let result = self.check_user_existence(recipient).await;
@ -539,12 +546,12 @@ impl Player {
}
#[tracing::instrument(skip(self), name = "Player::join_room")]
async fn join_room(&mut self, connection_id: ConnectionId, room_id: RoomId) -> JoinResult {
async fn join_room(&mut self, connection_id: ConnectionId, room_id: RoomId) -> Result<JoinResult> {
if self.banned_from.contains(&room_id) {
return JoinResult::Banned;
return Ok(JoinResult::Banned);
}
if self.my_rooms.contains_key(&room_id) {
return JoinResult::AlreadyJoined;
return Ok(JoinResult::AlreadyJoined);
}
if let Some(remote_node) = self.room_location(&room_id) {
@ -552,16 +559,15 @@ impl Player {
room_id: room_id.as_inner(),
player_id: self.player_id.as_inner(),
};
self.services.client.join_room(remote_node, req).await.unwrap();
let room_storage_id =
self.services.storage.create_or_retrieve_room_id_by_name(room_id.as_inner()).await.unwrap();
self.services.storage.add_room_member(room_storage_id, self.storage_id).await.unwrap();
self.services.client.join_room(remote_node, req).await?;
let room_storage_id = self.services.storage.create_or_retrieve_room_id_by_name(room_id.as_inner()).await?;
self.services.storage.add_room_member(room_storage_id, self.storage_id).await?;
self.my_rooms.insert(room_id.clone(), RoomRef::Remote { node_id: remote_node });
JoinResult::Success(RoomInfo {
Ok(JoinResult::Success(RoomInfo {
id: room_id,
topic: "unknown".into(),
members: vec![],
})
}))
} else {
let room = match self.services.rooms.get_or_create_room(&self.services, room_id.clone()).await {
Ok(room) => room,
@ -579,12 +585,12 @@ impl Player {
new_member_id: self.player_id.clone(),
};
self.broadcast_update(update, connection_id).await;
JoinResult::Success(room_info)
Ok(JoinResult::Success(room_info))
}
}
#[tracing::instrument(skip(self), name = "Player::retrieve_room_history")]
async fn get_room_history(&mut self, room_id: RoomId, limit: u32) -> Vec<StoredMessage> {
async fn get_room_history(&mut self, room_id: RoomId, limit: u32) -> Result<Vec<StoredMessage>> {
let room = self.my_rooms.get(&room_id);
if let Some(room) = room {
match room {
@ -601,7 +607,7 @@ impl Player {
}
#[tracing::instrument(skip(self), name = "Player::leave_room")]
async fn leave_room(&mut self, connection_id: ConnectionId, room_id: RoomId) {
async fn leave_room(&mut self, connection_id: ConnectionId, room_id: RoomId) -> Result<()> {
let room = self.my_rooms.remove(&room_id);
if let Some(room) = room {
match room {
@ -614,10 +620,10 @@ impl Player {
room_id: room_id.as_inner(),
player_id: self.player_id.as_inner(),
};
self.services.client.leave_room(node_id, req).await.unwrap();
self.services.client.leave_room(node_id, req).await?;
let room_storage_id =
self.services.storage.create_or_retrieve_room_id_by_name(room_id.as_inner()).await.unwrap();
self.services.storage.remove_room_member(room_storage_id, self.storage_id).await.unwrap();
self.services.storage.create_or_retrieve_room_id_by_name(room_id.as_inner()).await?;
self.services.storage.remove_room_member(room_storage_id, self.storage_id).await?;
}
}
}
@ -626,6 +632,7 @@ impl Player {
former_member_id: self.player_id.clone(),
};
self.broadcast_update(update, connection_id).await;
Ok(())
}
#[tracing::instrument(skip(self, body), name = "Player::send_room_message")]
@ -634,15 +641,15 @@ impl Player {
connection_id: ConnectionId,
room_id: RoomId,
body: Str,
) -> SendMessageResult {
) -> Result<SendMessageResult> {
let Some(room) = self.my_rooms.get(&room_id) else {
tracing::info!("Room with ID {room_id:?} not found");
return SendMessageResult::NoSuchRoom;
return Ok(SendMessageResult::NoSuchRoom);
};
let created_at = Utc::now();
match room {
RoomRef::Local(room) => {
room.send_message(&self.services, &self.player_id, body.clone(), created_at.clone()).await;
room.send_message(&self.services, &self.player_id, body.clone(), created_at.clone()).await?;
}
RoomRef::Remote { node_id } => {
let req = SendMessageReq {
@ -651,7 +658,7 @@ impl Player {
message: &*body,
created_at: &*created_at.to_rfc3339(),
};
self.services.client.send_room_message(*node_id, req).await.unwrap();
self.services.client.send_room_message(*node_id, req).await?;
self.services
.broadcast(
room_id.clone(),
@ -669,18 +676,19 @@ impl Player {
created_at,
};
self.broadcast_update(update, connection_id).await;
SendMessageResult::Success(created_at)
Ok(SendMessageResult::Success(created_at))
}
#[tracing::instrument(skip(self, new_topic), name = "Player::change_room_topic")]
async fn change_room_topic(&mut self, connection_id: ConnectionId, room_id: RoomId, new_topic: Str) {
async fn change_room_topic(&mut self, connection_id: ConnectionId, room_id: RoomId, new_topic: Str) -> Result<()> {
let Some(room) = self.my_rooms.get(&room_id) else {
tracing::info!("Room with ID {room_id:?} not found");
return;
// TODO
return Ok(());
};
match room {
RoomRef::Local(room) => {
room.set_topic(&self.services, &self.player_id, new_topic.clone()).await;
room.set_topic(&self.services, &self.player_id, new_topic.clone()).await?;
}
RoomRef::Remote { node_id } => {
let req = SetRoomTopicReq {
@ -688,11 +696,12 @@ impl Player {
player_id: self.player_id.as_inner(),
topic: &*new_topic,
};
self.services.client.set_room_topic(*node_id, req).await.unwrap();
self.services.client.set_room_topic(*node_id, req).await?;
}
}
let update = Updates::RoomTopicChanged { room_id, new_topic };
self.broadcast_update(update, connection_id).await;
Ok(())
}
#[tracing::instrument(skip(self), name = "Player::get_rooms")]
@ -717,12 +726,9 @@ impl Player {
}
#[tracing::instrument(skip(self, body), name = "Player::send_dialog_message")]
async fn send_dialog_message(&self, connection_id: ConnectionId, recipient: PlayerId, body: Str) {
async fn send_dialog_message(&self, connection_id: ConnectionId, recipient: PlayerId, body: Str) -> Result<()> {
let created_at = Utc::now();
self.services
.send_dialog_message(self.player_id.clone(), recipient.clone(), body.clone(), &created_at)
.await
.unwrap();
self.services.send_dialog_message(self.player_id.clone(), recipient.clone(), body.clone(), &created_at).await?;
let update = Updates::NewDialogMessage {
sender: self.player_id.clone(),
receiver: recipient.clone(),
@ -730,14 +736,15 @@ impl Player {
created_at,
};
self.broadcast_update(update, connection_id).await;
Ok(())
}
#[tracing::instrument(skip(self), name = "Player::check_user_existence")]
async fn check_user_existence(&self, recipient: PlayerId) -> GetInfoResult {
if self.services.storage.check_user_existence(recipient.as_inner().as_ref()).await.unwrap() {
GetInfoResult::UserExists
async fn check_user_existence(&self, recipient: PlayerId) -> Result<GetInfoResult> {
if self.services.storage.check_user_existence(recipient.as_inner().as_ref()).await? {
Ok(GetInfoResult::UserExists)
} else {
GetInfoResult::UserDoesntExist
Ok(GetInfoResult::UserDoesntExist)
}
}
@ -766,3 +773,8 @@ pub enum StopReason {
ServerShutdown,
InternalError,
}
pub enum PlayerConnectionResult {
Success(PlayerConnection),
PlayerNotFound,
}

View File

@ -79,9 +79,9 @@ impl RoomRegistry {
}
#[tracing::instrument(skip(self, services), name = "RoomRegistry::get_room")]
pub async fn get_room(&self, services: &Services, room_id: &RoomId) -> Option<RoomHandle> {
pub async fn get_room(&self, services: &Services, room_id: &RoomId) -> Result<Option<RoomHandle>> {
let mut inner = self.0.write().await;
inner.get_or_load_room(services, room_id).await.unwrap()
inner.get_or_load_room(services, room_id).await
}
#[tracing::instrument(skip(self), name = "RoomRegistry::get_all_rooms")]
@ -161,8 +161,8 @@ impl RoomHandle {
lock.broadcast_update(update, player_id).await;
}
pub async fn get_message_history(&self, services: &Services, limit: u32) -> Vec<StoredMessage> {
return services.storage.get_room_message_history(self.0.read().await.storage_id, limit).await.unwrap();
pub async fn get_message_history(&self, services: &Services, limit: u32) -> Result<Vec<StoredMessage>> {
services.storage.get_room_message_history(self.0.read().await.storage_id, limit).await
}
#[tracing::instrument(skip(self), name = "RoomHandle::unsubscribe")]
@ -186,12 +186,19 @@ impl RoomHandle {
}
#[tracing::instrument(skip(self, services, body, created_at), name = "RoomHandle::send_message")]
pub async fn send_message(&self, services: &Services, player_id: &PlayerId, body: Str, created_at: DateTime<Utc>) {
pub async fn send_message(
&self,
services: &Services,
player_id: &PlayerId,
body: Str,
created_at: DateTime<Utc>,
) -> Result<()> {
let mut lock = self.0.write().await;
let res = lock.send_message(services, player_id, body, created_at).await;
if let Err(err) = res {
log::warn!("Failed to send message: {err:?}");
if let Err(err) = &res {
tracing::error!("Failed to send message: {err:?}");
}
res
}
#[tracing::instrument(skip(self), name = "RoomHandle::get_room_info")]
@ -205,16 +212,17 @@ impl RoomHandle {
}
#[tracing::instrument(skip(self, services, new_topic), name = "RoomHandle::set_topic")]
pub async fn set_topic(&self, services: &Services, changer_id: &PlayerId, new_topic: Str) {
pub async fn set_topic(&self, services: &Services, changer_id: &PlayerId, new_topic: Str) -> Result<()> {
let mut lock = self.0.write().await;
let storage_id = lock.storage_id;
lock.topic = new_topic.clone();
services.storage.set_room_topic(storage_id, &new_topic).await.unwrap();
services.storage.set_room_topic(storage_id, &new_topic).await?;
let update = Updates::RoomTopicChanged {
room_id: lock.room_id.clone(),
new_topic: new_topic.clone(),
};
lock.broadcast_update(update, changer_id).await;
Ok(())
}
}

View File

@ -437,7 +437,14 @@ async fn handle_registered_socket<'a>(
log::info!("Handling registered user: {user:?}");
let player_id = PlayerId::from(user.nickname.clone())?;
let mut connection = core.connect_to_player(&player_id).await;
let mut connection = match core.connect_to_player(&player_id).await? {
PlayerConnectionResult::Success(connection) => connection,
PlayerConnectionResult::PlayerNotFound => {
tracing::error!("Authorized user unexpectedly not found in the database");
return Err(anyhow!("no such user"));
}
};
let text: Str = format!("Welcome to {} Server", &config.server_name).into();
ServerMessage {
@ -577,7 +584,7 @@ async fn handle_update(
match update {
Updates::RoomJoined { new_member_id, room_id } => {
if player_id == &new_member_id {
if let Some(room) = core.get_room(&room_id).await {
if let Some(room) = core.get_room(&room_id).await? {
let room_info = room.get_room_info().await;
let chan = Chan::Global(room_id.as_inner().clone());
produce_on_join_cmd_messages(&config, &user, &chan, &room_info, writer).await?;
@ -784,7 +791,7 @@ async fn handle_incoming_message(
writer.flush().await?;
}
Recipient::Chan(Chan::Global(chan)) => {
let room = core.get_room(&RoomId::try_from(chan.clone())?).await;
let room = core.get_room(&RoomId::try_from(chan.clone())?).await?;
if let Some(room) = room {
let room_info = room.get_room_info().await;
for member in room_info.members {
@ -870,7 +877,7 @@ async fn handle_incoming_message(
// TODO Respond with an error when a local channel is requested
Chan::Local(chan) => chan,
};
let room = core.get_room(&RoomId::try_from(channel_name.clone())?).await;
let room = core.get_room(&RoomId::try_from(channel_name.clone())?).await?;
// TODO Handle non-existent room
if let Some(room) = room {
let room_id = &RoomId::try_from(channel_name.clone())?;

View File

@ -9,7 +9,7 @@ use tokio::net::tcp::{ReadHalf, WriteHalf};
use tokio::net::TcpStream;
use lavina_core::clustering::{ClusterConfig, ClusterMetadata};
use lavina_core::player::{JoinResult, PlayerId, SendMessageResult};
use lavina_core::player::{JoinResult, PlayerConnectionResult, PlayerId, SendMessageResult};
use lavina_core::repo::{Storage, StorageConfig};
use lavina_core::room::RoomId;
use lavina_core::LavinaCore;
@ -109,7 +109,7 @@ impl TestServer {
async fn start() -> Result<TestServer> {
let _ = tracing_subscriber::fmt::try_init();
let config = ServerConfig {
listen_on: "127.0.0.1:0".parse().unwrap(),
listen_on: "127.0.0.1:0".parse()?,
server_name: "testserver".into(),
};
let mut metrics = MetricsRegistry::new();
@ -126,13 +126,13 @@ impl TestServer {
},
};
let core = LavinaCore::new(&mut metrics, cluster_config, storage).await?;
let server = launch(config, core.clone(), metrics.clone()).await.unwrap();
let server = launch(config, core.clone(), metrics.clone()).await?;
Ok(TestServer { core, server })
}
async fn reboot(self) -> Result<TestServer> {
let config = ServerConfig {
listen_on: "127.0.0.1:0".parse().unwrap(),
listen_on: "127.0.0.1:0".parse()?,
server_name: "testserver".into(),
};
let cluster_config = ClusterConfig {
@ -148,7 +148,7 @@ impl TestServer {
let storage = core.shutdown().await;
let mut metrics = MetricsRegistry::new();
let core = LavinaCore::new(&mut metrics, cluster_config, storage).await?;
let server = launch(config, core.clone(), metrics.clone()).await.unwrap();
let server = launch(config, core.clone(), metrics.clone()).await?;
Ok(TestServer { core, server })
}
@ -766,16 +766,18 @@ async fn server_time_capability() -> Result<()> {
s.expect(":testserver 366 tester #test :End of /NAMES list").await?;
server.core.create_player(&PlayerId::from("some_guy")?).await?;
let mut conn = server.core.connect_to_player(&PlayerId::from("some_guy").unwrap()).await;
let res = conn.join_room(RoomId::try_from("test").unwrap()).await?;
let mut conn = match server.core.connect_to_player(&PlayerId::from("some_guy")?).await? {
PlayerConnectionResult::Success(conn) => conn,
PlayerConnectionResult::PlayerNotFound => panic!("user was created, but not returned"),
};
let res = conn.join_room(RoomId::try_from("test")?).await?;
let JoinResult::Success(_) = res else {
panic!("Failed to join room");
};
s.expect(":some_guy JOIN #test").await?;
let SendMessageResult::Success(res) = conn.send_message(RoomId::try_from("test").unwrap(), "Hello".into()).await?
else {
let SendMessageResult::Success(res) = conn.send_message(RoomId::try_from("test")?, "Hello".into()).await? else {
panic!("Failed to send message");
};
s.expect(&format!(
@ -786,7 +788,7 @@ async fn server_time_capability() -> Result<()> {
// formatting check
assert_eq!(
DateTime::parse_from_rfc3339(&"2024-01-01T10:00:32.123Z").unwrap().to_rfc3339_opts(SecondsFormat::Millis, true),
DateTime::parse_from_rfc3339(&"2024-01-01T10:00:32.123Z")?.to_rfc3339_opts(SecondsFormat::Millis, true),
"2024-01-01T10:00:32.123Z"
);

View File

@ -167,7 +167,7 @@ impl<'a> XmppConnection<'a> {
resource: None,
}) if server.0 == self.hostname_rooms => {
let room_id = RoomId::try_from(room_name.0.clone()).unwrap();
let Some(_) = self.core.get_room(&room_id).await else {
let Some(_) = self.core.get_room(&room_id).await.unwrap() else {
// TODO should return item-not-found
// example:
// <error type="cancel">

View File

@ -23,7 +23,7 @@ use tokio_rustls::rustls::{Certificate, PrivateKey};
use tokio_rustls::TlsAcceptor;
use lavina_core::auth::Verdict;
use lavina_core::player::{ConnectionMessage, PlayerConnection, PlayerId, StopReason};
use lavina_core::player::{ConnectionMessage, PlayerConnection, PlayerConnectionResult, PlayerId, StopReason};
use lavina_core::prelude::*;
use lavina_core::terminator::Terminator;
use lavina_core::LavinaCore;
@ -202,7 +202,13 @@ async fn handle_socket(
authenticated = socket_auth(&mut xml_reader, &mut xml_writer, &mut reader_buf, &core, &hostname) => {
match authenticated {
Ok(authenticated) => {
let mut connection = core.connect_to_player(&authenticated.player_id).await;
let mut connection = match core.connect_to_player(&authenticated.player_id).await? {
PlayerConnectionResult::Success(connection) => connection,
PlayerConnectionResult::PlayerNotFound => {
tracing::error!("Authorized user unexpectedly not found in the database");
return Err(anyhow!("no such user"));
}
};
socket_final(
&mut xml_reader,
&mut xml_writer,

View File

@ -220,7 +220,7 @@ impl<'a> XmppConnection<'a> {
mod tests {
use anyhow::Result;
use lavina_core::player::PlayerId;
use lavina_core::player::{PlayerConnectionResult, PlayerId};
use proto_xmpp::bind::{Jid, Name, Resource, Server};
use proto_xmpp::client::Presence;
use proto_xmpp::muc::{Affiliation, Role, XUser, XUserItem};
@ -230,11 +230,11 @@ mod tests {
#[tokio::test]
async fn test_muc_joining() -> Result<()> {
let server = TestServer::start().await.unwrap();
let server = TestServer::start().await?;
server.core.create_player(&PlayerId::from("tester")?).await?;
let player_id = PlayerId::from("tester").unwrap();
let player_id = PlayerId::from("tester")?;
let user = Authenticated {
player_id,
xmpp_name: Name("tester".into()),
@ -242,10 +242,13 @@ mod tests {
xmpp_muc_name: Resource("tester".into()),
};
let mut player_conn = server.core.connect_to_player(&user.player_id).await;
let mut conn = expect_user_authenticated(&server, &user, &mut player_conn).await.unwrap();
let mut player_conn = match server.core.connect_to_player(&user.player_id).await? {
PlayerConnectionResult::Success(conn) => conn,
PlayerConnectionResult::PlayerNotFound => panic!("user was created, but not returned"),
};
let mut conn = expect_user_authenticated(&server, &user, &mut player_conn).await?;
let muc_presence = conn.retrieve_muc_presence(&user.xmpp_name).await.unwrap();
let muc_presence = conn.retrieve_muc_presence(&user.xmpp_name).await?;
let expected = Presence {
to: Some(Jid {
name: Some(conn.user.xmpp_name.clone()),
@ -274,7 +277,7 @@ mod tests {
};
assert_eq!(expected, muc_presence);
server.shutdown().await.unwrap();
server.shutdown().await;
Ok(())
}
@ -282,11 +285,11 @@ mod tests {
// i.e. in-memory cache of memberships is cleaned, does not cause any issues.
#[tokio::test]
async fn test_muc_joining_twice() -> Result<()> {
let server = TestServer::start().await.unwrap();
let server = TestServer::start().await?;
server.core.create_player(&PlayerId::from("tester")?).await?;
let player_id = PlayerId::from("tester").unwrap();
let player_id = PlayerId::from("tester")?;
let user = Authenticated {
player_id,
xmpp_name: Name("tester".into()),
@ -294,10 +297,13 @@ mod tests {
xmpp_muc_name: Resource("tester".into()),
};
let mut player_conn = server.core.connect_to_player(&user.player_id).await;
let mut conn = expect_user_authenticated(&server, &user, &mut player_conn).await.unwrap();
let mut player_conn = match server.core.connect_to_player(&user.player_id).await? {
PlayerConnectionResult::Success(conn) => conn,
PlayerConnectionResult::PlayerNotFound => panic!("user was created, but not returned"),
};
let mut conn = expect_user_authenticated(&server, &user, &mut player_conn).await?;
let response = conn.retrieve_muc_presence(&user.xmpp_name).await.unwrap();
let response = conn.retrieve_muc_presence(&user.xmpp_name).await?;
let expected = Presence {
to: Some(Jid {
name: Some(conn.user.xmpp_name.clone()),
@ -329,13 +335,16 @@ mod tests {
drop(conn);
let server = server.reboot().await.unwrap();
let mut player_conn = server.core.connect_to_player(&user.player_id).await;
let mut conn = expect_user_authenticated(&server, &user, &mut player_conn).await.unwrap();
let mut player_conn = match server.core.connect_to_player(&user.player_id).await? {
PlayerConnectionResult::Success(conn) => conn,
PlayerConnectionResult::PlayerNotFound => panic!("user was created, but not returned"),
};
let mut conn = expect_user_authenticated(&server, &user, &mut player_conn).await?;
let response = conn.retrieve_muc_presence(&user.xmpp_name).await.unwrap();
let response = conn.retrieve_muc_presence(&user.xmpp_name).await?;
assert_eq!(expected, response);
server.shutdown().await.unwrap();
server.shutdown().await;
Ok(())
}
}

View File

@ -48,10 +48,9 @@ impl TestServer {
Ok(TestServer { core })
}
pub async fn shutdown(self) -> anyhow::Result<()> {
pub async fn shutdown(self) {
let storage = self.core.shutdown().await;
storage.close().await;
Ok(())
}
}

View File

@ -149,9 +149,9 @@ impl TestServer {
async fn start() -> Result<TestServer> {
let _ = tracing_subscriber::fmt::try_init();
let config = ServerConfig {
listen_on: "127.0.0.1:0".parse().unwrap(),
cert: "tests/certs/xmpp.pem".parse().unwrap(),
key: "tests/certs/xmpp.key".parse().unwrap(),
listen_on: "127.0.0.1:0".parse()?,
cert: "tests/certs/xmpp.pem".parse()?,
key: "tests/certs/xmpp.key".parse()?,
hostname: "localhost".into(),
};
let mut metrics = MetricsRegistry::new();

View File

@ -48,7 +48,8 @@ impl Jid {
use lazy_static::lazy_static;
use regex::Regex;
lazy_static! {
static ref RE: Regex = Regex::new(r"^(([a-zA-Z0-9]+)@)?([^@/]+)(/([a-zA-Z0-9\-]+))?$").unwrap();
static ref RE: Regex =
Regex::new(r"^(([a-zA-Z0-9]+)@)?([^@/]+)(/([a-zA-Z0-9\-]+))?$").expect("this is a correct regex");
}
let m = RE.captures(i).ok_or(anyhow!("Incorrectly format jid: {i}"))?;
@ -152,70 +153,74 @@ mod tests {
use super::*;
#[tokio::test]
async fn parse_message() {
async fn parse_message() -> Result<()> {
let input = r#"<bind xmlns="urn:ietf:params:xml:ns:xmpp-bind"><resource>mobile</resource></bind>"#;
let mut reader = NsReader::from_reader(input.as_bytes());
let mut buf = vec![];
let (ns, event) = reader.read_resolved_event_into_async(&mut buf).await.unwrap();
let (ns, event) = reader.read_resolved_event_into_async(&mut buf).await?;
let mut parser = BindRequest::parse().consume(ns, &event);
let result = loop {
match parser {
Continuation::Final(res) => break res,
Continuation::Continue(next) => {
let (ns, event) = reader.read_resolved_event_into_async(&mut buf).await.unwrap();
let (ns, event) = reader.read_resolved_event_into_async(&mut buf).await?;
parser = next.consume(ns, &event);
}
}
}
.unwrap();
assert_eq!(result, BindRequest(Resource("mobile".into())),)
}?;
assert_eq!(result, BindRequest(Resource("mobile".into())));
Ok(())
}
#[test]
fn jid_parse_full() {
fn jid_parse_full() -> Result<()> {
let input = "chelik@server.example/kek";
let expected = Jid {
name: Some(Name("chelik".into())),
server: Server("server.example".into()),
resource: Some(Resource("kek".into())),
};
let res = Jid::from_string(input).unwrap();
let res = Jid::from_string(input)?;
assert_eq!(res, expected);
Ok(())
}
#[test]
fn jid_parse_user() {
fn jid_parse_user() -> Result<()> {
let input = "chelik@server.example";
let expected = Jid {
name: Some(Name("chelik".into())),
server: Server("server.example".into()),
resource: None,
};
let res = Jid::from_string(input).unwrap();
let res = Jid::from_string(input)?;
assert_eq!(res, expected);
Ok(())
}
#[test]
fn jid_parse_server() {
fn jid_parse_server() -> Result<()> {
let input = "server.example";
let expected = Jid {
name: None,
server: Server("server.example".into()),
resource: None,
};
let res = Jid::from_string(input).unwrap();
let res = Jid::from_string(input)?;
assert_eq!(res, expected);
Ok(())
}
#[test]
fn jid_parse_server_resource() {
fn jid_parse_server_resource() -> Result<()> {
let input = "server.example/kek";
let expected = Jid {
name: None,
server: Server("server.example".into()),
resource: Some(Resource("kek".into())),
};
let res = Jid::from_string(input).unwrap();
let res = Jid::from_string(input)?;
assert_eq!(res, expected);
Ok(())
}
}

View File

@ -709,7 +709,7 @@ mod tests {
#[tokio::test]
async fn parse_message() {
let input = r#"<message id="aacea" type="chat" to="chelik@xmpp.ru"><subject>daa</subject><body>bbb</body><unknown-stuff></unknown-stuff></message>"#;
let input = r#"<message id="aacea" type="chat" to="chelik@example.com"><subject>daa</subject><body>bbb</body><unknown-stuff></unknown-stuff></message>"#;
let result: Message<Ignore> = crate::xml::parse(input).unwrap();
assert_eq!(
result,
@ -718,7 +718,7 @@ mod tests {
id: Some("aacea".to_string()),
to: Some(Jid {
name: Some(Name("chelik".into())),
server: Server("xmpp.ru".into()),
server: Server("example.com".into()),
resource: None
}),
r#type: MessageType::Chat,
@ -732,7 +732,7 @@ mod tests {
#[tokio::test]
async fn parse_message_empty_custom() {
let input = r#"<message id="aacea" type="chat" to="chelik@xmpp.ru"><subject>daa</subject><body>bbb</body><unknown-stuff/></message>"#;
let input = r#"<message id="aacea" type="chat" to="chelik@example.com"><subject>daa</subject><body>bbb</body><unknown-stuff/></message>"#;
let result: Message<Ignore> = crate::xml::parse(input).unwrap();
assert_eq!(
result,
@ -741,7 +741,7 @@ mod tests {
id: Some("aacea".to_string()),
to: Some(Jid {
name: Some(Name("chelik".into())),
server: Server("xmpp.ru".into()),
server: Server("example.com".into()),
resource: None
}),
r#type: MessageType::Chat,

View File

@ -49,11 +49,11 @@ mod tests {
use crate::client::{Iq, IqType};
#[test]
fn test_parse() {
fn test_parse() -> Result<()> {
let input =
r#"<iq from='juliet@example.com/balcony' id='bv1bs71f' type='get'><query xmlns='jabber:iq:roster'/></iq>"#;
let result: Iq<RosterQuery> = parse(input).unwrap();
let result: Iq<RosterQuery> = parse(input)?;
assert_eq!(
result,
Iq {
@ -67,7 +67,8 @@ mod tests {
r#type: IqType::Get,
body: RosterQuery,
}
)
);
Ok(())
}
#[test]

View File

@ -169,30 +169,31 @@ mod test {
use super::*;
#[tokio::test]
async fn client_stream_start_correct_parse() {
let input = r###"<stream:stream xmlns:stream="http://etherx.jabber.org/streams" to="xmpp.ru" version="1.0" xmlns="jabber:client" xml:lang="en" xmlns:xml="http://www.w3.org/XML/1998/namespace">"###;
async fn client_stream_start_correct_parse() -> Result<()> {
let input = r###"<stream:stream xmlns:stream="http://etherx.jabber.org/streams" to="example.com" version="1.0" xmlns="jabber:client" xml:lang="en" xmlns:xml="http://www.w3.org/XML/1998/namespace">"###;
let mut reader = NsReader::from_reader(input.as_bytes());
let mut buf = vec![];
let res = ClientStreamStart::parse(&mut reader, &mut buf).await.unwrap();
let res = ClientStreamStart::parse(&mut reader, &mut buf).await?;
assert_eq!(
res,
ClientStreamStart {
to: "xmpp.ru".to_owned(),
to: "example.com".to_owned(),
lang: Some("en".to_owned()),
version: "1.0".to_owned()
}
)
);
Ok(())
}
#[tokio::test]
async fn server_stream_start_write() {
let input = ServerStreamStart {
from: "xmpp.ru".to_owned(),
from: "example.com".to_owned(),
lang: "en".to_owned(),
id: "stream_id".to_owned(),
version: "1.0".to_owned(),
};
let expected = r###"<stream:stream from="xmpp.ru" version="1.0" xmlns="jabber:client" xmlns:stream="http://etherx.jabber.org/streams" xml:lang="en" id="stream_id">"###;
let expected = r###"<stream:stream from="example.com" version="1.0" xmlns="jabber:client" xmlns:stream="http://etherx.jabber.org/streams" xml:lang="en" id="stream_id">"###;
let mut output: Vec<u8> = vec![];
let mut writer = Writer::new(&mut output);
input.write_xml(&mut writer).await.unwrap();

View File

@ -37,42 +37,45 @@ impl AuthBody {
mod test {
use super::*;
#[test]
fn test_returning_auth_body() {
fn test_returning_auth_body() -> Result<()> {
let orig = b"\x00login\x00pass";
let encoded = general_purpose::STANDARD.encode(orig);
let expected = AuthBody {
login: "login".to_string(),
password: "pass".to_string(),
};
let result = AuthBody::from_str(encoded.as_bytes()).unwrap();
let result = AuthBody::from_str(encoded.as_bytes())?;
assert_eq!(expected, result);
Ok(())
}
#[test]
fn test_ignoring_first_segment() {
fn test_ignoring_first_segment() -> Result<()> {
let orig = b"ignored\x00login\x00pass";
let encoded = general_purpose::STANDARD.encode(orig);
let expected = AuthBody {
login: "login".to_string(),
password: "pass".to_string(),
};
let result = AuthBody::from_str(encoded.as_bytes()).unwrap();
let result = AuthBody::from_str(encoded.as_bytes())?;
assert_eq!(expected, result);
Ok(())
}
#[test]
fn test_returning_auth_body_with_empty_strings() {
fn test_returning_auth_body_with_empty_strings() -> Result<()> {
let orig = b"\x00\x00";
let encoded = general_purpose::STANDARD.encode(orig);
let expected = AuthBody {
login: "".to_string(),
password: "".to_string(),
};
let result = AuthBody::from_str(encoded.as_bytes()).unwrap();
let result = AuthBody::from_str(encoded.as_bytes())?;
assert_eq!(expected, result);
Ok(())
}
#[test]

View File

@ -13,7 +13,7 @@ use serde::{Deserialize, Serialize};
use tokio::net::TcpListener;
use lavina_core::auth::UpdatePasswordResult;
use lavina_core::player::{PlayerId, SendMessageResult};
use lavina_core::player::{PlayerConnectionResult, PlayerId, SendMessageResult};
use lavina_core::prelude::*;
use lavina_core::room::RoomId;
use lavina_core::terminator::Terminator;
@ -170,8 +170,13 @@ async fn endpoint_send_room_message(
let Ok(player_id) = PlayerId::from(req.author_id) else {
return Ok(player_not_found());
};
let mut player = core.connect_to_player(&player_id).await;
let res = player.send_message(room_id, req.message.into()).await?;
let mut connection = match core.connect_to_player(&player_id).await? {
PlayerConnectionResult::Success(connection) => connection,
PlayerConnectionResult::PlayerNotFound => {
return Ok(player_not_found());
}
};
let res = connection.send_message(room_id, req.message.into()).await?;
match res {
SendMessageResult::NoSuchRoom => Ok(room_not_found()),
SendMessageResult::Success(_) => Ok(empty_204_request()),
@ -193,8 +198,13 @@ async fn endpoint_set_room_topic(
let Ok(player_id) = PlayerId::from(req.author_id) else {
return Ok(player_not_found());
};
let mut player = core.connect_to_player(&player_id).await;
player.change_topic(room_id, req.topic.into()).await?;
let mut connection = match core.connect_to_player(&player_id).await? {
PlayerConnectionResult::Success(connection) => connection,
PlayerConnectionResult::PlayerNotFound => {
return Ok(player_not_found());
}
};
connection.change_topic(room_id, req.topic.into()).await?;
Ok(empty_204_request())
}