From 041c1b5fe2a12fc45c3123552cd0632289a24a00 Mon Sep 17 00:00:00 2001 From: Nikita Vilunov Date: Sat, 13 Apr 2024 02:27:35 +0200 Subject: [PATCH] implement persistent memberships --- crates/lavina-core/src/player.rs | 56 ++++++++------ crates/lavina-core/src/repo/mod.rs | 3 + crates/lavina-core/src/repo/room.rs | 19 +++++ crates/lavina-core/src/repo/user.rs | 15 ++++ crates/lavina-core/src/room.rs | 39 +++++++--- crates/projection-irc/tests/lib.rs | 102 ++++++++++++++++++++++++- crates/projection-xmpp/tests/lib.rs | 112 +++++++++++++--------------- src/main.rs | 2 +- 8 files changed, 252 insertions(+), 96 deletions(-) create mode 100644 crates/lavina-core/src/repo/room.rs create mode 100644 crates/lavina-core/src/repo/user.rs diff --git a/crates/lavina-core/src/player.rs b/crates/lavina-core/src/player.rs index eec22f8..4beef24 100644 --- a/crates/lavina-core/src/player.rs +++ b/crates/lavina-core/src/player.rs @@ -7,16 +7,16 @@ //! //! A player actor is a serial handler of commands from a single player. It is preferable to run all per-player validations in the player actor, //! so that they don't overload the room actor. -use std::{ - collections::{HashMap, HashSet}, - sync::{Arc, RwLock}, -}; +use std::collections::{HashMap, HashSet}; +use std::sync::Arc; use prometheus::{IntGauge, Registry as MetricsRegistry}; use serde::Serialize; use tokio::sync::mpsc::{channel, Receiver, Sender}; +use tokio::sync::RwLock; use crate::prelude::*; +use crate::repo::Storage; use crate::room::{RoomHandle, RoomId, RoomInfo, RoomRegistry}; use crate::table::{AnonTable, Key as AnonKey}; @@ -208,23 +208,28 @@ pub enum Updates { #[derive(Clone)] pub struct PlayerRegistry(Arc>); impl PlayerRegistry { - pub fn empty(room_registry: RoomRegistry, metrics: &mut MetricsRegistry) -> Result { + pub fn empty( + room_registry: RoomRegistry, + storage: Storage, + metrics: &mut MetricsRegistry, + ) -> Result { let metric_active_players = IntGauge::new("chat_players_active", "Number of alive player actors")?; metrics.register(Box::new(metric_active_players.clone()))?; let inner = PlayerRegistryInner { room_registry, + storage, players: HashMap::new(), metric_active_players, }; Ok(PlayerRegistry(Arc::new(RwLock::new(inner)))) } - pub async fn get_or_create_player(&mut self, id: PlayerId) -> PlayerHandle { - let mut inner = self.0.write().unwrap(); + pub async fn get_or_launch_player(&mut self, id: PlayerId) -> PlayerHandle { + let mut inner = self.0.write().await; if let Some((handle, _)) = inner.players.get(&id) { handle.clone() } else { - let (handle, fiber) = Player::launch(id.clone(), inner.room_registry.clone()); + let (handle, fiber) = Player::launch(id.clone(), inner.room_registry.clone(), inner.storage.clone()).await; inner.players.insert(id, (handle.clone(), fiber)); inner.metric_active_players.inc(); handle @@ -232,12 +237,12 @@ impl PlayerRegistry { } pub async fn connect_to_player(&mut self, id: PlayerId) -> PlayerConnection { - let player_handle = self.get_or_create_player(id).await; + let player_handle = self.get_or_launch_player(id).await; player_handle.subscribe().await } pub async fn shutdown_all(&mut self) -> Result<()> { - let mut inner = self.0.write().unwrap(); + let mut inner = self.0.write().await; for (i, (k, j)) in inner.players.drain() { k.send(ActorCommand::Stop).await; drop(k); @@ -252,6 +257,8 @@ impl PlayerRegistry { /// The player registry state representation. struct PlayerRegistryInner { room_registry: RoomRegistry, + storage: Storage, + /// Active player actors. players: HashMap)>, metric_active_players: IntGauge, } @@ -259,26 +266,31 @@ struct PlayerRegistryInner { /// Player actor inner state representation. struct Player { player_id: PlayerId, + storage_id: u32, connections: AnonTable>, my_rooms: HashMap, banned_from: HashSet, rx: Receiver, handle: PlayerHandle, rooms: RoomRegistry, + storage: Storage, } impl Player { - fn launch(player_id: PlayerId, rooms: RoomRegistry) -> (PlayerHandle, JoinHandle) { + async fn launch(player_id: PlayerId, rooms: RoomRegistry, storage: Storage) -> (PlayerHandle, JoinHandle) { let (tx, rx) = channel(32); let handle = PlayerHandle { tx }; let handle_clone = handle.clone(); + let storage_id = storage.retrieve_user_id_by_name(player_id.as_inner()).await.unwrap().unwrap(); let player = Player { player_id, + storage_id, connections: AnonTable::new(), my_rooms: HashMap::new(), - banned_from: HashSet::from([RoomId::from("Empty").unwrap()]), + banned_from: HashSet::new(), rx, handle, rooms, + storage, }; let fiber = tokio::task::spawn(player.main_loop()); (handle_clone, fiber) @@ -372,6 +384,7 @@ impl Player { todo!(); } }; + room.add_member(self.player_id.clone(), self.storage_id).await; room.subscribe(self.player_id.clone(), self.handle.clone()).await; self.my_rooms.insert(room_id.clone(), room.clone()); let room_info = room.get_room_info().await; @@ -387,6 +400,7 @@ impl Player { let room = self.my_rooms.remove(&room_id); if let Some(room) = room { room.unsubscribe(&self.player_id).await; + room.remove_member(&self.player_id).await; } let update = Updates::RoomLeft { room_id, @@ -396,12 +410,11 @@ impl Player { } async fn send_message(&mut self, connection_id: ConnectionId, room_id: RoomId, body: Str) { - let room = self.rooms.get_room(&room_id).await; - if let Some(room) = room { - room.send_message(self.player_id.clone(), body.clone()).await; - } else { + let Some(room) = self.my_rooms.get(&room_id) else { tracing::info!("no room found"); - } + return; + }; + room.send_message(self.player_id.clone(), body.clone()).await; let update = Updates::NewMessage { room_id, author_id: self.player_id.clone(), @@ -411,12 +424,11 @@ impl Player { } async fn change_topic(&mut self, connection_id: ConnectionId, room_id: RoomId, new_topic: Str) { - let room = self.rooms.get_room(&room_id).await; - if let Some(mut room) = room { - room.set_topic(self.player_id.clone(), new_topic.clone()).await; - } else { + let Some(room) = self.my_rooms.get(&room_id) else { tracing::info!("no room found"); - } + return; + }; + room.set_topic(self.player_id.clone(), new_topic.clone()).await; let update = Updates::RoomTopicChanged { room_id, new_topic }; self.broadcast_update(update, connection_id).await; } diff --git a/crates/lavina-core/src/repo/mod.rs b/crates/lavina-core/src/repo/mod.rs index e9eee6c..cb94d43 100644 --- a/crates/lavina-core/src/repo/mod.rs +++ b/crates/lavina-core/src/repo/mod.rs @@ -11,6 +11,9 @@ use tokio::sync::Mutex; use crate::prelude::*; +mod user; +mod room; + #[derive(Deserialize, Debug, Clone)] pub struct StorageConfig { pub db_path: String, diff --git a/crates/lavina-core/src/repo/room.rs b/crates/lavina-core/src/repo/room.rs new file mode 100644 index 0000000..7b68a3d --- /dev/null +++ b/crates/lavina-core/src/repo/room.rs @@ -0,0 +1,19 @@ +use anyhow::Result; + +use crate::repo::Storage; + +impl Storage { + pub async fn add_room_member(&mut self, room_id: &u32, player_id: &u32) -> Result<()> { + let mut executor = self.conn.lock().await; + sqlx::query( + "insert into memberships(user_id, room_id, status) + values (?, ?, 1);", + ) + .bind(player_id) + .bind(room_id) + .execute(&mut *executor) + .await?; + + Ok(()) + } +} diff --git a/crates/lavina-core/src/repo/user.rs b/crates/lavina-core/src/repo/user.rs new file mode 100644 index 0000000..bfa34e7 --- /dev/null +++ b/crates/lavina-core/src/repo/user.rs @@ -0,0 +1,15 @@ +use anyhow::Result; + +use crate::repo::Storage; + +impl Storage { + pub async fn retrieve_user_id_by_name(&self, name: &str) -> Result> { + let mut executor = self.conn.lock().await; + let res: Option<(u32,)> = sqlx::query_as("select u.id from users u where u.name = ?;") + .bind(name) + .fetch_optional(&mut *executor) + .await?; + + Ok(res.map(|(id,)| id)) + } +} diff --git a/crates/lavina-core/src/room.rs b/crates/lavina-core/src/room.rs index 04fdbb1..58f5140 100644 --- a/crates/lavina-core/src/room.rs +++ b/crates/lavina-core/src/room.rs @@ -1,4 +1,5 @@ //! Domain of rooms — chats with multiple participants. +use std::collections::HashSet; use std::{collections::HashMap, hash::Hash, sync::Arc}; use prometheus::{IntGauge, Registry as MetricRegistry}; @@ -59,6 +60,7 @@ impl RoomRegistry { storage_id: stored_room.id, room_id: room_id.clone(), subscriptions: HashMap::new(), // TODO figure out how to populate subscriptions + members: HashSet::new(), // TODO load members from storage topic: stored_room.topic.into(), message_count: stored_room.message_count, storage: inner.storage.clone(), @@ -76,6 +78,7 @@ impl RoomRegistry { storage_id: id, room_id: room_id.clone(), subscriptions: HashMap::new(), + members: HashSet::new(), topic: topic.into(), message_count: 0, storage: inner.storage.clone(), @@ -118,12 +121,32 @@ pub struct RoomHandle(Arc>); impl RoomHandle { pub async fn subscribe(&self, player_id: PlayerId, player_handle: PlayerHandle) { let mut lock = self.0.write().await; - lock.add_subscriber(player_id, player_handle).await; + tracing::info!("Adding a subscriber to a room"); + lock.subscriptions.insert(player_id.clone(), player_handle); + } + + pub async fn add_member(&self, player_id: PlayerId, player_storage_id: u32) { + let mut lock = self.0.write().await; + tracing::info!("Adding a new member to a room"); + let storage_id = lock.storage_id; + lock.members.insert(player_id.clone()); + lock.storage.add_room_member(&storage_id, &player_storage_id).await.unwrap(); + let update = Updates::RoomJoined { + room_id: lock.room_id.clone(), + new_member_id: player_id.clone(), + }; + lock.broadcast_update(update, &player_id).await; } pub async fn unsubscribe(&self, player_id: &PlayerId) { let mut lock = self.0.write().await; lock.subscriptions.remove(player_id); + } + + pub async fn remove_member(&self, player_id: &PlayerId) { + let mut lock = self.0.write().await; + tracing::info!("Removing a member from a room"); + lock.members.remove(player_id); let update = Updates::RoomLeft { room_id: lock.room_id.clone(), former_member_id: player_id.clone(), @@ -148,7 +171,7 @@ impl RoomHandle { } } - pub async fn set_topic(&mut self, changer_id: PlayerId, new_topic: Str) { + pub async fn set_topic(&self, changer_id: PlayerId, new_topic: Str) { let mut lock = self.0.write().await; lock.topic = new_topic.clone(); let update = Updates::RoomTopicChanged { @@ -166,22 +189,14 @@ struct Room { room_id: RoomId, /// Player actors on the local node which are subscribed to this room's updates. subscriptions: HashMap, + /// Members of the room. + members: HashSet, /// The total number of messages. Used to calculate the id of the new message. message_count: u32, topic: Str, storage: Storage, } impl Room { - async fn add_subscriber(&mut self, player_id: PlayerId, player_handle: PlayerHandle) { - tracing::info!("Adding a subscriber to room"); - self.subscriptions.insert(player_id.clone(), player_handle); - let update = Updates::RoomJoined { - room_id: self.room_id.clone(), - new_member_id: player_id.clone(), - }; - self.broadcast_update(update, &player_id).await; - } - async fn send_message(&mut self, author_id: PlayerId, body: Str) -> Result<()> { tracing::info!("Adding a message to room"); self.storage.insert_message(self.storage_id, self.message_count, &body, &*author_id.as_inner()).await?; diff --git a/crates/projection-irc/tests/lib.rs b/crates/projection-irc/tests/lib.rs index 36b9040..3618467 100644 --- a/crates/projection-irc/tests/lib.rs +++ b/crates/projection-irc/tests/lib.rs @@ -111,7 +111,36 @@ impl TestServer { }) .await?; let rooms = RoomRegistry::new(&mut metrics, storage.clone()).unwrap(); - let players = PlayerRegistry::empty(rooms.clone(), &mut metrics).unwrap(); + let players = PlayerRegistry::empty(rooms.clone(), storage.clone(), &mut metrics).unwrap(); + let server = launch(config, players.clone(), rooms.clone(), metrics.clone(), storage.clone()).await.unwrap(); + Ok(TestServer { + metrics, + storage, + rooms, + players, + server, + }) + } + + async fn reboot(mut self) -> Result { + let config = ServerConfig { + listen_on: "127.0.0.1:0".parse().unwrap(), + server_name: "testserver".into(), + }; + let TestServer { + mut metrics, + mut storage, + rooms, + mut players, + server, + } = self; + server.terminate().await?; + players.shutdown_all().await.unwrap(); + drop(players); + drop(rooms); + let mut metrics = MetricsRegistry::new(); + let rooms = RoomRegistry::new(&mut metrics, storage.clone()).unwrap(); + let players = PlayerRegistry::empty(rooms.clone(), storage.clone(), &mut metrics).unwrap(); let server = launch(config, players.clone(), rooms.clone(), metrics.clone(), storage.clone()).await.unwrap(); Ok(TestServer { metrics, @@ -152,6 +181,76 @@ async fn scenario_basic() -> Result<()> { Ok(()) } +#[tokio::test] +async fn scenario_join_and_reboot() -> Result<()> { + let mut server = TestServer::start().await?; + + // test scenario + + server.storage.create_user("tester").await?; + server.storage.set_password("tester", "password").await?; + + let mut stream = TcpStream::connect(server.server.addr).await?; + let mut s = TestScope::new(&mut stream); + + // Open a connection and join a channel + + s.send("PASS password").await?; + s.send("NICK tester").await?; + s.send("USER UserName 0 * :Real Name").await?; + s.expect_server_introduction("tester").await?; + s.expect_nothing().await?; + s.send("JOIN #test").await?; + s.expect(":tester JOIN #test").await?; + s.expect(":testserver 332 tester #test :New room").await?; + s.expect(":testserver 353 tester = #test :tester").await?; + s.expect(":testserver 366 tester #test :End of /NAMES list").await?; + s.send("PRIVMSG #test :Hello").await?; + s.send("QUIT :Leaving").await?; + s.expect(":testserver ERROR :Leaving the server").await?; + s.expect_eof().await?; + stream.shutdown().await?; + + // Open a new connection and expect to be force-joined to the channel + + let mut stream = TcpStream::connect(server.server.addr).await?; + let mut s = TestScope::new(&mut stream); + + async fn test(s: &mut TestScope<'_>) -> Result<()> { + s.send("PASS password").await?; + s.send("NICK tester").await?; + s.send("USER UserName 0 * :Real Name").await?; + s.expect_server_introduction("tester").await?; + s.expect(":tester JOIN #test").await?; + s.expect(":testserver 332 tester #test :New room").await?; + s.expect(":testserver 353 tester = #test :tester").await?; + s.expect(":testserver 366 tester #test :End of /NAMES list").await?; + s.send("QUIT :Leaving").await?; + s.expect(":testserver ERROR :Leaving the server").await?; + s.expect_eof().await?; + Ok(()) + } + test(&mut s).await?; + stream.shutdown().await?; + + // Reboot the server + + let server = server.reboot().await?; + + // Open a new connection and expect to be force-joined to the channel + + let mut stream = TcpStream::connect(server.server.addr).await?; + let mut s = TestScope::new(&mut stream); + + test(&mut s).await?; + stream.shutdown().await?; + + // wrap up + + server.server.terminate().await?; + Ok(()) +} + #[tokio::test] async fn scenario_force_join_msg() -> Result<()> { let mut server = TestServer::start().await?; @@ -407,7 +506,6 @@ async fn scenario_cap_sasl_fail() -> Result<()> { #[tokio::test] async fn terminate_socket_scenario() -> Result<()> { let mut server = TestServer::start().await?; - let address: SocketAddr = ("127.0.0.1:0".parse().unwrap()); // test scenario diff --git a/crates/projection-xmpp/tests/lib.rs b/crates/projection-xmpp/tests/lib.rs index 9dfae4c..99c33b3 100644 --- a/crates/projection-xmpp/tests/lib.rs +++ b/crates/projection-xmpp/tests/lib.rs @@ -20,7 +20,7 @@ use tokio_rustls::TlsConnector; use lavina_core::player::PlayerRegistry; use lavina_core::repo::{Storage, StorageConfig}; use lavina_core::room::RoomRegistry; -use projection_xmpp::{launch, ServerConfig}; +use projection_xmpp::{launch, RunningServer, ServerConfig}; use proto_xmpp::xml::{Continuation, FromXml, Parser}; pub async fn read_irc_message(reader: &mut BufReader>, buf: &mut Vec) -> Result { @@ -122,29 +122,49 @@ impl ServerCertVerifier for IgnoreCertVerification { } } +struct TestServer { + metrics: MetricsRegistry, + storage: Storage, + rooms: RoomRegistry, + players: PlayerRegistry, + server: RunningServer, +} +impl TestServer { + async fn start() -> Result { + let _ = tracing_subscriber::fmt::try_init(); + let config = ServerConfig { + listen_on: "127.0.0.1:0".parse().unwrap(), + cert: "tests/certs/xmpp.pem".parse().unwrap(), + key: "tests/certs/xmpp.key".parse().unwrap(), + }; + let mut metrics = MetricsRegistry::new(); + let mut storage = Storage::open(StorageConfig { + db_path: ":memory:".into(), + }) + .await?; + let rooms = RoomRegistry::new(&mut metrics, storage.clone()).unwrap(); + let players = PlayerRegistry::empty(rooms.clone(), storage.clone(), &mut metrics).unwrap(); + let server = launch(config, players.clone(), rooms.clone(), metrics.clone(), storage.clone()).await.unwrap(); + Ok(TestServer { + metrics, + storage, + rooms, + players, + server, + }) + } +} + #[tokio::test] async fn scenario_basic() -> Result<()> { - tracing_subscriber::fmt::try_init(); - let config = ServerConfig { - listen_on: "127.0.0.1:0".parse().unwrap(), - cert: "tests/certs/xmpp.pem".parse().unwrap(), - key: "tests/certs/xmpp.key".parse().unwrap(), - }; - let mut metrics = MetricsRegistry::new(); - let mut storage = Storage::open(StorageConfig { - db_path: ":memory:".into(), - }) - .await?; - let rooms = RoomRegistry::new(&mut metrics, storage.clone()).unwrap(); - let players = PlayerRegistry::empty(rooms.clone(), &mut metrics).unwrap(); - let server = launch(config, players, rooms, metrics, storage.clone()).await.unwrap(); + let mut server = TestServer::start().await?; // test scenario - storage.create_user("tester").await?; - storage.set_password("tester", "password").await?; + server.storage.create_user("tester").await?; + server.storage.set_password("tester", "password").await?; - let mut stream = TcpStream::connect(server.addr).await?; + let mut stream = TcpStream::connect(server.server.addr).await?; let mut s = TestScope::new(&mut stream); tracing::info!("TCP connection established"); @@ -169,7 +189,7 @@ async fn scenario_basic() -> Result<()> { .with_no_client_auth(), )); tracing::info!("Initiating TLS connection..."); - let mut stream = connector.connect(ServerName::IpAddress(server.addr.ip()), stream).await?; + let mut stream = connector.connect(ServerName::IpAddress(server.server.addr.ip()), stream).await?; tracing::info!("TLS connection established"); let mut s = TestScopeTls::new(&mut stream, buffer); @@ -183,33 +203,20 @@ async fn scenario_basic() -> Result<()> { // wrap up - server.terminate().await?; + server.server.terminate().await?; Ok(()) } #[tokio::test] async fn scenario_basic_without_headers() -> Result<()> { - tracing_subscriber::fmt::try_init(); - let config = ServerConfig { - listen_on: "127.0.0.1:0".parse().unwrap(), - cert: "tests/certs/xmpp.pem".parse().unwrap(), - key: "tests/certs/xmpp.key".parse().unwrap(), - }; - let mut metrics = MetricsRegistry::new(); - let mut storage = Storage::open(StorageConfig { - db_path: ":memory:".into(), - }) - .await?; - let rooms = RoomRegistry::new(&mut metrics, storage.clone()).unwrap(); - let players = PlayerRegistry::empty(rooms.clone(), &mut metrics).unwrap(); - let server = launch(config, players, rooms, metrics, storage.clone()).await.unwrap(); + let mut server = TestServer::start().await?; // test scenario - storage.create_user("tester").await?; - storage.set_password("tester", "password").await?; + server.storage.create_user("tester").await?; + server.storage.set_password("tester", "password").await?; - let mut stream = TcpStream::connect(server.addr).await?; + let mut stream = TcpStream::connect(server.server.addr).await?; let mut s = TestScope::new(&mut stream); tracing::info!("TCP connection established"); @@ -233,7 +240,7 @@ async fn scenario_basic_without_headers() -> Result<()> { .with_no_client_auth(), )); tracing::info!("Initiating TLS connection..."); - let mut stream = connector.connect(ServerName::IpAddress(server.addr.ip()), stream).await?; + let mut stream = connector.connect(ServerName::IpAddress(server.server.addr.ip()), stream).await?; tracing::info!("TLS connection established"); let mut s = TestScopeTls::new(&mut stream, buffer); @@ -246,33 +253,20 @@ async fn scenario_basic_without_headers() -> Result<()> { // wrap up - server.terminate().await?; + server.server.terminate().await?; Ok(()) } #[tokio::test] async fn terminate_socket() -> Result<()> { - tracing_subscriber::fmt::try_init(); - let config = ServerConfig { - listen_on: "127.0.0.1:0".parse().unwrap(), - cert: "tests/certs/xmpp.pem".parse().unwrap(), - key: "tests/certs/xmpp.key".parse().unwrap(), - }; - let mut metrics = MetricsRegistry::new(); - let mut storage = Storage::open(StorageConfig { - db_path: ":memory:".into(), - }) - .await?; - let rooms = RoomRegistry::new(&mut metrics, storage.clone()).unwrap(); - let players = PlayerRegistry::empty(rooms.clone(), &mut metrics).unwrap(); - let server = launch(config, players, rooms, metrics, storage.clone()).await.unwrap(); - let address: SocketAddr = ("127.0.0.1:0".parse().unwrap()); + let mut server = TestServer::start().await?; + // test scenario - storage.create_user("tester").await?; - storage.set_password("tester", "password").await?; + server.storage.create_user("tester").await?; + server.storage.set_password("tester", "password").await?; - let mut stream = TcpStream::connect(server.addr).await?; + let mut stream = TcpStream::connect(server.server.addr).await?; let mut s = TestScope::new(&mut stream); tracing::info!("TCP connection established"); @@ -298,10 +292,10 @@ async fn terminate_socket() -> Result<()> { )); tracing::info!("Initiating TLS connection..."); - let mut stream = connector.connect(ServerName::IpAddress(server.addr.ip()), stream).await?; + let mut stream = connector.connect(ServerName::IpAddress(server.server.addr.ip()), stream).await?; tracing::info!("TLS connection established"); - server.terminate().await?; + server.server.terminate().await?; assert_eq!(stream.read_u8().await.unwrap_err().kind(), ErrorKind::UnexpectedEof); diff --git a/src/main.rs b/src/main.rs index 8111074..0d03a89 100644 --- a/src/main.rs +++ b/src/main.rs @@ -52,7 +52,7 @@ async fn main() -> Result<()> { let mut metrics = MetricsRegistry::new(); let storage = Storage::open(storage_config).await?; let rooms = RoomRegistry::new(&mut metrics, storage.clone())?; - let mut players = PlayerRegistry::empty(rooms.clone(), &mut metrics)?; + let mut players = PlayerRegistry::empty(rooms.clone(), storage.clone(), &mut metrics)?; let telemetry_terminator = http::launch(telemetry_config, metrics.clone(), rooms.clone(), storage.clone()).await?; let irc = projection_irc::launch( irc_config,