diff --git a/src/projections/xmpp/mod.rs b/src/projections/xmpp/mod.rs index 6dde107..57c1222 100644 --- a/src/projections/xmpp/mod.rs +++ b/src/projections/xmpp/mod.rs @@ -148,10 +148,11 @@ async fn handle_socket( socket_force_tls(&mut buf_reader, &mut buf_writer, &mut reader_buf).await?; - let config = tokio_rustls::rustls::ServerConfig::builder() + let mut config = tokio_rustls::rustls::ServerConfig::builder() .with_safe_defaults() .with_no_client_auth() .with_single_cert(vec![config.cert.clone()], config.key.clone())?; + config.key_log = Arc::new(tokio_rustls::rustls::KeyLogFile::new()); let acceptor = TlsAcceptor::from(Arc::new(config)); let new_stream = acceptor.accept(stream).await?; @@ -164,14 +165,6 @@ async fn handle_socket( socket_auth(&mut xml_reader, &mut xml_writer, &mut reader_buf).await?; socket_final(&mut xml_reader, &mut xml_writer, &mut reader_buf).await?; - loop { - let event = xml_reader.read_event_into_async(&mut reader_buf).await?; - println!("EVENT: {event:?}"); - if event == Event::Eof { - break; - } - } - let a = xml_reader.into_inner().into_inner(); let b = xml_writer.into_inner(); a.unsplit(b).shutdown().await?; @@ -289,22 +282,29 @@ async fn socket_final( Continuation::Final(res) => { let res = res?; dbg!(&res); - handle_packet(&mut events, res); + let stop = handle_packet(&mut events, res); for i in &events { xml_writer.write_event_async(i).await?; } events.clear(); xml_writer.get_mut().flush().await?; + if stop { + break; + } parser = proto::ClientPacket::parse(); } Continuation::Continue(p) => parser = p, } } + Ok(()) } -fn handle_packet(output: &mut Vec>, packet: ClientPacket) { +fn handle_packet(output: &mut Vec>, packet: ClientPacket) -> bool { match packet { - proto::ClientPacket::Iq(iq) => handle_iq(output, iq), + proto::ClientPacket::Iq(iq) => { + handle_iq(output, iq); + false + } proto::ClientPacket::Message(_) => todo!(), proto::ClientPacket::Presence(p) => { let response = Presence::<()> { @@ -313,6 +313,11 @@ fn handle_packet(output: &mut Vec>, packet: ClientPacket) { ..Default::default() }; response.serialize(output); + false + }, + proto::ClientPacket::StreamEnd => { + ServerStreamEnd.serialize(output); + true } } } diff --git a/src/projections/xmpp/proto.rs b/src/projections/xmpp/proto.rs index c3a4c79..aacf279 100644 --- a/src/projections/xmpp/proto.rs +++ b/src/projections/xmpp/proto.rs @@ -51,6 +51,7 @@ pub enum ClientPacket { Iq(Iq), Message(Message), Presence(Presence), + StreamEnd, } impl FromXml for ClientPacket { @@ -58,22 +59,34 @@ impl FromXml for ClientPacket { fn parse() -> Self::P { |(namespace, event): (ResolveResult<'static>, &'static Event<'static>)| -> Result { - let Event::Start(bytes) = event else { - return Err(ffail!("Unexpected XML event: {event:?}")); - }; - let name = bytes.name(); - match_parser!(name, namespace, event; - Iq::, - Presence::, - Message, - { - Err(ffail!( - "Unexpected XML event of name {:?} in namespace {:?}", - name, - namespace - )) + match event { + Event::Start(bytes) => { + let name = bytes.name(); + match_parser!(name, namespace, event; + Iq::, + Presence::, + Message, + { + Err(ffail!( + "Unexpected XML event of name {:?} in namespace {:?}", + name, + namespace + )) + } + ) } - ) + Event::End(bytes) => { + let name = bytes.name(); + if name.local_name().as_ref() == b"stream" { + return Ok(ClientPacket::StreamEnd); + } else { + return Err(ffail!("Unexpected XML event: {event:?}")); + } + } + _ => { + return Err(ffail!("Unexpected XML event: {event:?}")); + } + } } } } diff --git a/src/protos/xmpp/stream.rs b/src/protos/xmpp/stream.rs index a96bad2..1f133a3 100644 --- a/src/protos/xmpp/stream.rs +++ b/src/protos/xmpp/stream.rs @@ -6,6 +6,7 @@ use tokio::io::{AsyncBufRead, AsyncWrite}; use super::skip_text; use crate::prelude::*; +use crate::util::xml::ToXml; pub static XMLNS: &'static str = "http://etherx.jabber.org/streams"; pub static PREFIX: &'static str = "stream"; @@ -109,6 +110,13 @@ impl ServerStreamStart { } } +pub struct ServerStreamEnd; +impl ToXml for ServerStreamEnd { + fn serialize(&self, events: &mut Vec>) { + events.push(Event::End(BytesEnd::new("stream:stream"))); + } +} + pub struct Features { pub start_tls: bool, pub mechanisms: bool,