diff --git a/Cargo.lock b/Cargo.lock index e2d52e1..f3bbd73 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1062,6 +1062,7 @@ version = "0.0.3-dev" dependencies = [ "anyhow", "assert_matches", + "chrono", "clap", "derive_more", "figment", diff --git a/Cargo.toml b/Cargo.toml index 093e62c..8e026da 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -64,6 +64,7 @@ opentelemetry-semantic-conventions = "0.14.0" opentelemetry_sdk = { version = "0.22.1", features = ["rt-tokio"] } opentelemetry-otlp = "0.15.0" tracing-opentelemetry = "0.23.0" +chrono.workspace = true [dev-dependencies] assert_matches.workspace = true diff --git a/crates/lavina-core/src/player.rs b/crates/lavina-core/src/player.rs index 8c89b86..21914c4 100644 --- a/crates/lavina-core/src/player.rs +++ b/crates/lavina-core/src/player.rs @@ -316,7 +316,7 @@ impl PlayerRegistry { } #[tracing::instrument(skip(self), name = "PlayerRegistry::get_or_launch_player")] - pub async fn get_or_launch_player(&mut self, id: &PlayerId) -> PlayerHandle { + pub async fn get_or_launch_player(&self, id: &PlayerId) -> PlayerHandle { let inner = self.0.read().await; if let Some((handle, _)) = inner.players.get(id) { handle.clone() @@ -341,7 +341,7 @@ impl PlayerRegistry { } #[tracing::instrument(skip(self), name = "PlayerRegistry::connect_to_player")] - pub async fn connect_to_player(&mut self, id: &PlayerId) -> PlayerConnection { + pub async fn connect_to_player(&self, id: &PlayerId) -> PlayerConnection { let player_handle = self.get_or_launch_player(id).await; player_handle.subscribe().await } diff --git a/crates/mgmt-api/src/lib.rs b/crates/mgmt-api/src/lib.rs index c21ff85..0ffbfdb 100644 --- a/crates/mgmt-api/src/lib.rs +++ b/crates/mgmt-api/src/lib.rs @@ -1,5 +1,7 @@ use serde::{Deserialize, Serialize}; +pub mod rooms; + #[derive(Serialize, Deserialize)] pub struct ErrorResponse<'a> { pub code: &'a str, diff --git a/crates/mgmt-api/src/rooms.rs b/crates/mgmt-api/src/rooms.rs new file mode 100644 index 0000000..c091467 --- /dev/null +++ b/crates/mgmt-api/src/rooms.rs @@ -0,0 +1,24 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Serialize, Deserialize)] +pub struct SendMessageReq<'a> { + pub room_id: &'a str, + pub author_id: &'a str, + pub message: &'a str, +} + +#[derive(Serialize, Deserialize)] +pub struct SetTopicReq<'a> { + pub room_id: &'a str, + pub author_id: &'a str, + pub topic: &'a str, +} + +pub mod paths { + pub const SEND_MESSAGE: &'static str = "/mgmt/rooms/send_message"; + pub const SET_TOPIC: &'static str = "/mgmt/rooms/set_topic"; +} + +pub mod errors { + pub const ROOM_NOT_FOUND: &'static str = "room_not_found"; +} diff --git a/crates/projection-xmpp/src/iq.rs b/crates/projection-xmpp/src/iq.rs index 0031c36..fdff68d 100644 --- a/crates/projection-xmpp/src/iq.rs +++ b/crates/projection-xmpp/src/iq.rs @@ -4,16 +4,16 @@ use quick_xml::events::Event; use lavina_core::room::{RoomId, RoomRegistry}; use proto_xmpp::bind::{BindResponse, Jid, Name, Server}; -use proto_xmpp::client::{Iq, IqError, IqErrorType, IqType}; +use proto_xmpp::client::{Iq, IqError, IqErrorType, IqType, Message, MessageType}; use proto_xmpp::disco::{Feature, Identity, InfoQuery, Item, ItemQuery}; +use proto_xmpp::mam::{Fin, Set}; use proto_xmpp::roster::RosterQuery; use proto_xmpp::session::Session; +use proto_xmpp::xml::ToXml; use crate::proto::IqClientBody; use crate::XmppConnection; -use proto_xmpp::xml::ToXml; - impl<'a> XmppConnection<'a> { pub async fn handle_iq(&self, output: &mut Vec>, iq: Iq) { match iq.body { @@ -87,6 +87,18 @@ impl<'a> XmppConnection<'a> { }; req.serialize(output); } + IqClientBody::MessageArchiveRequest(_) => { + let response = Iq { + from: iq.to, + id: iq.id, + to: None, + r#type: IqType::Result, + body: Fin { + set: Set { count: Some(0) }, + }, + }; + response.serialize(output); + } _ => { let req = Iq { from: None, diff --git a/crates/projection-xmpp/src/proto.rs b/crates/projection-xmpp/src/proto.rs index 0b157a6..2943509 100644 --- a/crates/projection-xmpp/src/proto.rs +++ b/crates/projection-xmpp/src/proto.rs @@ -7,6 +7,7 @@ use lavina_core::prelude::*; use proto_xmpp::bind::BindRequest; use proto_xmpp::client::{Iq, Message, Presence}; use proto_xmpp::disco::{InfoQuery, ItemQuery}; +use proto_xmpp::mam::MessageArchiveRequest; use proto_xmpp::roster::RosterQuery; use proto_xmpp::session::Session; use proto_xmpp::xml::*; @@ -18,6 +19,7 @@ pub enum IqClientBody { Roster(RosterQuery), DiscoInfo(InfoQuery), DiscoItem(ItemQuery), + MessageArchiveRequest(MessageArchiveRequest), Unknown(Ignore), } @@ -38,6 +40,7 @@ impl FromXml for IqClientBody { RosterQuery, InfoQuery, ItemQuery, + MessageArchiveRequest, { delegate_parsing!(Ignore, namespace, event).into() } diff --git a/crates/projection-xmpp/tests/lib.rs b/crates/projection-xmpp/tests/lib.rs index dd537a1..c1ea13a 100644 --- a/crates/projection-xmpp/tests/lib.rs +++ b/crates/projection-xmpp/tests/lib.rs @@ -1,4 +1,5 @@ use std::io::ErrorKind; +use std::str::from_utf8; use std::sync::Arc; use std::time::Duration; @@ -6,6 +7,7 @@ use anyhow::Result; use assert_matches::*; use prometheus::Registry as MetricsRegistry; use quick_xml::events::Event; +use quick_xml::name::LocalName; use quick_xml::NsReader; use tokio::io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader}; use tokio::io::{ReadHalf as GenericReadHalf, WriteHalf as GenericWriteHalf}; @@ -22,6 +24,10 @@ use lavina_core::LavinaCore; use projection_xmpp::{launch, RunningServer, ServerConfig}; use proto_xmpp::xml::{Continuation, FromXml, Parser}; +fn element_name<'a>(local_name: &LocalName<'a>) -> &'a str { + from_utf8(local_name.into_inner()).unwrap() +} + pub async fn read_irc_message(reader: &mut BufReader>, buf: &mut Vec) -> Result { let mut size = 0; let res = reader.read_until(b'\n', buf).await?; @@ -55,19 +61,13 @@ impl<'a> TestScope<'a> { Ok(event) } - async fn read(&mut self) -> Result { - self.buffer.clear(); - let (ns, event) = self.reader.read_resolved_event_into_async(&mut self.buffer).await?; - let mut parser: Continuation<_, std::result::Result> = T::parse().consume(ns, &event); - loop { - match parser { - Continuation::Final(res) => return Ok(res?), - Continuation::Continue(next) => { - let (ns, event) = self.reader.read_resolved_event_into_async(&mut self.buffer).await?; - parser = next.consume(ns, &event); - } - } - } + async fn expect_starttls_required(&mut self) -> Result<()> { + assert_matches!(self.next_xml_event().await?, Event::Start(b) => assert_eq!(element_name(&b.local_name()), "features")); + assert_matches!(self.next_xml_event().await?, Event::Start(b) => assert_eq!(element_name(&b.local_name()), "starttls")); + assert_matches!(self.next_xml_event().await?, Event::Empty(b) => assert_eq!(element_name(&b.local_name()), "required")); + assert_matches!(self.next_xml_event().await?, Event::End(b) => assert_eq!(element_name(&b.local_name()), "starttls")); + assert_matches!(self.next_xml_event().await?, Event::End(b) => assert_eq!(element_name(&b.local_name()), "features")); + Ok(()) } } @@ -98,6 +98,24 @@ impl<'a> TestScopeTls<'a> { Ok(()) } + async fn expect_auth_mechanisms(&mut self) -> Result<()> { + assert_matches!(self.next_xml_event().await?, Event::Start(b) => assert_eq!(element_name(&b.local_name()), "features")); + assert_matches!(self.next_xml_event().await?, Event::Start(b) => assert_eq!(element_name(&b.local_name()), "mechanisms")); + assert_matches!(self.next_xml_event().await?, Event::Start(b) => assert_eq!(element_name(&b.local_name()), "mechanism")); + assert_matches!(self.next_xml_event().await?, Event::Text(b) => assert_eq!(&*b, b"PLAIN")); + assert_matches!(self.next_xml_event().await?, Event::End(b) => assert_eq!(element_name(&b.local_name()), "mechanism")); + assert_matches!(self.next_xml_event().await?, Event::End(b) => assert_eq!(element_name(&b.local_name()), "mechanisms")); + assert_matches!(self.next_xml_event().await?, Event::End(b) => assert_eq!(element_name(&b.local_name()), "features")); + Ok(()) + } + + async fn expect_bind_feature(&mut self) -> Result<()> { + assert_matches!(self.next_xml_event().await?, Event::Start(b) => assert_eq!(element_name(&b.local_name()), "features")); + assert_matches!(self.next_xml_event().await?, Event::Empty(b) => assert_eq!(element_name(&b.local_name()), "bind")); + assert_matches!(self.next_xml_event().await?, Event::End(b) => assert_eq!(element_name(&b.local_name()), "features")); + Ok(()) + } + async fn next_xml_event(&mut self) -> Result> { self.buffer.clear(); let event = self.reader.read_event_into_async(&mut self.buffer); @@ -107,6 +125,7 @@ impl<'a> TestScopeTls<'a> { } struct IgnoreCertVerification; + impl ServerCertVerifier for IgnoreCertVerification { fn verify_server_cert( &self, @@ -127,6 +146,7 @@ struct TestServer { core: LavinaCore, server: RunningServer, } + impl TestServer { async fn start() -> Result { let _ = tracing_subscriber::fmt::try_init(); @@ -175,14 +195,10 @@ async fn scenario_basic() -> Result<()> { s.send(r#""#).await?; s.send(r#""#).await?; assert_matches!(s.next_xml_event().await?, Event::Decl(_) => {}); - assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"stream")); - assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"features")); - assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"starttls")); - assert_matches!(s.next_xml_event().await?, Event::Empty(b) => assert_eq!(b.local_name().into_inner(), b"required")); - assert_matches!(s.next_xml_event().await?, Event::End(b) => assert_eq!(b.local_name().into_inner(), b"starttls")); - assert_matches!(s.next_xml_event().await?, Event::End(b) => assert_eq!(b.local_name().into_inner(), b"features")); + assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(element_name(&b.local_name()), "stream")); + s.expect_starttls_required().await?; s.send(r#""#).await?; - assert_matches!(s.next_xml_event().await?, Event::Empty(b) => assert_eq!(b.local_name().into_inner(), b"proceed")); + assert_matches!(s.next_xml_event().await?, Event::Empty(b) => assert_eq!(element_name(&b.local_name()), "proceed")); let buffer = s.buffer; tracing::info!("TLS feature negotiation complete"); @@ -201,35 +217,26 @@ async fn scenario_basic() -> Result<()> { s.send(r#""#).await?; s.send(r#""#).await?; assert_matches!(s.next_xml_event().await?, Event::Decl(_) => {}); - assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"stream")); + assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(element_name(&b.local_name()), "stream")); + s.expect_auth_mechanisms().await?; - assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"features")); - assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"mechanisms")); - assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"mechanism")); - assert_matches!(s.next_xml_event().await?, Event::Text(b) => assert_eq!(&*b, b"PLAIN")); - assert_matches!(s.next_xml_event().await?, Event::End(b) => assert_eq!(b.local_name().into_inner(), b"mechanism")); - assert_matches!(s.next_xml_event().await?, Event::End(b) => assert_eq!(b.local_name().into_inner(), b"mechanisms")); - assert_matches!(s.next_xml_event().await?, Event::End(b) => assert_eq!(b.local_name().into_inner(), b"features")); - - // base64-encoded b"\x00tester\x00password" + // base64-encoded "\x00tester\x00password" s.send(r#"AHRlc3RlcgBwYXNzd29yZA=="#) .await?; - assert_matches!(s.next_xml_event().await?, Event::Empty(b) => assert_eq!(b.local_name().into_inner(), b"success")); + assert_matches!(s.next_xml_event().await?, Event::Empty(b) => assert_eq!(element_name(&b.local_name()), "success")); s.send(r#""#).await?; s.send(r#""#).await?; assert_matches!(s.next_xml_event().await?, Event::Decl(_) => {}); - assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"stream")); - assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"features")); - assert_matches!(s.next_xml_event().await?, Event::Empty(b) => assert_eq!(b.local_name().into_inner(), b"bind")); - assert_matches!(s.next_xml_event().await?, Event::End(b) => assert_eq!(b.local_name().into_inner(), b"features")); + assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(element_name(&b.local_name()), "stream")); + s.expect_bind_feature().await?; s.send(r#"kek"#).await?; - assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"iq")); - assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"bind")); - assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"jid")); + assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(element_name(&b.local_name()), "iq")); + assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(element_name(&b.local_name()), "bind")); + assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(element_name(&b.local_name()), "jid")); assert_matches!(s.next_xml_event().await?, Event::Text(b) => assert_eq!(&*b, b"tester@localhost/tester")); - assert_matches!(s.next_xml_event().await?, Event::End(b) => assert_eq!(b.local_name().into_inner(), b"jid")); - assert_matches!(s.next_xml_event().await?, Event::End(b) => assert_eq!(b.local_name().into_inner(), b"bind")); - assert_matches!(s.next_xml_event().await?, Event::End(b) => assert_eq!(b.local_name().into_inner(), b"iq")); + assert_matches!(s.next_xml_event().await?, Event::End(b) => assert_eq!(element_name(&b.local_name()), "jid")); + assert_matches!(s.next_xml_event().await?, Event::End(b) => assert_eq!(element_name(&b.local_name()), "bind")); + assert_matches!(s.next_xml_event().await?, Event::End(b) => assert_eq!(element_name(&b.local_name()), "iq")); s.send(r#"Logged out"#).await?; stream.shutdown().await?; @@ -256,14 +263,10 @@ async fn scenario_wrong_password() -> Result<()> { s.send(r#""#).await?; s.send(r#""#).await?; assert_matches!(s.next_xml_event().await?, Event::Decl(_) => {}); - assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"stream")); - assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"features")); - assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"starttls")); - assert_matches!(s.next_xml_event().await?, Event::Empty(b) => assert_eq!(b.local_name().into_inner(), b"required")); - assert_matches!(s.next_xml_event().await?, Event::End(b) => assert_eq!(b.local_name().into_inner(), b"starttls")); - assert_matches!(s.next_xml_event().await?, Event::End(b) => assert_eq!(b.local_name().into_inner(), b"features")); + assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(element_name(&b.local_name()), "stream")); + s.expect_starttls_required().await?; s.send(r#""#).await?; - assert_matches!(s.next_xml_event().await?, Event::Empty(b) => assert_eq!(b.local_name().into_inner(), b"proceed")); + assert_matches!(s.next_xml_event().await?, Event::Empty(b) => assert_eq!(element_name(&b.local_name()), "proceed")); let buffer = s.buffer; tracing::info!("TLS feature negotiation complete"); @@ -282,22 +285,14 @@ async fn scenario_wrong_password() -> Result<()> { s.send(r#""#).await?; s.send(r#""#).await?; assert_matches!(s.next_xml_event().await?, Event::Decl(_) => {}); - assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"stream")); - - assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"features")); - assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"mechanisms")); - assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"mechanism")); - assert_matches!(s.next_xml_event().await?, Event::Text(b) => assert_eq!(&*b, b"PLAIN")); - assert_matches!(s.next_xml_event().await?, Event::End(b) => assert_eq!(b.local_name().into_inner(), b"mechanism")); - assert_matches!(s.next_xml_event().await?, Event::End(b) => assert_eq!(b.local_name().into_inner(), b"mechanisms")); - assert_matches!(s.next_xml_event().await?, Event::End(b) => assert_eq!(b.local_name().into_inner(), b"features")); - - // base64-encoded b"\x00tester\x00password2" + assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(element_name(&b.local_name()), "stream")); + s.expect_auth_mechanisms().await?; + // base64-encoded "\x00tester\x00password2" s.send(r#"AHRlc3RlcgBwYXNzd29yZDI="#) .await?; - assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"failure")); - assert_matches!(s.next_xml_event().await?, Event::Empty(b) => assert_eq!(b.local_name().into_inner(), b"not-authorized")); - assert_matches!(s.next_xml_event().await?, Event::End(b) => assert_eq!(b.local_name().into_inner(), b"failure")); + assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(element_name(&b.local_name()), "failure")); + assert_matches!(s.next_xml_event().await?, Event::Empty(b) => assert_eq!(element_name(&b.local_name()), "not-authorized")); + assert_matches!(s.next_xml_event().await?, Event::End(b) => assert_eq!(element_name(&b.local_name()), "failure")); let _ = stream.shutdown().await; @@ -322,14 +317,10 @@ async fn scenario_basic_without_headers() -> Result<()> { s.send(r#""#).await?; assert_matches!(s.next_xml_event().await?, Event::Decl(_) => {}); - assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"stream")); - assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"features")); - assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"starttls")); - assert_matches!(s.next_xml_event().await?, Event::Empty(b) => assert_eq!(b.local_name().into_inner(), b"required")); - assert_matches!(s.next_xml_event().await?, Event::End(b) => assert_eq!(b.local_name().into_inner(), b"starttls")); - assert_matches!(s.next_xml_event().await?, Event::End(b) => assert_eq!(b.local_name().into_inner(), b"features")); + assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(element_name(&b.local_name()), "stream")); + s.expect_starttls_required().await?; s.send(r#""#).await?; - assert_matches!(s.next_xml_event().await?, Event::Empty(b) => assert_eq!(b.local_name().into_inner(), b"proceed")); + assert_matches!(s.next_xml_event().await?, Event::Empty(b) => assert_eq!(element_name(&b.local_name()), "proceed")); let buffer = s.buffer; tracing::info!("TLS feature negotiation complete"); @@ -347,7 +338,7 @@ async fn scenario_basic_without_headers() -> Result<()> { s.send(r#""#).await?; assert_matches!(s.next_xml_event().await?, Event::Decl(_) => {}); - assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"stream")); + assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(element_name(&b.local_name()), "stream")); stream.shutdown().await?; @@ -374,14 +365,10 @@ async fn terminate_socket() -> Result<()> { s.send(r#""#).await?; assert_matches!(s.next_xml_event().await?, Event::Decl(_) => {}); - assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"stream")); - assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"features")); - assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"starttls")); - assert_matches!(s.next_xml_event().await?, Event::Empty(b) => assert_eq!(b.local_name().into_inner(), b"required")); - assert_matches!(s.next_xml_event().await?, Event::End(b) => assert_eq!(b.local_name().into_inner(), b"starttls")); - assert_matches!(s.next_xml_event().await?, Event::End(b) => assert_eq!(b.local_name().into_inner(), b"features")); + assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(element_name(&b.local_name()), "stream")); + s.expect_starttls_required().await?; s.send(r#""#).await?; - assert_matches!(s.next_xml_event().await?, Event::Empty(b) => assert_eq!(b.local_name().into_inner(), b"proceed")); + assert_matches!(s.next_xml_event().await?, Event::Empty(b) => assert_eq!(element_name(&b.local_name()), "proceed")); let connector = TlsConnector::from(Arc::new( ClientConfig::builder() @@ -400,3 +387,89 @@ async fn terminate_socket() -> Result<()> { Ok(()) } + +#[tokio::test] +async fn test_message_archive_request() -> Result<()> { + let mut server = TestServer::start().await?; + + // test scenario + + server.storage.create_user("tester").await?; + Authenticator::new(&server.storage).set_password("tester", "password").await?; + + let mut stream = TcpStream::connect(server.server.addr).await?; + let mut s = TestScope::new(&mut stream); + tracing::info!("TCP connection established"); + + s.send(r#""#).await?; + s.send(r#""#).await?; + assert_matches!(s.next_xml_event().await?, Event::Decl(_) => {}); + assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(element_name(&b.local_name()), "stream")); + s.expect_starttls_required().await?; + s.send(r#""#).await?; + assert_matches!(s.next_xml_event().await?, Event::Empty(b) => assert_eq!(element_name(&b.local_name()), "proceed")); + let buffer = s.buffer; + tracing::info!("TLS feature negotiation complete"); + + let connector = TlsConnector::from(Arc::new( + ClientConfig::builder() + .with_safe_defaults() + .with_custom_certificate_verifier(Arc::new(IgnoreCertVerification)) + .with_no_client_auth(), + )); + tracing::info!("Initiating TLS connection..."); + let mut stream = connector.connect(ServerName::IpAddress(server.server.addr.ip()), stream).await?; + tracing::info!("TLS connection established"); + + let mut s = TestScopeTls::new(&mut stream, buffer); + + s.send(r#""#).await?; + s.send(r#""#).await?; + assert_matches!(s.next_xml_event().await?, Event::Decl(_) => {}); + assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(element_name(&b.local_name()), "stream")); + s.expect_auth_mechanisms().await?; + + // base64-encoded "\x00tester\x00password" + s.send(r#"AHRlc3RlcgBwYXNzd29yZA=="#) + .await?; + assert_matches!(s.next_xml_event().await?, Event::Empty(b) => assert_eq!(element_name(&b.local_name()), "success")); + s.send(r#""#).await?; + s.send(r#""#).await?; + assert_matches!(s.next_xml_event().await?, Event::Decl(_) => {}); + assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(element_name(&b.local_name()), "stream")); + s.expect_bind_feature().await?; + s.send(r#"kek"#).await?; + assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(element_name(&b.local_name()), "iq")); + assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(element_name(&b.local_name()), "bind")); + assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(element_name(&b.local_name()), "jid")); + assert_matches!(s.next_xml_event().await?, Event::Text(b) => assert_eq!(&*b, b"tester@localhost/tester")); + assert_matches!(s.next_xml_event().await?, Event::End(b) => assert_eq!(element_name(&b.local_name()), "jid")); + assert_matches!(s.next_xml_event().await?, Event::End(b) => assert_eq!(element_name(&b.local_name()), "bind")); + assert_matches!(s.next_xml_event().await?, Event::End(b) => assert_eq!(element_name(&b.local_name()), "iq")); + + s.send(r#""#).await?; + assert_matches!(s.next_xml_event().await?, Event::Start(b) => { + assert_eq!(element_name(&b.local_name()), "iq") + }); + assert_matches!(s.next_xml_event().await?, Event::Start(b) => { + assert_eq!(element_name(&b.local_name()), "fin") + }); + assert_matches!(s.next_xml_event().await?, Event::Start(b) => { + assert_eq!(element_name(&b.local_name()), "set") + }); + assert_matches!(s.next_xml_event().await?, Event::Start(b) => { + assert_eq!(element_name(&b.local_name()), "count") + }); + assert_matches!(s.next_xml_event().await?, Event::Text(b) => { + assert_eq!(&*b, b"0") + }); + + s.send(r#"Logged out"#).await?; + + stream.shutdown().await?; + + // wrap up + + server.shutdown().await?; + Ok(()) +} diff --git a/crates/proto-xmpp/src/bind.rs b/crates/proto-xmpp/src/bind.rs index d27a00e..d546141 100644 --- a/crates/proto-xmpp/src/bind.rs +++ b/crates/proto-xmpp/src/bind.rs @@ -74,8 +74,8 @@ impl Jid { pub struct BindRequest(pub Resource); impl FromXmlTag for BindRequest { - const NS: &'static str = XMLNS; const NAME: &'static str = "bind"; + const NS: &'static str = XMLNS; } impl FromXml for BindRequest { diff --git a/crates/proto-xmpp/src/client.rs b/crates/proto-xmpp/src/client.rs index 05807bd..08d203a 100644 --- a/crates/proto-xmpp/src/client.rs +++ b/crates/proto-xmpp/src/client.rs @@ -658,7 +658,7 @@ mod tests { #[tokio::test] async fn parse_message() { - let input = r#"daabbb"#; + let input = r#"daabbb"#; let result: Message = crate::xml::parse(input).unwrap(); assert_eq!( result, @@ -666,8 +666,8 @@ mod tests { from: None, id: Some("aacea".to_string()), to: Some(Jid { - name: Some(Name("nikita".into())), - server: Server("vlnv.dev".into()), + name: Some(Name("chelik".into())), + server: Server("xmpp.ru".into()), resource: None }), r#type: MessageType::Chat, @@ -681,7 +681,7 @@ mod tests { #[tokio::test] async fn parse_message_empty_custom() { - let input = r#"daabbb"#; + let input = r#"daabbb"#; let result: Message = crate::xml::parse(input).unwrap(); assert_eq!( result, @@ -689,8 +689,8 @@ mod tests { from: None, id: Some("aacea".to_string()), to: Some(Jid { - name: Some(Name("nikita".into())), - server: Server("vlnv.dev".into()), + name: Some(Name("chelik".into())), + server: Server("xmpp.ru".into()), resource: None }), r#type: MessageType::Chat, diff --git a/crates/proto-xmpp/src/lib.rs b/crates/proto-xmpp/src/lib.rs index d3e25ba..71e8a94 100644 --- a/crates/proto-xmpp/src/lib.rs +++ b/crates/proto-xmpp/src/lib.rs @@ -3,6 +3,7 @@ pub mod bind; pub mod client; pub mod disco; +pub mod mam; pub mod muc; mod prelude; pub mod roster; diff --git a/crates/proto-xmpp/src/mam.rs b/crates/proto-xmpp/src/mam.rs new file mode 100644 index 0000000..c8151f2 --- /dev/null +++ b/crates/proto-xmpp/src/mam.rs @@ -0,0 +1,225 @@ +use anyhow::{anyhow, Result}; +use quick_xml::events::{BytesEnd, BytesStart, BytesText, Event}; +use quick_xml::name::{Namespace, ResolveResult}; +use std::io::Read; + +use crate::xml::*; + +pub const MAM_XMLNS: &'static str = "urn:xmpp:mam:2"; +pub const DATA_XMLNS: &'static str = "jabber:x:data"; +pub const RESULT_SET_XMLNS: &'static str = "http://jabber.org/protocol/rsm"; + +#[derive(PartialEq, Eq, Debug, Clone)] +pub struct MessageArchiveRequest { + pub x: Option, +} + +#[derive(PartialEq, Eq, Debug, Clone)] +pub struct X { + pub fields: Vec, +} + +#[derive(PartialEq, Eq, Debug, Clone)] +pub struct Field { + pub values: Vec, +} + +// Message archive response styled as a result set. +#[derive(PartialEq, Eq, Debug, Clone)] +pub struct Fin { + pub set: Set, +} + +#[derive(PartialEq, Eq, Debug, Clone)] +pub struct Set { + pub count: Option, +} + +impl ToXml for Fin { + fn serialize(&self, events: &mut Vec>) { + let fin_bytes = BytesStart::new(format!(r#"fin xmlns="{}" complete=True"#, MAM_XMLNS)); + let set_bytes = BytesStart::new(format!(r#"set xmlns="{}""#, RESULT_SET_XMLNS)); + events.push(Event::Start(fin_bytes)); + events.push(Event::Start(set_bytes)); + + if let &Some(count) = &self.set.count { + events.push(Event::Start(BytesStart::new("count"))); + events.push(Event::Text(BytesText::new(count.to_string().as_str()).into_owned())); + events.push(Event::End(BytesEnd::new("count"))); + } + events.push(Event::End(BytesEnd::new("set"))); + events.push(Event::End(BytesEnd::new("fin"))); + } +} + +impl FromXmlTag for X { + const NAME: &'static str = "x"; + const NS: &'static str = DATA_XMLNS; +} + +impl FromXmlTag for MessageArchiveRequest { + const NAME: &'static str = "query"; + const NS: &'static str = MAM_XMLNS; +} + +impl FromXml for X { + type P = impl Parser>; + + fn parse() -> Self::P { + |(mut namespace, mut event): (ResolveResult<'static>, &'static Event<'static>)| -> Result { + println!("X::parse {:?}", event); + + let bytes = match event { + Event::Start(bytes) if bytes.name().0 == X::NAME.as_bytes() => bytes, + Event::Empty(bytes) if bytes.name().0 == X::NAME.as_bytes() => return Ok(X { fields: vec![] }), + _ => return Err(anyhow!("Unexpected XML event: {event:?}")), + }; + let mut fields = vec![]; + loop { + (namespace, event) = yield; + match event { + Event::Start(_) => { + // start of + let mut values = vec![]; + loop { + (namespace, event) = yield; + match event { + Event::Start(bytes) if bytes.name().0 == b"value" => { + // start of + } + Event::End(bytes) if bytes.name().0 == b"field" => { + // end of + break; + } + _ => return Err(anyhow!("Unexpected XML event: {event:?}")), + } + (namespace, event) = yield; + let text: String = match event { + Event::Text(bytes) => { + // text inside + String::from_utf8(bytes.to_vec())? + } + _ => return Err(anyhow!("Unexpected XML event: {event:?}")), + }; + (namespace, event) = yield; + match event { + Event::End(bytes) if bytes.name().0 == b"value" => { + // end of + } + _ => return Err(anyhow!("Unexpected XML event: {event:?}")), + } + values.push(text); + } + fields.push(Field { values }) + } + Event::End(bytes) if bytes.name().0 == X::NAME.as_bytes() => { + // end of + return Ok(X { fields }); + } + _ => return Err(anyhow!("Unexpected XML event: {event:?}")), + } + } + } + } +} + +impl FromXml for MessageArchiveRequest { + type P = impl Parser>; + + fn parse() -> Self::P { + |(mut namespace, mut event): (ResolveResult<'static>, &'static Event<'static>)| -> Result { + println!("MessageArchiveRequest::parse {:?}", event); + + let bytes = match event { + Event::Empty(_) => return Ok(MessageArchiveRequest { x: None }), + Event::Start(bytes) => bytes, + _ => return Err(anyhow!("Unexpected XML event: {event:?}")), + }; + if bytes.name().0 != MessageArchiveRequest::NAME.as_bytes() { + return Err(anyhow!("Unexpected XML tag: {:?}", bytes.name())); + } + let ResolveResult::Bound(Namespace(ns)) = namespace else { + return Err(anyhow!("No namespace provided")); + }; + if ns != MAM_XMLNS.as_bytes() { + return Err(anyhow!("Incorrect namespace")); + } + (namespace, event) = yield; + match event { + Event::End(bytes) if bytes.name().0 == MessageArchiveRequest::NAME.as_bytes() => { + Ok(MessageArchiveRequest { x: None }) + } + Event::Start(bytes) | Event::Empty(bytes) if bytes.name().0 == X::NAME.as_bytes() => { + let x = delegate_parsing!(X, namespace, event)?; + Ok(MessageArchiveRequest { x: Some(x) }) + } + _ => Err(anyhow!("Unexpected XML event: {event:?}")), + } + } + } +} + +impl MessageArchiveRequest {} + +#[cfg(test)] +mod tests { + use super::*; + use crate::bind::{Jid, Name, Server}; + use crate::client::{Iq, IqType}; + + #[test] + fn test_parse_archive_query() { + let input = r#""#; + + let result: Iq = parse(input).unwrap(); + assert_eq!( + result, + Iq { + from: None, + id: "juliet1".to_string(), + to: Option::from(Jid { + name: None, + server: Server("pubsub.shakespeare.lit".into()), + resource: None, + }), + r#type: IqType::Set, + body: MessageArchiveRequest { x: None }, + } + ); + } + + #[test] + fn test_parse_query_messages_from_jid() { + let input = r#"value1juliet@capulet.lit"#; + + let result: Iq = parse(input).unwrap(); + assert_eq!( + result, + Iq { + from: None, + id: "juliet1".to_string(), + to: None, + r#type: IqType::Set, + body: MessageArchiveRequest { + x: Some(X { + fields: vec![ + Field { + values: vec!["value1".to_string()], + }, + Field { + values: vec!["juliet@capulet.lit".to_string()], + }, + ] + }) + }, + } + ); + } + + #[test] + fn test_parse_query_messages_from_jid_with_unclosed_tag() { + let input = r#"value1juliet@capulet.lit"#; + + assert!(parse::>(input).is_err()) + } +} diff --git a/crates/proto-xmpp/src/roster.rs b/crates/proto-xmpp/src/roster.rs index f7d0305..7eb97d1 100644 --- a/crates/proto-xmpp/src/roster.rs +++ b/crates/proto-xmpp/src/roster.rs @@ -2,7 +2,7 @@ use quick_xml::events::{BytesStart, Event}; use crate::xml::*; use anyhow::{anyhow, Result}; -use quick_xml::name::ResolveResult; +use quick_xml::name::{Namespace, ResolveResult}; pub const XMLNS: &'static str = "jabber:iq:roster"; @@ -14,6 +14,9 @@ impl FromXml for RosterQuery { fn parse() -> Self::P { |(mut namespace, mut event): (ResolveResult<'static>, &'static Event<'static>)| -> Result { + let ResolveResult::Bound(Namespace(ns)) = namespace else { + return Err(anyhow!("No namespace provided")); + }; match event { Event::Start(_) => (), Event::Empty(_) => return Ok(RosterQuery), @@ -38,3 +41,39 @@ impl ToXml for RosterQuery { events.push(Event::Empty(BytesStart::new(format!(r#"query xmlns="{}""#, XMLNS)))); } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::bind::{Jid, Name, Resource, Server}; + use crate::client::{Iq, IqType}; + + #[test] + fn test_parse() { + let input = + r#""#; + + let result: Iq = parse(input).unwrap(); + assert_eq!( + result, + Iq { + from: Option::from(Jid { + name: Option::from(Name("juliet".into())), + server: Server("example.com".into()), + resource: Option::from(Resource("balcony".into())), + }), + id: "bv1bs71f".to_string(), + to: None, + r#type: IqType::Get, + body: RosterQuery, + } + ) + } + + #[test] + fn test_missing_namespace() { + let input = r#""#; + + assert!(parse::>(input).is_err()); + } +} diff --git a/crates/proto-xmpp/src/stream.rs b/crates/proto-xmpp/src/stream.rs index 8f46f31..b12891e 100644 --- a/crates/proto-xmpp/src/stream.rs +++ b/crates/proto-xmpp/src/stream.rs @@ -170,14 +170,14 @@ mod test { #[tokio::test] async fn client_stream_start_correct_parse() { - let input = r###""###; + let input = r###""###; let mut reader = NsReader::from_reader(input.as_bytes()); let mut buf = vec![]; let res = ClientStreamStart::parse(&mut reader, &mut buf).await.unwrap(); assert_eq!( res, ClientStreamStart { - to: "vlnv.dev".to_owned(), + to: "xmpp.ru".to_owned(), lang: Some("en".to_owned()), version: "1.0".to_owned() } @@ -187,12 +187,12 @@ mod test { #[tokio::test] async fn server_stream_start_write() { let input = ServerStreamStart { - from: "vlnv.dev".to_owned(), + from: "xmpp.ru".to_owned(), lang: "en".to_owned(), id: "stream_id".to_owned(), version: "1.0".to_owned(), }; - let expected = r###""###; + let expected = r###""###; let mut output: Vec = vec![]; let mut writer = Writer::new(&mut output); input.write_xml(&mut writer).await.unwrap(); diff --git a/src/http.rs b/src/http.rs index ae64676..f13ee7a 100644 --- a/src/http.rs +++ b/src/http.rs @@ -13,10 +13,10 @@ use serde::{Deserialize, Serialize}; use tokio::net::TcpListener; use lavina_core::auth::{Authenticator, UpdatePasswordResult}; -use lavina_core::player::{PlayerId, PlayerRegistry}; +use lavina_core::player::{PlayerId, PlayerRegistry, SendMessageResult}; use lavina_core::prelude::*; use lavina_core::repo::Storage; -use lavina_core::room::RoomRegistry; +use lavina_core::room::{RoomId, RoomRegistry}; use lavina_core::terminator::Terminator; use lavina_core::LavinaCore; @@ -88,6 +88,8 @@ async fn route( (&Method::POST, paths::CREATE_PLAYER) => endpoint_create_player(request, storage).await.or5xx(), (&Method::POST, paths::STOP_PLAYER) => endpoint_stop_player(request, core.players).await.or5xx(), (&Method::POST, paths::SET_PASSWORD) => endpoint_set_password(request, storage).await.or5xx(), + (&Method::POST, rooms::paths::SEND_MESSAGE) => endpoint_send_room_message(request, core).await.or5xx(), + (&Method::POST, rooms::paths::SET_TOPIC) => endpoint_set_room_topic(request, core).await.or5xx(), _ => endpoint_not_found(), }; Ok(res) @@ -139,9 +141,7 @@ async fn endpoint_stop_player( let Some(()) = players.stop_player(&player_id).await? else { return Ok(player_not_found()); }; - let mut response = Response::new(Full::::default()); - *response.status_mut() = StatusCode::NO_CONTENT; - Ok(response) + Ok(empty_204_request()) } #[tracing::instrument(skip_all)] @@ -160,9 +160,48 @@ async fn endpoint_set_password( return Ok(player_not_found()); } } - let mut response = Response::new(Full::::default()); - *response.status_mut() = StatusCode::NO_CONTENT; - Ok(response) + Ok(empty_204_request()) +} + +async fn endpoint_send_room_message( + request: Request, + mut core: LavinaCore, +) -> Result>> { + let str = request.collect().await?.to_bytes(); + let Ok(req) = serde_json::from_slice::(&str[..]) else { + return Ok(malformed_request()); + }; + let Ok(room_id) = RoomId::from(req.room_id) else { + return Ok(room_not_found()); + }; + let Ok(player_id) = PlayerId::from(req.author_id) else { + return Ok(player_not_found()); + }; + let mut player = core.players.connect_to_player(&player_id).await; + let res = player.send_message(room_id, req.message.into()).await?; + match res { + SendMessageResult::NoSuchRoom => Ok(room_not_found()), + SendMessageResult::Success(_) => Ok(empty_204_request()), + } +} + +async fn endpoint_set_room_topic( + request: Request, + core: LavinaCore, +) -> Result>> { + let str = request.collect().await?.to_bytes(); + let Ok(req) = serde_json::from_slice::(&str[..]) else { + return Ok(malformed_request()); + }; + let Ok(room_id) = RoomId::from(req.room_id) else { + return Ok(room_not_found()); + }; + let Ok(player_id) = PlayerId::from(req.author_id) else { + return Ok(player_not_found()); + }; + let mut player = core.players.connect_to_player(&player_id).await; + player.change_topic(room_id, req.topic.into()).await?; + Ok(empty_204_request()) } fn endpoint_not_found() -> Response> { @@ -188,6 +227,17 @@ fn player_not_found() -> Response> { response } +fn room_not_found() -> Response> { + let payload = ErrorResponse { + code: rooms::errors::ROOM_NOT_FOUND, + message: "No such room exists", + } + .to_body(); + let mut response = Response::new(payload); + *response.status_mut() = StatusCode::UNPROCESSABLE_ENTITY; + response +} + fn malformed_request() -> Response> { let payload = ErrorResponse { code: errors::MALFORMED_REQUEST, @@ -200,6 +250,12 @@ fn malformed_request() -> Response> { return response; } +fn empty_204_request() -> Response> { + let mut response = Response::new(Full::::default()); + *response.status_mut() = StatusCode::NO_CONTENT; + response +} + trait Or5xx { fn or5xx(self) -> Response>; }