diff --git a/crates/lavina-core/src/player.rs b/crates/lavina-core/src/player.rs index eec22f8..0486808 100644 --- a/crates/lavina-core/src/player.rs +++ b/crates/lavina-core/src/player.rs @@ -7,16 +7,16 @@ //! //! A player actor is a serial handler of commands from a single player. It is preferable to run all per-player validations in the player actor, //! so that they don't overload the room actor. -use std::{ - collections::{HashMap, HashSet}, - sync::{Arc, RwLock}, -}; +use std::collections::{HashMap, HashSet}; +use std::sync::Arc; use prometheus::{IntGauge, Registry as MetricsRegistry}; use serde::Serialize; use tokio::sync::mpsc::{channel, Receiver, Sender}; +use tokio::sync::RwLock; use crate::prelude::*; +use crate::repo::Storage; use crate::room::{RoomHandle, RoomId, RoomInfo, RoomRegistry}; use crate::table::{AnonTable, Key as AnonKey}; @@ -208,36 +208,41 @@ pub enum Updates { #[derive(Clone)] pub struct PlayerRegistry(Arc>); impl PlayerRegistry { - pub fn empty(room_registry: RoomRegistry, metrics: &mut MetricsRegistry) -> Result { + pub fn empty( + room_registry: RoomRegistry, + storage: Storage, + metrics: &mut MetricsRegistry, + ) -> Result { let metric_active_players = IntGauge::new("chat_players_active", "Number of alive player actors")?; metrics.register(Box::new(metric_active_players.clone()))?; let inner = PlayerRegistryInner { room_registry, + storage, players: HashMap::new(), metric_active_players, }; Ok(PlayerRegistry(Arc::new(RwLock::new(inner)))) } - pub async fn get_or_create_player(&mut self, id: PlayerId) -> PlayerHandle { - let mut inner = self.0.write().unwrap(); - if let Some((handle, _)) = inner.players.get(&id) { + pub async fn get_or_launch_player(&mut self, id: &PlayerId) -> PlayerHandle { + let mut inner = self.0.write().await; + if let Some((handle, _)) = inner.players.get(id) { handle.clone() } else { - let (handle, fiber) = Player::launch(id.clone(), inner.room_registry.clone()); - inner.players.insert(id, (handle.clone(), fiber)); + let (handle, fiber) = Player::launch(id.clone(), inner.room_registry.clone(), inner.storage.clone()).await; + inner.players.insert(id.clone(), (handle.clone(), fiber)); inner.metric_active_players.inc(); handle } } - pub async fn connect_to_player(&mut self, id: PlayerId) -> PlayerConnection { - let player_handle = self.get_or_create_player(id).await; + pub async fn connect_to_player(&mut self, id: &PlayerId) -> PlayerConnection { + let player_handle = self.get_or_launch_player(id).await; player_handle.subscribe().await } pub async fn shutdown_all(&mut self) -> Result<()> { - let mut inner = self.0.write().unwrap(); + let mut inner = self.0.write().await; for (i, (k, j)) in inner.players.drain() { k.send(ActorCommand::Stop).await; drop(k); @@ -252,6 +257,8 @@ impl PlayerRegistry { /// The player registry state representation. struct PlayerRegistryInner { room_registry: RoomRegistry, + storage: Storage, + /// Active player actors. players: HashMap)>, metric_active_players: IntGauge, } @@ -259,32 +266,49 @@ struct PlayerRegistryInner { /// Player actor inner state representation. struct Player { player_id: PlayerId, + storage_id: u32, connections: AnonTable>, my_rooms: HashMap, banned_from: HashSet, rx: Receiver, handle: PlayerHandle, rooms: RoomRegistry, + storage: Storage, } impl Player { - fn launch(player_id: PlayerId, rooms: RoomRegistry) -> (PlayerHandle, JoinHandle) { + async fn launch(player_id: PlayerId, rooms: RoomRegistry, storage: Storage) -> (PlayerHandle, JoinHandle) { let (tx, rx) = channel(32); let handle = PlayerHandle { tx }; let handle_clone = handle.clone(); + let storage_id = storage.retrieve_user_id_by_name(player_id.as_inner()).await.unwrap().unwrap(); let player = Player { player_id, + storage_id, + // connections are empty when the actor is just started connections: AnonTable::new(), + // room handlers will be loaded later in the started task my_rooms: HashMap::new(), - banned_from: HashSet::from([RoomId::from("Empty").unwrap()]), + // TODO implement and load bans + banned_from: HashSet::new(), rx, handle, rooms, + storage, }; let fiber = tokio::task::spawn(player.main_loop()); (handle_clone, fiber) } async fn main_loop(mut self) -> Self { + let rooms = self.storage.get_rooms_of_a_user(self.storage_id).await.unwrap(); + for room_id in rooms { + let room = self.rooms.get_room(&room_id).await; + if let Some(room) = room { + self.my_rooms.insert(room_id, room); + } else { + tracing::error!("Room #{room_id:?} not found"); + } + } while let Some(cmd) = self.rx.recv().await { match cmd { ActorCommand::AddConnection { sender, promise } => { @@ -372,7 +396,8 @@ impl Player { todo!(); } }; - room.subscribe(self.player_id.clone(), self.handle.clone()).await; + room.add_member(&self.player_id, self.storage_id).await; + room.subscribe(&self.player_id, self.handle.clone()).await; self.my_rooms.insert(room_id.clone(), room.clone()); let room_info = room.get_room_info().await; let update = Updates::RoomJoined { @@ -387,6 +412,7 @@ impl Player { let room = self.my_rooms.remove(&room_id); if let Some(room) = room { room.unsubscribe(&self.player_id).await; + room.remove_member(&self.player_id, self.storage_id).await; } let update = Updates::RoomLeft { room_id, @@ -396,12 +422,11 @@ impl Player { } async fn send_message(&mut self, connection_id: ConnectionId, room_id: RoomId, body: Str) { - let room = self.rooms.get_room(&room_id).await; - if let Some(room) = room { - room.send_message(self.player_id.clone(), body.clone()).await; - } else { + let Some(room) = self.my_rooms.get(&room_id) else { tracing::info!("no room found"); - } + return; + }; + room.send_message(&self.player_id, body.clone()).await; let update = Updates::NewMessage { room_id, author_id: self.player_id.clone(), @@ -411,12 +436,11 @@ impl Player { } async fn change_topic(&mut self, connection_id: ConnectionId, room_id: RoomId, new_topic: Str) { - let room = self.rooms.get_room(&room_id).await; - if let Some(mut room) = room { - room.set_topic(self.player_id.clone(), new_topic.clone()).await; - } else { + let Some(room) = self.my_rooms.get(&room_id) else { tracing::info!("no room found"); - } + return; + }; + room.set_topic(&self.player_id, new_topic.clone()).await; let update = Updates::RoomTopicChanged { room_id, new_topic }; self.broadcast_update(update, connection_id).await; } diff --git a/crates/lavina-core/src/repo/mod.rs b/crates/lavina-core/src/repo/mod.rs index e9eee6c..e8e3854 100644 --- a/crates/lavina-core/src/repo/mod.rs +++ b/crates/lavina-core/src/repo/mod.rs @@ -11,6 +11,9 @@ use tokio::sync::Mutex; use crate::prelude::*; +mod room; +mod user; + #[derive(Deserialize, Debug, Clone)] pub struct StorageConfig { pub db_path: String, @@ -48,7 +51,7 @@ impl Storage { Ok(res) } - pub async fn retrieve_room_by_name(&mut self, name: &str) -> Result> { + pub async fn retrieve_room_by_name(&self, name: &str) -> Result> { let mut executor = self.conn.lock().await; let res = sqlx::query_as( "select id, name, topic, message_count diff --git a/crates/lavina-core/src/repo/room.rs b/crates/lavina-core/src/repo/room.rs new file mode 100644 index 0000000..b568b9d --- /dev/null +++ b/crates/lavina-core/src/repo/room.rs @@ -0,0 +1,33 @@ +use anyhow::Result; + +use crate::repo::Storage; + +impl Storage { + pub async fn add_room_member(&self, room_id: u32, player_id: u32) -> Result<()> { + let mut executor = self.conn.lock().await; + sqlx::query( + "insert into memberships(user_id, room_id, status) + values (?, ?, 1);", + ) + .bind(player_id) + .bind(room_id) + .execute(&mut *executor) + .await?; + + Ok(()) + } + + pub async fn remove_room_member(&self, room_id: u32, player_id: u32) -> Result<()> { + let mut executor = self.conn.lock().await; + sqlx::query( + "delete from memberships + where user_id = ? and room_id = ?;", + ) + .bind(player_id) + .bind(room_id) + .execute(&mut *executor) + .await?; + + Ok(()) + } +} diff --git a/crates/lavina-core/src/repo/user.rs b/crates/lavina-core/src/repo/user.rs new file mode 100644 index 0000000..d836b8f --- /dev/null +++ b/crates/lavina-core/src/repo/user.rs @@ -0,0 +1,30 @@ +use anyhow::Result; + +use crate::repo::Storage; +use crate::room::RoomId; + +impl Storage { + pub async fn retrieve_user_id_by_name(&self, name: &str) -> Result> { + let mut executor = self.conn.lock().await; + let res: Option<(u32,)> = sqlx::query_as("select u.id from users u where u.name = ?;") + .bind(name) + .fetch_optional(&mut *executor) + .await?; + + Ok(res.map(|(id,)| id)) + } + + pub async fn get_rooms_of_a_user(&self, user_id: u32) -> Result> { + let mut executor = self.conn.lock().await; + let res: Vec<(String,)> = sqlx::query_as( + "select r.name + from memberships m inner join rooms r on m.room_id = r.id + where m.user_id = ?;", + ) + .bind(user_id) + .fetch_all(&mut *executor) + .await?; + + res.into_iter().map(|(room_id,)| RoomId::from(room_id)).collect() + } +} diff --git a/crates/lavina-core/src/room.rs b/crates/lavina-core/src/room.rs index 04fdbb1..dbd14fd 100644 --- a/crates/lavina-core/src/room.rs +++ b/crates/lavina-core/src/room.rs @@ -1,4 +1,5 @@ //! Domain of rooms — chats with multiple participants. +use std::collections::HashSet; use std::{collections::HashMap, hash::Hash, sync::Arc}; use prometheus::{IntGauge, Registry as MetricRegistry}; @@ -48,27 +49,9 @@ impl RoomRegistry { pub async fn get_or_create_room(&mut self, room_id: RoomId) -> Result { let mut inner = self.0.write().await; - if let Some(room_handle) = inner.rooms.get(&room_id) { - // room was already loaded into memory - log::debug!("Room {} was loaded already", &room_id.0); + if let Some(room_handle) = inner.get_or_load_room(&room_id).await? { Ok(room_handle.clone()) - } else if let Some(stored_room) = inner.storage.retrieve_room_by_name(&*room_id.0).await? { - // room exists, but was not loaded - log::debug!("Loading room {}...", &room_id.0); - let room = Room { - storage_id: stored_room.id, - room_id: room_id.clone(), - subscriptions: HashMap::new(), // TODO figure out how to populate subscriptions - topic: stored_room.topic.into(), - message_count: stored_room.message_count, - storage: inner.storage.clone(), - }; - let room_handle = RoomHandle(Arc::new(AsyncRwLock::new(room))); - inner.rooms.insert(room_id, room_handle.clone()); - inner.metric_active_rooms.inc(); - Ok(room_handle) } else { - // room does not exist, create it and load log::debug!("Creating room {}...", &room_id.0); let topic = "New room"; let id = inner.storage.create_new_room(&*room_id.0, &*topic).await?; @@ -76,6 +59,7 @@ impl RoomRegistry { storage_id: id, room_id: room_id.clone(), subscriptions: HashMap::new(), + members: HashSet::new(), topic: topic.into(), message_count: 0, storage: inner.storage.clone(), @@ -88,9 +72,8 @@ impl RoomRegistry { } pub async fn get_room(&self, room_id: &RoomId) -> Option { - let inner = self.0.read().await; - let res = inner.rooms.get(room_id); - res.map(|r| r.clone()) + let mut inner = self.0.write().await; + inner.get_or_load_room(room_id).await.unwrap() } pub async fn get_all_rooms(&self) -> Vec { @@ -113,17 +96,66 @@ struct RoomRegistryInner { storage: Storage, } +impl RoomRegistryInner { + async fn get_or_load_room(&mut self, room_id: &RoomId) -> Result> { + if let Some(room_handle) = self.rooms.get(room_id) { + log::debug!("Room {} was loaded already", &room_id.0); + Ok(Some(room_handle.clone())) + } else if let Some(stored_room) = self.storage.retrieve_room_by_name(&*room_id.0).await? { + log::debug!("Loading room {}...", &room_id.0); + let room = Room { + storage_id: stored_room.id, + room_id: room_id.clone(), + subscriptions: HashMap::new(), + members: HashSet::new(), // TODO load members from storage + topic: stored_room.topic.into(), + message_count: stored_room.message_count, + storage: self.storage.clone(), + }; + let room_handle = RoomHandle(Arc::new(AsyncRwLock::new(room))); + self.rooms.insert(room_id.clone(), room_handle.clone()); + self.metric_active_rooms.inc(); + Ok(Some(room_handle)) + } else { + tracing::debug!("Room {} does not exist", &room_id.0); + Ok(None) + } + } +} + #[derive(Clone)] pub struct RoomHandle(Arc>); impl RoomHandle { - pub async fn subscribe(&self, player_id: PlayerId, player_handle: PlayerHandle) { + pub async fn subscribe(&self, player_id: &PlayerId, player_handle: PlayerHandle) { let mut lock = self.0.write().await; - lock.add_subscriber(player_id, player_handle).await; + tracing::info!("Adding a subscriber to a room"); + lock.subscriptions.insert(player_id.clone(), player_handle); + } + + pub async fn add_member(&self, player_id: &PlayerId, player_storage_id: u32) { + let mut lock = self.0.write().await; + tracing::info!("Adding a new member to a room"); + let room_storage_id = lock.storage_id; + lock.storage.add_room_member(room_storage_id, player_storage_id).await.unwrap(); + lock.members.insert(player_id.clone()); + let update = Updates::RoomJoined { + room_id: lock.room_id.clone(), + new_member_id: player_id.clone(), + }; + lock.broadcast_update(update, player_id).await; } pub async fn unsubscribe(&self, player_id: &PlayerId) { let mut lock = self.0.write().await; lock.subscriptions.remove(player_id); + } + + pub async fn remove_member(&self, player_id: &PlayerId, player_storage_id: u32) { + let mut lock = self.0.write().await; + tracing::info!("Removing a member from a room"); + let room_storage_id = lock.storage_id; + lock.storage.remove_room_member(room_storage_id, player_storage_id).await.unwrap(); + lock.members.remove(player_id); let update = Updates::RoomLeft { room_id: lock.room_id.clone(), former_member_id: player_id.clone(), @@ -131,7 +163,7 @@ impl RoomHandle { lock.broadcast_update(update, player_id).await; } - pub async fn send_message(&self, player_id: PlayerId, body: Str) { + pub async fn send_message(&self, player_id: &PlayerId, body: Str) { let mut lock = self.0.write().await; let res = lock.send_message(player_id, body).await; if let Err(err) = res { @@ -148,14 +180,14 @@ impl RoomHandle { } } - pub async fn set_topic(&mut self, changer_id: PlayerId, new_topic: Str) { + pub async fn set_topic(&self, changer_id: &PlayerId, new_topic: Str) { let mut lock = self.0.write().await; lock.topic = new_topic.clone(); let update = Updates::RoomTopicChanged { room_id: lock.room_id.clone(), new_topic: new_topic.clone(), }; - lock.broadcast_update(update, &changer_id).await; + lock.broadcast_update(update, changer_id).await; } } @@ -166,23 +198,15 @@ struct Room { room_id: RoomId, /// Player actors on the local node which are subscribed to this room's updates. subscriptions: HashMap, + /// Members of the room. + members: HashSet, /// The total number of messages. Used to calculate the id of the new message. message_count: u32, topic: Str, storage: Storage, } impl Room { - async fn add_subscriber(&mut self, player_id: PlayerId, player_handle: PlayerHandle) { - tracing::info!("Adding a subscriber to room"); - self.subscriptions.insert(player_id.clone(), player_handle); - let update = Updates::RoomJoined { - room_id: self.room_id.clone(), - new_member_id: player_id.clone(), - }; - self.broadcast_update(update, &player_id).await; - } - - async fn send_message(&mut self, author_id: PlayerId, body: Str) -> Result<()> { + async fn send_message(&mut self, author_id: &PlayerId, body: Str) -> Result<()> { tracing::info!("Adding a message to room"); self.storage.insert_message(self.storage_id, self.message_count, &body, &*author_id.as_inner()).await?; self.message_count += 1; @@ -191,7 +215,7 @@ impl Room { author_id: author_id.clone(), body, }; - self.broadcast_update(update, &author_id).await; + self.broadcast_update(update, author_id).await; Ok(()) } diff --git a/crates/projection-irc/src/lib.rs b/crates/projection-irc/src/lib.rs index ee47f9b..e52e92a 100644 --- a/crates/projection-irc/src/lib.rs +++ b/crates/projection-irc/src/lib.rs @@ -392,7 +392,7 @@ async fn handle_registered_socket<'a>( log::info!("Handling registered user: {user:?}"); let player_id = PlayerId::from(user.nickname.clone())?; - let mut connection = players.connect_to_player(player_id.clone()).await; + let mut connection = players.connect_to_player(&player_id).await; let text: Str = format!("Welcome to {} Server", &config.server_name).into(); ServerMessage { diff --git a/crates/projection-irc/tests/lib.rs b/crates/projection-irc/tests/lib.rs index 36b9040..3618467 100644 --- a/crates/projection-irc/tests/lib.rs +++ b/crates/projection-irc/tests/lib.rs @@ -111,7 +111,36 @@ impl TestServer { }) .await?; let rooms = RoomRegistry::new(&mut metrics, storage.clone()).unwrap(); - let players = PlayerRegistry::empty(rooms.clone(), &mut metrics).unwrap(); + let players = PlayerRegistry::empty(rooms.clone(), storage.clone(), &mut metrics).unwrap(); + let server = launch(config, players.clone(), rooms.clone(), metrics.clone(), storage.clone()).await.unwrap(); + Ok(TestServer { + metrics, + storage, + rooms, + players, + server, + }) + } + + async fn reboot(mut self) -> Result { + let config = ServerConfig { + listen_on: "127.0.0.1:0".parse().unwrap(), + server_name: "testserver".into(), + }; + let TestServer { + mut metrics, + mut storage, + rooms, + mut players, + server, + } = self; + server.terminate().await?; + players.shutdown_all().await.unwrap(); + drop(players); + drop(rooms); + let mut metrics = MetricsRegistry::new(); + let rooms = RoomRegistry::new(&mut metrics, storage.clone()).unwrap(); + let players = PlayerRegistry::empty(rooms.clone(), storage.clone(), &mut metrics).unwrap(); let server = launch(config, players.clone(), rooms.clone(), metrics.clone(), storage.clone()).await.unwrap(); Ok(TestServer { metrics, @@ -152,6 +181,76 @@ async fn scenario_basic() -> Result<()> { Ok(()) } +#[tokio::test] +async fn scenario_join_and_reboot() -> Result<()> { + let mut server = TestServer::start().await?; + + // test scenario + + server.storage.create_user("tester").await?; + server.storage.set_password("tester", "password").await?; + + let mut stream = TcpStream::connect(server.server.addr).await?; + let mut s = TestScope::new(&mut stream); + + // Open a connection and join a channel + + s.send("PASS password").await?; + s.send("NICK tester").await?; + s.send("USER UserName 0 * :Real Name").await?; + s.expect_server_introduction("tester").await?; + s.expect_nothing().await?; + s.send("JOIN #test").await?; + s.expect(":tester JOIN #test").await?; + s.expect(":testserver 332 tester #test :New room").await?; + s.expect(":testserver 353 tester = #test :tester").await?; + s.expect(":testserver 366 tester #test :End of /NAMES list").await?; + s.send("PRIVMSG #test :Hello").await?; + s.send("QUIT :Leaving").await?; + s.expect(":testserver ERROR :Leaving the server").await?; + s.expect_eof().await?; + stream.shutdown().await?; + + // Open a new connection and expect to be force-joined to the channel + + let mut stream = TcpStream::connect(server.server.addr).await?; + let mut s = TestScope::new(&mut stream); + + async fn test(s: &mut TestScope<'_>) -> Result<()> { + s.send("PASS password").await?; + s.send("NICK tester").await?; + s.send("USER UserName 0 * :Real Name").await?; + s.expect_server_introduction("tester").await?; + s.expect(":tester JOIN #test").await?; + s.expect(":testserver 332 tester #test :New room").await?; + s.expect(":testserver 353 tester = #test :tester").await?; + s.expect(":testserver 366 tester #test :End of /NAMES list").await?; + s.send("QUIT :Leaving").await?; + s.expect(":testserver ERROR :Leaving the server").await?; + s.expect_eof().await?; + Ok(()) + } + test(&mut s).await?; + stream.shutdown().await?; + + // Reboot the server + + let server = server.reboot().await?; + + // Open a new connection and expect to be force-joined to the channel + + let mut stream = TcpStream::connect(server.server.addr).await?; + let mut s = TestScope::new(&mut stream); + + test(&mut s).await?; + stream.shutdown().await?; + + // wrap up + + server.server.terminate().await?; + Ok(()) +} + #[tokio::test] async fn scenario_force_join_msg() -> Result<()> { let mut server = TestServer::start().await?; @@ -407,7 +506,6 @@ async fn scenario_cap_sasl_fail() -> Result<()> { #[tokio::test] async fn terminate_socket_scenario() -> Result<()> { let mut server = TestServer::start().await?; - let address: SocketAddr = ("127.0.0.1:0".parse().unwrap()); // test scenario diff --git a/crates/projection-xmpp/src/lib.rs b/crates/projection-xmpp/src/lib.rs index 0da3019..30e0a3c 100644 --- a/crates/projection-xmpp/src/lib.rs +++ b/crates/projection-xmpp/src/lib.rs @@ -207,7 +207,7 @@ async fn handle_socket( authenticated = socket_auth(&mut xml_reader, &mut xml_writer, &mut reader_buf, &mut storage, &hostname) => { match authenticated { Ok(authenticated) => { - let mut connection = players.connect_to_player(authenticated.player_id.clone()).await; + let mut connection = players.connect_to_player(&authenticated.player_id).await; socket_final( &mut xml_reader, &mut xml_writer, diff --git a/crates/projection-xmpp/tests/lib.rs b/crates/projection-xmpp/tests/lib.rs index cc7e645..29d0368 100644 --- a/crates/projection-xmpp/tests/lib.rs +++ b/crates/projection-xmpp/tests/lib.rs @@ -143,7 +143,7 @@ impl TestServer { }) .await?; let rooms = RoomRegistry::new(&mut metrics, storage.clone()).unwrap(); - let players = PlayerRegistry::empty(rooms.clone(), &mut metrics).unwrap(); + let players = PlayerRegistry::empty(rooms.clone(), storage.clone(), &mut metrics).unwrap(); let server = launch(config, players.clone(), rooms.clone(), metrics.clone(), storage.clone()).await.unwrap(); Ok(TestServer { metrics, diff --git a/src/main.rs b/src/main.rs index 8111074..0d03a89 100644 --- a/src/main.rs +++ b/src/main.rs @@ -52,7 +52,7 @@ async fn main() -> Result<()> { let mut metrics = MetricsRegistry::new(); let storage = Storage::open(storage_config).await?; let rooms = RoomRegistry::new(&mut metrics, storage.clone())?; - let mut players = PlayerRegistry::empty(rooms.clone(), &mut metrics)?; + let mut players = PlayerRegistry::empty(rooms.clone(), storage.clone(), &mut metrics)?; let telemetry_terminator = http::launch(telemetry_config, metrics.clone(), rooms.clone(), storage.clone()).await?; let irc = projection_irc::launch( irc_config,