diff --git a/crates/lavina-core/src/player.rs b/crates/lavina-core/src/player.rs index 14a50ca..a5db4b7 100644 --- a/crates/lavina-core/src/player.rs +++ b/crates/lavina-core/src/player.rs @@ -112,9 +112,13 @@ impl PlayerConnection { } #[tracing::instrument(skip(self), name = "PlayerConnection::get_room_message_history")] - pub async fn get_room_message_history(&self, room_id: RoomId) -> Result> { + pub async fn get_room_message_history(&self, room_id: &RoomId, limit: u32) -> Result> { let (promise, deferred) = oneshot(); - let cmd = ClientCommand::GetRoomHistory { room_id, promise }; + let cmd = ClientCommand::GetRoomHistory { + room_id: room_id.clone(), + promise, + limit, + }; self.player_handle.send(ActorCommand::ClientCommand(cmd, self.connection_id.clone())).await; Ok(deferred.await?) } @@ -223,6 +227,7 @@ pub enum ClientCommand { GetRoomHistory { room_id: RoomId, promise: Promise>, + limit: u32, }, } @@ -522,8 +527,12 @@ impl Player { let result = self.check_user_existence(recipient).await; let _ = promise.send(result); } - ClientCommand::GetRoomHistory { room_id, promise } => { - let result = self.get_room_history(room_id).await; + ClientCommand::GetRoomHistory { + room_id, + promise, + limit, + } => { + let result = self.get_room_history(room_id, limit).await; let _ = promise.send(result); } } @@ -575,11 +584,11 @@ impl Player { } #[tracing::instrument(skip(self), name = "Player::retrieve_room_history")] - async fn get_room_history(&mut self, room_id: RoomId) -> Vec { + async fn get_room_history(&mut self, room_id: RoomId, limit: u32) -> Vec { let room = self.my_rooms.get(&room_id); if let Some(room) = room { match room { - RoomRef::Local(room) => room.get_message_history(&self.services).await, + RoomRef::Local(room) => room.get_message_history(&self.services, limit).await, RoomRef::Remote { node_id } => { todo!() } diff --git a/crates/lavina-core/src/repo/room.rs b/crates/lavina-core/src/repo/room.rs index bca612b..08ccc42 100644 --- a/crates/lavina-core/src/repo/room.rs +++ b/crates/lavina-core/src/repo/room.rs @@ -30,27 +30,35 @@ impl Storage { } #[tracing::instrument(skip(self), name = "Storage::retrieve_room_message_history")] - pub async fn get_room_message_history(&self, room_id: u32) -> Result> { + pub async fn get_room_message_history(&self, room_id: u32, limit: u32) -> Result> { let mut executor = self.conn.lock().await; let res = sqlx::query_as( " select - messages.id as id, - content, - created_at, - users.name as author_name - from - messages - join - users - on messages.author_id = users.id - where - room_id = ? + * + from ( + select + messages.id as id, + content, + created_at, + users.name as author_name + from + messages + join + users + on messages.author_id = users.id + where + room_id = ? + order by + messages.id desc + limit ? + ) order by - messages.id; + id asc; ", ) .bind(room_id) + .bind(limit) .fetch_all(&mut *executor) .await?; diff --git a/crates/lavina-core/src/room.rs b/crates/lavina-core/src/room.rs index 1b1ce30..3a4df49 100644 --- a/crates/lavina-core/src/room.rs +++ b/crates/lavina-core/src/room.rs @@ -160,8 +160,8 @@ impl RoomHandle { lock.broadcast_update(update, player_id).await; } - pub async fn get_message_history(&self, services: &LavinaCore) -> Vec { - return services.storage.get_room_message_history(self.0.read().await.storage_id).await.unwrap(); + pub async fn get_message_history(&self, services: &Services, limit: u32) -> Vec { + return services.storage.get_room_message_history(self.0.read().await.storage_id, limit).await.unwrap(); } #[tracing::instrument(skip(self), name = "RoomHandle::unsubscribe")] diff --git a/crates/projection-irc/src/cap.rs b/crates/projection-irc/src/cap.rs index 83f1e24..9d96cf2 100644 --- a/crates/projection-irc/src/cap.rs +++ b/crates/projection-irc/src/cap.rs @@ -6,5 +6,6 @@ bitflags! { const None = 0; const Sasl = 1 << 0; const ServerTime = 1 << 1; + const ChatHistory = 1 << 2; } } diff --git a/crates/projection-irc/src/lib.rs b/crates/projection-irc/src/lib.rs index 4bceaf5..46ba1a5 100644 --- a/crates/projection-irc/src/lib.rs +++ b/crates/projection-irc/src/lib.rs @@ -27,8 +27,11 @@ use proto_irc::server::{AwayStatus, ServerMessage, ServerMessageBody}; use proto_irc::user::PrefixedNick; use proto_irc::{Chan, Recipient, Tag}; use sasl::AuthBody; + mod cap; + use handler::Handler; + mod whois; use crate::cap::Capabilities; @@ -140,7 +143,7 @@ impl RegistrationState { sender: Some(config.server_name.clone().into()), body: ServerMessageBody::Cap { target: self.future_nickname.clone().unwrap_or_else(|| "*".into()), - subcmd: CapSubBody::Ls("sasl=PLAIN server-time".into()), + subcmd: CapSubBody::Ls("sasl=PLAIN server-time draft/chathistory".into()), }, } .write_async(writer) @@ -167,6 +170,13 @@ impl RegistrationState { self.enabled_capabilities |= Capabilities::ServerTime; } acked.push(cap); + } else if &*cap.name == "draft/chathistory" { + if cap.to_disable { + self.enabled_capabilities &= !Capabilities::ChatHistory; + } else { + self.enabled_capabilities |= Capabilities::ChatHistory; + } + acked.push(cap); } else { naked.push(cap); } @@ -853,6 +863,46 @@ async fn handle_incoming_message( log::info!("Received QUIT"); return Ok(HandleResult::Leave); } + ClientMessage::ChatHistory { chan, limit } => { + if user.enabled_capabilities.contains(Capabilities::ChatHistory) { + let channel_name = match chan.clone() { + Chan::Global(chan) => chan, + // TODO Respond with an error when a local channel is requested + Chan::Local(chan) => chan, + }; + let room = core.get_room(&RoomId::try_from(channel_name.clone())?).await; + // TODO Handle non-existent room + if let Some(room) = room { + let room_id = &RoomId::try_from(channel_name.clone())?; + let messages = user_handle.get_room_message_history(room_id, limit).await?; + for message in messages { + let mut tags = vec![]; + if user.enabled_capabilities.contains(Capabilities::ServerTime) { + let tag = Tag { + key: "time".into(), + value: Some(message.created_at.to_rfc3339_opts(SecondsFormat::Millis, true).into()), + }; + tags.push(tag); + } + ServerMessage { + tags, + sender: Some(message.author_name.into()), + body: ServerMessageBody::PrivateMessage { + target: Recipient::Chan(chan.clone()), + body: message.content.into(), + }, + } + .write_async(writer) + .await?; + } + writer.flush().await?; + } + } else { + log::warn!( + "Requested chat history for user {user:?} even though the capability was not negotiated" + ); + } + } cmd => { log::warn!("Not implemented handler for client command: {cmd:?}"); } diff --git a/crates/projection-irc/tests/lib.rs b/crates/projection-irc/tests/lib.rs index c94c25b..141dd46 100644 --- a/crates/projection-irc/tests/lib.rs +++ b/crates/projection-irc/tests/lib.rs @@ -95,7 +95,7 @@ impl<'a> TestScope<'a> { } async fn expect_cap_ls(&mut self) -> Result<()> { - self.expect(":testserver CAP * LS :sasl=PLAIN server-time").await?; + self.expect(":testserver CAP * LS :sasl=PLAIN server-time draft/chathistory").await?; Ok(()) } } @@ -104,6 +104,7 @@ struct TestServer { core: LavinaCore, server: RunningServer, } + impl TestServer { async fn start() -> Result { let _ = tracing_subscriber::fmt::try_init(); @@ -187,6 +188,61 @@ async fn scenario_basic() -> Result<()> { Ok(()) } +#[tokio::test] +async fn scenario_basic_with_chathistory() -> Result<()> { + let server = TestServer::start().await?; + + // test scenario + + server.core.create_player(&PlayerId::from("tester")?).await?; + server.core.set_password("tester", "password").await?; + + let mut stream = TcpStream::connect(server.server.addr).await?; + let mut s = TestScope::new(&mut stream); + + s.send("NICK tester").await?; + s.send("CAP REQ :draft/chathistory").await?; + s.send("USER UserName 0 * :Real Name").await?; + s.send("PASS password").await?; + s.send("USER UserName 0 * :Real Name").await?; + s.expect(":testserver CAP tester ACK :draft/chathistory").await?; + s.send("CAP END").await?; + + s.expect_server_introduction("tester").await?; + s.expect_nothing().await?; + + s.send("JOIN #test").await?; + s.expect(":tester JOIN #test").await?; + s.expect(":testserver 332 tester #test :New room").await?; + s.expect(":testserver 353 tester = #test :tester").await?; + s.expect(":testserver 366 tester #test :End of /NAMES list").await?; + + s.send("PRIVMSG #test :Message1").await?; + s.send("PRIVMSG #test :Message2").await?; + s.send("PRIVMSG #test :Message3").await?; + s.send("PRIVMSG #test :Message4").await?; + + s.send("CHATHISTORY LATEST #test * 1").await?; + s.expect(":tester PRIVMSG #test :Message4").await?; + + s.send("CHATHISTORY LATEST #test * 3").await?; + s.expect(":tester PRIVMSG #test :Message2").await?; + s.expect(":tester PRIVMSG #test :Message3").await?; + s.expect(":tester PRIVMSG #test :Message4").await?; + s.expect_nothing().await?; + + s.send("QUIT :Leaving").await?; + s.expect(":testserver ERROR :Leaving the server").await?; + s.expect_eof().await?; + + stream.shutdown().await?; + + // wrap up + + server.shutdown().await; + Ok(()) +} + #[tokio::test] async fn scenario_join_and_reboot() -> Result<()> { let server = TestServer::start().await?; diff --git a/crates/projection-xmpp/src/presence.rs b/crates/projection-xmpp/src/presence.rs index 9e8e2e0..c2a1cb3 100644 --- a/crates/projection-xmpp/src/presence.rs +++ b/crates/projection-xmpp/src/presence.rs @@ -179,7 +179,7 @@ impl<'a> XmppConnection<'a> { #[tracing::instrument(skip(self), name = "XmppConnection::retrieve_message_history")] async fn retrieve_message_history(&self, room_name: &Name) -> Result> { let room_id = RoomId::try_from(room_name.0.clone())?; - let history_messages = self.user_handle.get_room_message_history(room_id).await?; + let history_messages = self.user_handle.get_room_message_history(&room_id, 50).await?; let mut response = vec![]; for history_message in history_messages.into_iter() { diff --git a/crates/proto-irc/src/client.rs b/crates/proto-irc/src/client.rs index 5a62427..ab83822 100644 --- a/crates/proto-irc/src/client.rs +++ b/crates/proto-irc/src/client.rs @@ -65,6 +65,10 @@ pub enum ClientMessage { reason: Str, }, Authenticate(Str), + ChatHistory { + chan: Chan, + limit: u32, + }, } pub mod command_args { @@ -95,6 +99,7 @@ pub fn client_message(input: &str) -> Result { client_message_privmsg, client_message_quit, client_message_authenticate, + client_message_chathistory, )))(input); match res { Ok((_, e)) => Ok(e), @@ -134,6 +139,7 @@ fn client_message_nick(input: &str) -> IResult<&str, ClientMessage> { }, )) } + fn client_message_pass(input: &str) -> IResult<&str, ClientMessage> { let (input, _) = tag("PASS ")(input)?; let (input, r) = opt(tag(":"))(input)?; @@ -172,6 +178,7 @@ fn client_message_user(input: &str) -> IResult<&str, ClientMessage> { }, )) } + fn client_message_join(input: &str) -> IResult<&str, ClientMessage> { let (input, _) = tag("JOIN ")(input)?; let (input, chan) = chan(input)?; @@ -280,6 +287,22 @@ fn client_message_authenticate(input: &str) -> IResult<&str, ClientMessage> { Ok((input, ClientMessage::Authenticate(body.into()))) } +fn client_message_chathistory(input: &str) -> IResult<&str, ClientMessage> { + let (input, _) = tag("CHATHISTORY LATEST ")(input)?; + let (input, chan) = chan(input)?; + + let (input, _) = tag(" * ")(input)?; + let (input, limit) = limit(input)?; + + Ok((input, ClientMessage::ChatHistory { chan, limit })) +} + +fn limit(input: &str) -> IResult<&str, u32> { + let (input, limit) = receiver(input)?; + let limit = limit.parse().unwrap(); + Ok((input, limit)) +} + #[derive(Clone, Debug, PartialEq, Eq)] pub enum CapabilitySubcommand { /// CAP LS {code} @@ -383,6 +406,7 @@ mod test { let result = client_message(input); assert_matches!(result, Ok(result) => assert_eq!(expected, result)); } + #[test] fn test_client_message_pong() { let input = "PONG 1337"; @@ -391,6 +415,7 @@ mod test { let result = client_message(input); assert_matches!(result, Ok(result) => assert_eq!(expected, result)); } + #[test] fn test_client_message_nick() { let input = "NICK SomeNick"; @@ -401,6 +426,7 @@ mod test { let result = client_message(input); assert_matches!(result, Ok(result) => assert_eq!(expected, result)); } + #[test] fn test_client_message_whois() { let test_user = "WHOIS val"; @@ -461,6 +487,7 @@ mod test { assert_matches!(res_more_than_two_params, Ok(result) => assert_eq!(expected_more_than_two_params, result)); assert_matches!(res_none_none_params, Ok(result) => assert_eq!(expected_none_none_params, result)) } + #[test] fn test_client_message_user() { let input = "USER SomeNick 8 * :Real Name"; @@ -472,6 +499,7 @@ mod test { let result = client_message(input); assert_matches!(result, Ok(result) => assert_eq!(expected, result)); } + #[test] fn test_client_message_part() { let input = "PART #chan :Pokasiki !!!"; @@ -483,6 +511,7 @@ mod test { let result = client_message(input); assert_matches!(result, Ok(result) => assert_eq!(expected, result)); } + #[test] fn test_client_message_part_empty() { let input = "PART #chan"; @@ -494,6 +523,7 @@ mod test { let result = client_message(input); assert_matches!(result, Ok(result) => assert_eq!(expected, result)); } + #[test] fn test_client_cap_req() { let input = "CAP REQ :multi-prefix -sasl"; @@ -513,4 +543,16 @@ mod test { let result = client_message(input); assert_matches!(result, Ok(result) => assert_eq!(expected, result)); } + + #[test] + fn test_client_chat_history_latest() { + let input = "CHATHISTORY LATEST #chan * 10"; + let expected = ClientMessage::ChatHistory { + chan: Chan::Global("chan".into()), + limit: 10, + }; + + let result = client_message(input); + assert_matches!(result, Ok(result) => assert_eq!(expected, result)); + } }