diff --git a/config.toml b/config.toml index 6104dce..4765aa0 100644 --- a/config.toml +++ b/config.toml @@ -9,6 +9,7 @@ server_name = "irc.localhost" listen_on = "127.0.0.1:5222" cert = "./certs/xmpp.pem" key = "./certs/xmpp.key" +hostname = "localhost" [storage] db_path = "db.sqlite" diff --git a/crates/lavina-core/src/player.rs b/crates/lavina-core/src/player.rs index eec22f8..0486808 100644 --- a/crates/lavina-core/src/player.rs +++ b/crates/lavina-core/src/player.rs @@ -7,16 +7,16 @@ //! //! A player actor is a serial handler of commands from a single player. It is preferable to run all per-player validations in the player actor, //! so that they don't overload the room actor. -use std::{ - collections::{HashMap, HashSet}, - sync::{Arc, RwLock}, -}; +use std::collections::{HashMap, HashSet}; +use std::sync::Arc; use prometheus::{IntGauge, Registry as MetricsRegistry}; use serde::Serialize; use tokio::sync::mpsc::{channel, Receiver, Sender}; +use tokio::sync::RwLock; use crate::prelude::*; +use crate::repo::Storage; use crate::room::{RoomHandle, RoomId, RoomInfo, RoomRegistry}; use crate::table::{AnonTable, Key as AnonKey}; @@ -208,36 +208,41 @@ pub enum Updates { #[derive(Clone)] pub struct PlayerRegistry(Arc>); impl PlayerRegistry { - pub fn empty(room_registry: RoomRegistry, metrics: &mut MetricsRegistry) -> Result { + pub fn empty( + room_registry: RoomRegistry, + storage: Storage, + metrics: &mut MetricsRegistry, + ) -> Result { let metric_active_players = IntGauge::new("chat_players_active", "Number of alive player actors")?; metrics.register(Box::new(metric_active_players.clone()))?; let inner = PlayerRegistryInner { room_registry, + storage, players: HashMap::new(), metric_active_players, }; Ok(PlayerRegistry(Arc::new(RwLock::new(inner)))) } - pub async fn get_or_create_player(&mut self, id: PlayerId) -> PlayerHandle { - let mut inner = self.0.write().unwrap(); - if let Some((handle, _)) = inner.players.get(&id) { + pub async fn get_or_launch_player(&mut self, id: &PlayerId) -> PlayerHandle { + let mut inner = self.0.write().await; + if let Some((handle, _)) = inner.players.get(id) { handle.clone() } else { - let (handle, fiber) = Player::launch(id.clone(), inner.room_registry.clone()); - inner.players.insert(id, (handle.clone(), fiber)); + let (handle, fiber) = Player::launch(id.clone(), inner.room_registry.clone(), inner.storage.clone()).await; + inner.players.insert(id.clone(), (handle.clone(), fiber)); inner.metric_active_players.inc(); handle } } - pub async fn connect_to_player(&mut self, id: PlayerId) -> PlayerConnection { - let player_handle = self.get_or_create_player(id).await; + pub async fn connect_to_player(&mut self, id: &PlayerId) -> PlayerConnection { + let player_handle = self.get_or_launch_player(id).await; player_handle.subscribe().await } pub async fn shutdown_all(&mut self) -> Result<()> { - let mut inner = self.0.write().unwrap(); + let mut inner = self.0.write().await; for (i, (k, j)) in inner.players.drain() { k.send(ActorCommand::Stop).await; drop(k); @@ -252,6 +257,8 @@ impl PlayerRegistry { /// The player registry state representation. struct PlayerRegistryInner { room_registry: RoomRegistry, + storage: Storage, + /// Active player actors. players: HashMap)>, metric_active_players: IntGauge, } @@ -259,32 +266,49 @@ struct PlayerRegistryInner { /// Player actor inner state representation. struct Player { player_id: PlayerId, + storage_id: u32, connections: AnonTable>, my_rooms: HashMap, banned_from: HashSet, rx: Receiver, handle: PlayerHandle, rooms: RoomRegistry, + storage: Storage, } impl Player { - fn launch(player_id: PlayerId, rooms: RoomRegistry) -> (PlayerHandle, JoinHandle) { + async fn launch(player_id: PlayerId, rooms: RoomRegistry, storage: Storage) -> (PlayerHandle, JoinHandle) { let (tx, rx) = channel(32); let handle = PlayerHandle { tx }; let handle_clone = handle.clone(); + let storage_id = storage.retrieve_user_id_by_name(player_id.as_inner()).await.unwrap().unwrap(); let player = Player { player_id, + storage_id, + // connections are empty when the actor is just started connections: AnonTable::new(), + // room handlers will be loaded later in the started task my_rooms: HashMap::new(), - banned_from: HashSet::from([RoomId::from("Empty").unwrap()]), + // TODO implement and load bans + banned_from: HashSet::new(), rx, handle, rooms, + storage, }; let fiber = tokio::task::spawn(player.main_loop()); (handle_clone, fiber) } async fn main_loop(mut self) -> Self { + let rooms = self.storage.get_rooms_of_a_user(self.storage_id).await.unwrap(); + for room_id in rooms { + let room = self.rooms.get_room(&room_id).await; + if let Some(room) = room { + self.my_rooms.insert(room_id, room); + } else { + tracing::error!("Room #{room_id:?} not found"); + } + } while let Some(cmd) = self.rx.recv().await { match cmd { ActorCommand::AddConnection { sender, promise } => { @@ -372,7 +396,8 @@ impl Player { todo!(); } }; - room.subscribe(self.player_id.clone(), self.handle.clone()).await; + room.add_member(&self.player_id, self.storage_id).await; + room.subscribe(&self.player_id, self.handle.clone()).await; self.my_rooms.insert(room_id.clone(), room.clone()); let room_info = room.get_room_info().await; let update = Updates::RoomJoined { @@ -387,6 +412,7 @@ impl Player { let room = self.my_rooms.remove(&room_id); if let Some(room) = room { room.unsubscribe(&self.player_id).await; + room.remove_member(&self.player_id, self.storage_id).await; } let update = Updates::RoomLeft { room_id, @@ -396,12 +422,11 @@ impl Player { } async fn send_message(&mut self, connection_id: ConnectionId, room_id: RoomId, body: Str) { - let room = self.rooms.get_room(&room_id).await; - if let Some(room) = room { - room.send_message(self.player_id.clone(), body.clone()).await; - } else { + let Some(room) = self.my_rooms.get(&room_id) else { tracing::info!("no room found"); - } + return; + }; + room.send_message(&self.player_id, body.clone()).await; let update = Updates::NewMessage { room_id, author_id: self.player_id.clone(), @@ -411,12 +436,11 @@ impl Player { } async fn change_topic(&mut self, connection_id: ConnectionId, room_id: RoomId, new_topic: Str) { - let room = self.rooms.get_room(&room_id).await; - if let Some(mut room) = room { - room.set_topic(self.player_id.clone(), new_topic.clone()).await; - } else { + let Some(room) = self.my_rooms.get(&room_id) else { tracing::info!("no room found"); - } + return; + }; + room.set_topic(&self.player_id, new_topic.clone()).await; let update = Updates::RoomTopicChanged { room_id, new_topic }; self.broadcast_update(update, connection_id).await; } diff --git a/crates/lavina-core/src/repo/mod.rs b/crates/lavina-core/src/repo/mod.rs index d8af5f7..e8e3854 100644 --- a/crates/lavina-core/src/repo/mod.rs +++ b/crates/lavina-core/src/repo/mod.rs @@ -12,6 +12,7 @@ use tokio::sync::Mutex; use crate::prelude::*; mod room; +mod user; #[derive(Deserialize, Debug, Clone)] pub struct StorageConfig { @@ -50,7 +51,7 @@ impl Storage { Ok(res) } - pub async fn retrieve_room_by_name(&mut self, name: &str) -> Result> { + pub async fn retrieve_room_by_name(&self, name: &str) -> Result> { let mut executor = self.conn.lock().await; let res = sqlx::query_as( "select id, name, topic, message_count diff --git a/crates/lavina-core/src/repo/room.rs b/crates/lavina-core/src/repo/room.rs index b3ffc0f..96b89f2 100644 --- a/crates/lavina-core/src/repo/room.rs +++ b/crates/lavina-core/src/repo/room.rs @@ -3,6 +3,34 @@ use anyhow::Result; use crate::repo::Storage; impl Storage { + pub async fn add_room_member(&self, room_id: u32, player_id: u32) -> Result<()> { + let mut executor = self.conn.lock().await; + sqlx::query( + "insert into memberships(user_id, room_id, status) + values (?, ?, 1);", + ) + .bind(player_id) + .bind(room_id) + .execute(&mut *executor) + .await?; + + Ok(()) + } + + pub async fn remove_room_member(&self, room_id: u32, player_id: u32) -> Result<()> { + let mut executor = self.conn.lock().await; + sqlx::query( + "delete from memberships + where user_id = ? and room_id = ?;", + ) + .bind(player_id) + .bind(room_id) + .execute(&mut *executor) + .await?; + + Ok(()) + } + pub async fn set_room_topic(&mut self, id: u32, topic: &str) -> Result<()> { let mut executor = self.conn.lock().await; sqlx::query( diff --git a/crates/lavina-core/src/repo/user.rs b/crates/lavina-core/src/repo/user.rs new file mode 100644 index 0000000..d836b8f --- /dev/null +++ b/crates/lavina-core/src/repo/user.rs @@ -0,0 +1,30 @@ +use anyhow::Result; + +use crate::repo::Storage; +use crate::room::RoomId; + +impl Storage { + pub async fn retrieve_user_id_by_name(&self, name: &str) -> Result> { + let mut executor = self.conn.lock().await; + let res: Option<(u32,)> = sqlx::query_as("select u.id from users u where u.name = ?;") + .bind(name) + .fetch_optional(&mut *executor) + .await?; + + Ok(res.map(|(id,)| id)) + } + + pub async fn get_rooms_of_a_user(&self, user_id: u32) -> Result> { + let mut executor = self.conn.lock().await; + let res: Vec<(String,)> = sqlx::query_as( + "select r.name + from memberships m inner join rooms r on m.room_id = r.id + where m.user_id = ?;", + ) + .bind(user_id) + .fetch_all(&mut *executor) + .await?; + + res.into_iter().map(|(room_id,)| RoomId::from(room_id)).collect() + } +} diff --git a/crates/lavina-core/src/room.rs b/crates/lavina-core/src/room.rs index 2173dff..a5e2dab 100644 --- a/crates/lavina-core/src/room.rs +++ b/crates/lavina-core/src/room.rs @@ -1,4 +1,5 @@ //! Domain of rooms — chats with multiple participants. +use std::collections::HashSet; use std::{collections::HashMap, hash::Hash, sync::Arc}; use prometheus::{IntGauge, Registry as MetricRegistry}; @@ -48,27 +49,9 @@ impl RoomRegistry { pub async fn get_or_create_room(&mut self, room_id: RoomId) -> Result { let mut inner = self.0.write().await; - if let Some(room_handle) = inner.rooms.get(&room_id) { - // room was already loaded into memory - log::debug!("Room {} was loaded already", &room_id.0); + if let Some(room_handle) = inner.get_or_load_room(&room_id).await? { Ok(room_handle.clone()) - } else if let Some(stored_room) = inner.storage.retrieve_room_by_name(&*room_id.0).await? { - // room exists, but was not loaded - log::debug!("Loading room {}...", &room_id.0); - let room = Room { - storage_id: stored_room.id, - room_id: room_id.clone(), - subscriptions: HashMap::new(), // TODO figure out how to populate subscriptions - topic: stored_room.topic.into(), - message_count: stored_room.message_count, - storage: inner.storage.clone(), - }; - let room_handle = RoomHandle(Arc::new(AsyncRwLock::new(room))); - inner.rooms.insert(room_id, room_handle.clone()); - inner.metric_active_rooms.inc(); - Ok(room_handle) } else { - // room does not exist, create it and load log::debug!("Creating room {}...", &room_id.0); let topic = "New room"; let id = inner.storage.create_new_room(&*room_id.0, &*topic).await?; @@ -76,6 +59,7 @@ impl RoomRegistry { storage_id: id, room_id: room_id.clone(), subscriptions: HashMap::new(), + members: HashSet::new(), topic: topic.into(), message_count: 0, storage: inner.storage.clone(), @@ -88,9 +72,8 @@ impl RoomRegistry { } pub async fn get_room(&self, room_id: &RoomId) -> Option { - let inner = self.0.read().await; - let res = inner.rooms.get(room_id); - res.map(|r| r.clone()) + let mut inner = self.0.write().await; + inner.get_or_load_room(room_id).await.unwrap() } pub async fn get_all_rooms(&self) -> Vec { @@ -113,17 +96,66 @@ struct RoomRegistryInner { storage: Storage, } +impl RoomRegistryInner { + async fn get_or_load_room(&mut self, room_id: &RoomId) -> Result> { + if let Some(room_handle) = self.rooms.get(room_id) { + log::debug!("Room {} was loaded already", &room_id.0); + Ok(Some(room_handle.clone())) + } else if let Some(stored_room) = self.storage.retrieve_room_by_name(&*room_id.0).await? { + log::debug!("Loading room {}...", &room_id.0); + let room = Room { + storage_id: stored_room.id, + room_id: room_id.clone(), + subscriptions: HashMap::new(), + members: HashSet::new(), // TODO load members from storage + topic: stored_room.topic.into(), + message_count: stored_room.message_count, + storage: self.storage.clone(), + }; + let room_handle = RoomHandle(Arc::new(AsyncRwLock::new(room))); + self.rooms.insert(room_id.clone(), room_handle.clone()); + self.metric_active_rooms.inc(); + Ok(Some(room_handle)) + } else { + tracing::debug!("Room {} does not exist", &room_id.0); + Ok(None) + } + } +} + #[derive(Clone)] pub struct RoomHandle(Arc>); impl RoomHandle { - pub async fn subscribe(&self, player_id: PlayerId, player_handle: PlayerHandle) { + pub async fn subscribe(&self, player_id: &PlayerId, player_handle: PlayerHandle) { let mut lock = self.0.write().await; - lock.add_subscriber(player_id, player_handle).await; + tracing::info!("Adding a subscriber to a room"); + lock.subscriptions.insert(player_id.clone(), player_handle); + } + + pub async fn add_member(&self, player_id: &PlayerId, player_storage_id: u32) { + let mut lock = self.0.write().await; + tracing::info!("Adding a new member to a room"); + let room_storage_id = lock.storage_id; + lock.storage.add_room_member(room_storage_id, player_storage_id).await.unwrap(); + lock.members.insert(player_id.clone()); + let update = Updates::RoomJoined { + room_id: lock.room_id.clone(), + new_member_id: player_id.clone(), + }; + lock.broadcast_update(update, player_id).await; } pub async fn unsubscribe(&self, player_id: &PlayerId) { let mut lock = self.0.write().await; lock.subscriptions.remove(player_id); + } + + pub async fn remove_member(&self, player_id: &PlayerId, player_storage_id: u32) { + let mut lock = self.0.write().await; + tracing::info!("Removing a member from a room"); + let room_storage_id = lock.storage_id; + lock.storage.remove_room_member(room_storage_id, player_storage_id).await.unwrap(); + lock.members.remove(player_id); let update = Updates::RoomLeft { room_id: lock.room_id.clone(), former_member_id: player_id.clone(), @@ -131,7 +163,7 @@ impl RoomHandle { lock.broadcast_update(update, player_id).await; } - pub async fn send_message(&self, player_id: PlayerId, body: Str) { + pub async fn send_message(&self, player_id: &PlayerId, body: Str) { let mut lock = self.0.write().await; let res = lock.send_message(player_id, body).await; if let Err(err) = res { @@ -148,7 +180,7 @@ impl RoomHandle { } } - pub async fn set_topic(&mut self, changer_id: PlayerId, new_topic: Str) { + pub async fn set_topic(&self, changer_id: &PlayerId, new_topic: Str) { let mut lock = self.0.write().await; let storage_id = lock.storage_id; lock.topic = new_topic.clone(); @@ -157,7 +189,7 @@ impl RoomHandle { room_id: lock.room_id.clone(), new_topic: new_topic.clone(), }; - lock.broadcast_update(update, &changer_id).await; + lock.broadcast_update(update, changer_id).await; } } @@ -168,23 +200,15 @@ struct Room { room_id: RoomId, /// Player actors on the local node which are subscribed to this room's updates. subscriptions: HashMap, + /// Members of the room. + members: HashSet, /// The total number of messages. Used to calculate the id of the new message. message_count: u32, topic: Str, storage: Storage, } impl Room { - async fn add_subscriber(&mut self, player_id: PlayerId, player_handle: PlayerHandle) { - tracing::info!("Adding a subscriber to room"); - self.subscriptions.insert(player_id.clone(), player_handle); - let update = Updates::RoomJoined { - room_id: self.room_id.clone(), - new_member_id: player_id.clone(), - }; - self.broadcast_update(update, &player_id).await; - } - - async fn send_message(&mut self, author_id: PlayerId, body: Str) -> Result<()> { + async fn send_message(&mut self, author_id: &PlayerId, body: Str) -> Result<()> { tracing::info!("Adding a message to room"); self.storage.insert_message(self.storage_id, self.message_count, &body, &*author_id.as_inner()).await?; self.message_count += 1; @@ -193,7 +217,7 @@ impl Room { author_id: author_id.clone(), body, }; - self.broadcast_update(update, &author_id).await; + self.broadcast_update(update, author_id).await; Ok(()) } diff --git a/crates/projection-irc/src/lib.rs b/crates/projection-irc/src/lib.rs index ee47f9b..e52e92a 100644 --- a/crates/projection-irc/src/lib.rs +++ b/crates/projection-irc/src/lib.rs @@ -392,7 +392,7 @@ async fn handle_registered_socket<'a>( log::info!("Handling registered user: {user:?}"); let player_id = PlayerId::from(user.nickname.clone())?; - let mut connection = players.connect_to_player(player_id.clone()).await; + let mut connection = players.connect_to_player(&player_id).await; let text: Str = format!("Welcome to {} Server", &config.server_name).into(); ServerMessage { diff --git a/crates/projection-irc/tests/lib.rs b/crates/projection-irc/tests/lib.rs index 36b9040..3618467 100644 --- a/crates/projection-irc/tests/lib.rs +++ b/crates/projection-irc/tests/lib.rs @@ -111,7 +111,36 @@ impl TestServer { }) .await?; let rooms = RoomRegistry::new(&mut metrics, storage.clone()).unwrap(); - let players = PlayerRegistry::empty(rooms.clone(), &mut metrics).unwrap(); + let players = PlayerRegistry::empty(rooms.clone(), storage.clone(), &mut metrics).unwrap(); + let server = launch(config, players.clone(), rooms.clone(), metrics.clone(), storage.clone()).await.unwrap(); + Ok(TestServer { + metrics, + storage, + rooms, + players, + server, + }) + } + + async fn reboot(mut self) -> Result { + let config = ServerConfig { + listen_on: "127.0.0.1:0".parse().unwrap(), + server_name: "testserver".into(), + }; + let TestServer { + mut metrics, + mut storage, + rooms, + mut players, + server, + } = self; + server.terminate().await?; + players.shutdown_all().await.unwrap(); + drop(players); + drop(rooms); + let mut metrics = MetricsRegistry::new(); + let rooms = RoomRegistry::new(&mut metrics, storage.clone()).unwrap(); + let players = PlayerRegistry::empty(rooms.clone(), storage.clone(), &mut metrics).unwrap(); let server = launch(config, players.clone(), rooms.clone(), metrics.clone(), storage.clone()).await.unwrap(); Ok(TestServer { metrics, @@ -152,6 +181,76 @@ async fn scenario_basic() -> Result<()> { Ok(()) } +#[tokio::test] +async fn scenario_join_and_reboot() -> Result<()> { + let mut server = TestServer::start().await?; + + // test scenario + + server.storage.create_user("tester").await?; + server.storage.set_password("tester", "password").await?; + + let mut stream = TcpStream::connect(server.server.addr).await?; + let mut s = TestScope::new(&mut stream); + + // Open a connection and join a channel + + s.send("PASS password").await?; + s.send("NICK tester").await?; + s.send("USER UserName 0 * :Real Name").await?; + s.expect_server_introduction("tester").await?; + s.expect_nothing().await?; + s.send("JOIN #test").await?; + s.expect(":tester JOIN #test").await?; + s.expect(":testserver 332 tester #test :New room").await?; + s.expect(":testserver 353 tester = #test :tester").await?; + s.expect(":testserver 366 tester #test :End of /NAMES list").await?; + s.send("PRIVMSG #test :Hello").await?; + s.send("QUIT :Leaving").await?; + s.expect(":testserver ERROR :Leaving the server").await?; + s.expect_eof().await?; + stream.shutdown().await?; + + // Open a new connection and expect to be force-joined to the channel + + let mut stream = TcpStream::connect(server.server.addr).await?; + let mut s = TestScope::new(&mut stream); + + async fn test(s: &mut TestScope<'_>) -> Result<()> { + s.send("PASS password").await?; + s.send("NICK tester").await?; + s.send("USER UserName 0 * :Real Name").await?; + s.expect_server_introduction("tester").await?; + s.expect(":tester JOIN #test").await?; + s.expect(":testserver 332 tester #test :New room").await?; + s.expect(":testserver 353 tester = #test :tester").await?; + s.expect(":testserver 366 tester #test :End of /NAMES list").await?; + s.send("QUIT :Leaving").await?; + s.expect(":testserver ERROR :Leaving the server").await?; + s.expect_eof().await?; + Ok(()) + } + test(&mut s).await?; + stream.shutdown().await?; + + // Reboot the server + + let server = server.reboot().await?; + + // Open a new connection and expect to be force-joined to the channel + + let mut stream = TcpStream::connect(server.server.addr).await?; + let mut s = TestScope::new(&mut stream); + + test(&mut s).await?; + stream.shutdown().await?; + + // wrap up + + server.server.terminate().await?; + Ok(()) +} + #[tokio::test] async fn scenario_force_join_msg() -> Result<()> { let mut server = TestServer::start().await?; @@ -407,7 +506,6 @@ async fn scenario_cap_sasl_fail() -> Result<()> { #[tokio::test] async fn terminate_socket_scenario() -> Result<()> { let mut server = TestServer::start().await?; - let address: SocketAddr = ("127.0.0.1:0".parse().unwrap()); // test scenario diff --git a/crates/projection-xmpp/src/iq.rs b/crates/projection-xmpp/src/iq.rs index 01135b1..6766e19 100644 --- a/crates/projection-xmpp/src/iq.rs +++ b/crates/projection-xmpp/src/iq.rs @@ -25,7 +25,7 @@ impl<'a> XmppConnection<'a> { r#type: IqType::Result, body: BindResponse(Jid { name: Some(self.user.xmpp_name.clone()), - server: Server("localhost".into()), + server: Server(self.hostname.clone()), resource: Some(self.user.xmpp_resource.clone()), }), }; @@ -52,7 +52,7 @@ impl<'a> XmppConnection<'a> { req.serialize(output); } IqClientBody::DiscoInfo(info) => { - let response = disco_info(iq.to.as_deref(), &info); + let response = self.disco_info(iq.to.as_deref(), &info); let req = Iq { from: iq.to, id: iq.id, @@ -63,7 +63,7 @@ impl<'a> XmppConnection<'a> { req.serialize(output); } IqClientBody::DiscoItem(item) => { - let response = disco_items(iq.to.as_deref(), &item, self.rooms).await; + let response = self.disco_items(iq.to.as_deref(), &item, self.rooms).await; let req = Iq { from: iq.to, id: iq.id, @@ -87,78 +87,79 @@ impl<'a> XmppConnection<'a> { } } } -} -fn disco_info(to: Option<&str>, req: &InfoQuery) -> InfoQuery { - let identity; - let feature; - match to { - Some("localhost") => { - identity = vec![Identity { - category: "server".into(), - name: None, - r#type: "im".into(), - }]; - feature = vec![ - Feature::new("http://jabber.org/protocol/disco#info"), - Feature::new("http://jabber.org/protocol/disco#items"), - Feature::new("iq"), - Feature::new("presence"), - ] - } - Some("rooms.localhost") => { - identity = vec![Identity { - category: "conference".into(), - name: Some("Chat rooms".into()), - r#type: "text".into(), - }]; - feature = vec![ - Feature::new("http://jabber.org/protocol/disco#info"), - Feature::new("http://jabber.org/protocol/disco#items"), - Feature::new("http://jabber.org/protocol/muc"), - ] - } - _ => { - identity = vec![]; - feature = vec![]; - } - }; - InfoQuery { - node: None, - identity, - feature, - } -} + fn disco_info(&self, to: Option<&str>, req: &InfoQuery) -> InfoQuery { + let identity; + let feature; -async fn disco_items(to: Option<&str>, req: &ItemQuery, rooms: &RoomRegistry) -> ItemQuery { - let item = match to { - Some("localhost") => { - vec![Item { - jid: Jid { + match to { + Some(r) if r == &*self.hostname => { + identity = vec![Identity { + category: "server".into(), name: None, - server: Server("rooms.localhost".into()), - resource: None, - }, - name: None, - node: None, - }] + r#type: "im".into(), + }]; + feature = vec![ + Feature::new("http://jabber.org/protocol/disco#info"), + Feature::new("http://jabber.org/protocol/disco#items"), + Feature::new("iq"), + Feature::new("presence"), + ] + } + Some(r) if r == &*self.hostname_rooms => { + identity = vec![Identity { + category: "conference".into(), + name: Some("Chat rooms".into()), + r#type: "text".into(), + }]; + feature = vec![ + Feature::new("http://jabber.org/protocol/disco#info"), + Feature::new("http://jabber.org/protocol/disco#items"), + Feature::new("http://jabber.org/protocol/muc"), + ] + } + _ => { + identity = vec![]; + feature = vec![]; + } + }; + InfoQuery { + node: None, + identity, + feature, } - Some("rooms.localhost") => { - let room_list = rooms.get_all_rooms().await; - room_list - .into_iter() - .map(|room_info| Item { + } + + async fn disco_items(&self, to: Option<&str>, req: &ItemQuery, rooms: &RoomRegistry) -> ItemQuery { + let item = match to { + Some(r) if r == &*self.hostname => { + vec![Item { jid: Jid { - name: Some(Name(room_info.id.into_inner())), - server: Server("rooms.localhost".into()), + name: None, + server: Server(self.hostname_rooms.clone()), resource: None, }, name: None, node: None, - }) - .collect() - } - _ => vec![], - }; - ItemQuery { item } + }] + } + Some(r) if r == &*self.hostname_rooms => { + let room_list = rooms.get_all_rooms().await; + room_list + .into_iter() + .map(|room_info| Item { + jid: Jid { + name: Some(Name(room_info.id.into_inner())), + server: Server(self.hostname_rooms.clone()), + resource: None, + }, + name: None, + node: None, + }) + .collect() + } + _ => vec![], + }; + ItemQuery { item } + } } diff --git a/crates/projection-xmpp/src/lib.rs b/crates/projection-xmpp/src/lib.rs index c659e5b..30e0a3c 100644 --- a/crates/projection-xmpp/src/lib.rs +++ b/crates/projection-xmpp/src/lib.rs @@ -9,7 +9,6 @@ use std::net::SocketAddr; use std::path::PathBuf; use std::sync::Arc; -use anyhow::anyhow; use futures_util::future::join_all; use prometheus::Registry as MetricsRegistry; use quick_xml::events::{BytesDecl, Event}; @@ -44,6 +43,7 @@ pub struct ServerConfig { pub listen_on: SocketAddr, pub cert: PathBuf, pub key: PathBuf, + pub hostname: Str, } struct LoadedConfig { @@ -125,11 +125,12 @@ pub async fn launch( let players = players.clone(); let rooms = rooms.clone(); let storage = storage.clone(); + let hostname = config.hostname.clone(); let terminator = Terminator::spawn(|termination| { let stopped_tx = stopped_tx.clone(); let loaded_config = loaded_config.clone(); async move { - match handle_socket(loaded_config, stream, &socket_addr, players, rooms, storage, termination).await { + match handle_socket(loaded_config, stream, &socket_addr, players, rooms, storage, hostname, termination).await { Ok(_) => log::info!("Connection terminated"), Err(err) => log::warn!("Connection failed: {err}"), } @@ -164,12 +165,13 @@ pub async fn launch( } async fn handle_socket( - config: Arc, + cert_config: Arc, mut stream: TcpStream, socket_addr: &SocketAddr, mut players: PlayerRegistry, rooms: RoomRegistry, mut storage: Storage, + hostname: Str, termination: Deferred<()>, // TODO use it to stop the connection gracefully ) -> Result<()> { log::info!("Received an XMPP connection from {socket_addr}"); @@ -178,12 +180,12 @@ async fn handle_socket( let mut buf_reader = BufReader::new(reader); let mut buf_writer = BufWriter::new(writer); - socket_force_tls(&mut buf_reader, &mut buf_writer, &mut reader_buf).await?; + socket_force_tls(&mut buf_reader, &mut buf_writer, &mut reader_buf, &hostname).await?; let mut config = tokio_rustls::rustls::ServerConfig::builder() .with_safe_defaults() .with_no_client_auth() - .with_single_cert(vec![config.cert.clone()], config.key.clone())?; + .with_single_cert(vec![cert_config.cert.clone()], cert_config.key.clone())?; config.key_log = Arc::new(tokio_rustls::rustls::KeyLogFile::new()); log::debug!("Accepting TLS connection..."); @@ -202,10 +204,10 @@ async fn handle_socket( log::info!("Socket handling was terminated"); return Ok(()) }, - authenticated = socket_auth(&mut xml_reader, &mut xml_writer, &mut reader_buf, &mut storage) => { + authenticated = socket_auth(&mut xml_reader, &mut xml_writer, &mut reader_buf, &mut storage, &hostname) => { match authenticated { Ok(authenticated) => { - let mut connection = players.connect_to_player(authenticated.player_id.clone()).await; + let mut connection = players.connect_to_player(&authenticated.player_id).await; socket_final( &mut xml_reader, &mut xml_writer, @@ -213,6 +215,7 @@ async fn handle_socket( &authenticated, &mut connection, &rooms, + &hostname, ) .await?; }, @@ -233,16 +236,18 @@ async fn socket_force_tls( reader: &mut (impl AsyncBufRead + Unpin), writer: &mut (impl AsyncWrite + Unpin), reader_buf: &mut Vec, + hostname: &Str, ) -> Result<()> { use proto_xmpp::tls::*; let xml_reader = &mut NsReader::from_reader(reader); let xml_writer = &mut Writer::new(writer); + // TODO validate the server hostname received in the stream start let _ = ClientStreamStart::parse(xml_reader, reader_buf).await?; let event = Event::Decl(BytesDecl::new("1.0", None, None)); xml_writer.write_event_async(event).await?; let msg = ServerStreamStart { - from: "localhost".into(), + from: hostname.to_string(), lang: "en".into(), id: uuid::Uuid::new_v4().to_string(), version: "1.0".into(), @@ -267,12 +272,14 @@ async fn socket_auth( xml_writer: &mut Writer<(impl AsyncWrite + Unpin)>, reader_buf: &mut Vec, storage: &mut Storage, + hostname: &Str, ) -> Result { + // TODO validate the server hostname received in the stream start let _ = ClientStreamStart::parse(xml_reader, reader_buf).await?; xml_writer.write_event_async(Event::Decl(BytesDecl::new("1.0", None, None))).await?; ServerStreamStart { - from: "localhost".into(), + from: hostname.to_string(), lang: "en".into(), id: uuid::Uuid::new_v4().to_string(), version: "1.0".into(), @@ -335,12 +342,14 @@ async fn socket_final( authenticated: &Authenticated, user_handle: &mut PlayerConnection, rooms: &RoomRegistry, + hostname: &Str, ) -> Result<()> { + // TODO validate the server hostname received in the stream start let _ = ClientStreamStart::parse(xml_reader, reader_buf).await?; xml_writer.write_event_async(Event::Decl(BytesDecl::new("1.0", None, None))).await?; ServerStreamStart { - from: "localhost".into(), + from: hostname.to_string(), lang: "en".into(), id: uuid::Uuid::new_v4().to_string(), version: "1.0".into(), @@ -366,6 +375,8 @@ async fn socket_final( user: authenticated, user_handle, rooms, + hostname: hostname.clone(), + hostname_rooms: format!("rooms.{}", hostname).into(), }; let should_recreate_xml_future = select! { biased; @@ -422,6 +433,8 @@ struct XmppConnection<'a> { user: &'a Authenticated, user_handle: &'a mut PlayerConnection, rooms: &'a RoomRegistry, + hostname: Str, + hostname_rooms: Str, } impl<'a> XmppConnection<'a> { diff --git a/crates/projection-xmpp/src/message.rs b/crates/projection-xmpp/src/message.rs index 44aab05..a737b2b 100644 --- a/crates/projection-xmpp/src/message.rs +++ b/crates/projection-xmpp/src/message.rs @@ -18,17 +18,17 @@ impl<'a> XmppConnection<'a> { resource: _, }) = m.to { - if server.0.as_ref() == "rooms.localhost" && m.r#type == MessageType::Groupchat { + if server.0.as_ref() == &*self.hostname_rooms && m.r#type == MessageType::Groupchat { self.user_handle.send_message(RoomId::from(name.0.clone())?, m.body.clone().into()).await?; Message::<()> { to: Some(Jid { name: Some(self.user.xmpp_name.clone()), - server: Server("localhost".into()), + server: Server(self.hostname.clone()), resource: Some(self.user.xmpp_resource.clone()), }), from: Some(Jid { name: Some(name), - server: Server("rooms.localhost".into()), + server: Server(self.hostname_rooms.clone()), resource: Some(self.user.xmpp_muc_name.clone()), }), id: m.id, diff --git a/crates/projection-xmpp/src/presence.rs b/crates/projection-xmpp/src/presence.rs index eabf0fd..6f9540e 100644 --- a/crates/projection-xmpp/src/presence.rs +++ b/crates/projection-xmpp/src/presence.rs @@ -16,12 +16,12 @@ impl<'a> XmppConnection<'a> { Presence::<()> { to: Some(Jid { name: Some(self.user.xmpp_name.clone()), - server: Server("localhost".into()), + server: Server(self.hostname.clone()), resource: Some(self.user.xmpp_resource.clone()), }), from: Some(Jid { name: Some(self.user.xmpp_name.clone()), - server: Server("localhost".into()), + server: Server(self.hostname.clone()), resource: Some(self.user.xmpp_resource.clone()), }), ..Default::default() @@ -36,12 +36,12 @@ impl<'a> XmppConnection<'a> { Presence::<()> { to: Some(Jid { name: Some(self.user.xmpp_name.clone()), - server: Server("localhost".into()), + server: Server(self.hostname.clone()), resource: Some(self.user.xmpp_resource.clone()), }), from: Some(Jid { name: Some(name.clone()), - server: Server("rooms.localhost".into()), + server: Server(self.hostname_rooms.clone()), resource: Some(self.user.xmpp_muc_name.clone()), }), ..Default::default() diff --git a/crates/projection-xmpp/src/updates.rs b/crates/projection-xmpp/src/updates.rs index c211be8..0161b3f 100644 --- a/crates/projection-xmpp/src/updates.rs +++ b/crates/projection-xmpp/src/updates.rs @@ -21,12 +21,12 @@ impl<'a> XmppConnection<'a> { Message::<()> { to: Some(Jid { name: Some(self.user.xmpp_name.clone()), - server: Server("localhost".into()), + server: Server(self.hostname.clone()), resource: Some(self.user.xmpp_resource.clone()), }), from: Some(Jid { name: Some(Name(room_id.into_inner().into())), - server: Server("rooms.localhost".into()), + server: Server(self.hostname_rooms.clone()), resource: Some(Resource(author_id.into_inner().into())), }), id: None, diff --git a/crates/projection-xmpp/tests/lib.rs b/crates/projection-xmpp/tests/lib.rs index 7ce583c..29d0368 100644 --- a/crates/projection-xmpp/tests/lib.rs +++ b/crates/projection-xmpp/tests/lib.rs @@ -135,6 +135,7 @@ impl TestServer { listen_on: "127.0.0.1:0".parse().unwrap(), cert: "tests/certs/xmpp.pem".parse().unwrap(), key: "tests/certs/xmpp.key".parse().unwrap(), + hostname: "localhost".into(), }; let mut metrics = MetricsRegistry::new(); let mut storage = Storage::open(StorageConfig { @@ -142,7 +143,7 @@ impl TestServer { }) .await?; let rooms = RoomRegistry::new(&mut metrics, storage.clone()).unwrap(); - let players = PlayerRegistry::empty(rooms.clone(), &mut metrics).unwrap(); + let players = PlayerRegistry::empty(rooms.clone(), storage.clone(), &mut metrics).unwrap(); let server = launch(config, players.clone(), rooms.clone(), metrics.clone(), storage.clone()).await.unwrap(); Ok(TestServer { metrics, diff --git a/docs/running.md b/docs/running.md index 61f3067..ad422a1 100644 --- a/docs/running.md +++ b/docs/running.md @@ -19,6 +19,7 @@ server_name = "irc.localhost" listen_on = "127.0.0.1:5222" cert = "./certs/xmpp.pem" key = "./certs/xmpp.key" +hostname = "localhost" [storage] db_path = "db.sqlite" diff --git a/src/main.rs b/src/main.rs index 8111074..0d03a89 100644 --- a/src/main.rs +++ b/src/main.rs @@ -52,7 +52,7 @@ async fn main() -> Result<()> { let mut metrics = MetricsRegistry::new(); let storage = Storage::open(storage_config).await?; let rooms = RoomRegistry::new(&mut metrics, storage.clone())?; - let mut players = PlayerRegistry::empty(rooms.clone(), &mut metrics)?; + let mut players = PlayerRegistry::empty(rooms.clone(), storage.clone(), &mut metrics)?; let telemetry_terminator = http::launch(telemetry_config, metrics.clone(), rooms.clone(), storage.clone()).await?; let irc = projection_irc::launch( irc_config,