diff --git a/crates/projection-xmpp/src/lib.rs b/crates/projection-xmpp/src/lib.rs index 3b71b84..e3cb4e5 100644 --- a/crates/projection-xmpp/src/lib.rs +++ b/crates/projection-xmpp/src/lib.rs @@ -229,7 +229,6 @@ async fn socket_force_tls( use proto_xmpp::tls::*; let xml_reader = &mut NsReader::from_reader(reader); let xml_writer = &mut Writer::new(writer); - read_xml_header(xml_reader, reader_buf).await?; let _ = ClientStreamStart::parse(xml_reader, reader_buf).await?; let event = Event::Decl(BytesDecl::new("1.0", None, None)); @@ -261,7 +260,6 @@ async fn socket_auth( reader_buf: &mut Vec, storage: &mut Storage, ) -> Result { - read_xml_header(xml_reader, reader_buf).await?; let _ = ClientStreamStart::parse(xml_reader, reader_buf).await?; xml_writer.write_event_async(Event::Decl(BytesDecl::new("1.0", None, None))).await?; @@ -328,7 +326,6 @@ async fn socket_final( user_handle: &mut PlayerConnection, rooms: &RoomRegistry, ) -> Result<()> { - read_xml_header(xml_reader, reader_buf).await?; let _ = ClientStreamStart::parse(xml_reader, reader_buf).await?; xml_writer.write_event_async(Event::Decl(BytesDecl::new("1.0", None, None))).await?; @@ -420,7 +417,7 @@ struct XmppConnection<'a> { impl<'a> XmppConnection<'a> { async fn handle_packet(&mut self, output: &mut Vec>, packet: ClientPacket) -> Result { let res = match packet { - proto::ClientPacket::Iq(iq) => { + ClientPacket::Iq(iq) => { self.handle_iq(output, iq).await; false } @@ -428,11 +425,11 @@ impl<'a> XmppConnection<'a> { self.handle_message(output, m).await?; false } - proto::ClientPacket::Presence(p) => { + ClientPacket::Presence(p) => { self.handle_presence(output, p).await?; false } - proto::ClientPacket::StreamEnd => { + ClientPacket::StreamEnd => { ServerStreamEnd.serialize(output); true } @@ -440,25 +437,3 @@ impl<'a> XmppConnection<'a> { Ok(res) } } - -async fn read_xml_header( - xml_reader: &mut NsReader<(impl AsyncBufRead + Unpin)>, - reader_buf: &mut Vec, -) -> Result<()> { - if let Event::Decl(bytes) = xml_reader.read_event_into_async(reader_buf).await? { - // this is header - if let Some(encoding) = bytes.encoding() { - let encoding = encoding?; - if &*encoding == b"UTF-8" { - Ok(()) - } else { - Err(anyhow!("Unsupported encoding: {encoding:?}")) - } - } else { - // Err(fail("No XML encoding provided")) - Ok(()) - } - } else { - Err(anyhow!("Expected XML header")) - } -} diff --git a/crates/projection-xmpp/tests/lib.rs b/crates/projection-xmpp/tests/lib.rs index 0dae478..9dfae4c 100644 --- a/crates/projection-xmpp/tests/lib.rs +++ b/crates/projection-xmpp/tests/lib.rs @@ -187,6 +187,69 @@ async fn scenario_basic() -> Result<()> { Ok(()) } +#[tokio::test] +async fn scenario_basic_without_headers() -> Result<()> { + tracing_subscriber::fmt::try_init(); + let config = ServerConfig { + listen_on: "127.0.0.1:0".parse().unwrap(), + cert: "tests/certs/xmpp.pem".parse().unwrap(), + key: "tests/certs/xmpp.key".parse().unwrap(), + }; + let mut metrics = MetricsRegistry::new(); + let mut storage = Storage::open(StorageConfig { + db_path: ":memory:".into(), + }) + .await?; + let rooms = RoomRegistry::new(&mut metrics, storage.clone()).unwrap(); + let players = PlayerRegistry::empty(rooms.clone(), &mut metrics).unwrap(); + let server = launch(config, players, rooms, metrics, storage.clone()).await.unwrap(); + + // test scenario + + storage.create_user("tester").await?; + storage.set_password("tester", "password").await?; + + let mut stream = TcpStream::connect(server.addr).await?; + let mut s = TestScope::new(&mut stream); + tracing::info!("TCP connection established"); + + s.send(r#""#).await?; + assert_matches!(s.next_xml_event().await?, Event::Decl(_) => {}); + assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"stream")); + assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"features")); + assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"starttls")); + assert_matches!(s.next_xml_event().await?, Event::Empty(b) => assert_eq!(b.local_name().into_inner(), b"required")); + assert_matches!(s.next_xml_event().await?, Event::End(b) => assert_eq!(b.local_name().into_inner(), b"starttls")); + assert_matches!(s.next_xml_event().await?, Event::End(b) => assert_eq!(b.local_name().into_inner(), b"features")); + s.send(r#""#).await?; + assert_matches!(s.next_xml_event().await?, Event::Empty(b) => assert_eq!(b.local_name().into_inner(), b"proceed")); + let buffer = s.buffer; + tracing::info!("TLS feature negotiation complete"); + + let connector = TlsConnector::from(Arc::new( + ClientConfig::builder() + .with_safe_defaults() + .with_custom_certificate_verifier(Arc::new(IgnoreCertVerification)) + .with_no_client_auth(), + )); + tracing::info!("Initiating TLS connection..."); + let mut stream = connector.connect(ServerName::IpAddress(server.addr.ip()), stream).await?; + tracing::info!("TLS connection established"); + + let mut s = TestScopeTls::new(&mut stream, buffer); + + s.send(r#""#).await?; + assert_matches!(s.next_xml_event().await?, Event::Decl(_) => {}); + assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"stream")); + + stream.shutdown().await?; + + // wrap up + + server.terminate().await?; + Ok(()) +} + #[tokio::test] async fn terminate_socket() -> Result<()> { tracing_subscriber::fmt::try_init(); diff --git a/crates/proto-xmpp/src/stream.rs b/crates/proto-xmpp/src/stream.rs index d85df07..8f46f31 100644 --- a/crates/proto-xmpp/src/stream.rs +++ b/crates/proto-xmpp/src/stream.rs @@ -24,7 +24,17 @@ impl ClientStreamStart { reader: &mut NsReader, buf: &mut Vec, ) -> Result { - let incoming = skip_text!(reader, buf); + let mut incoming = skip_text!(reader, buf); + if let Event::Decl(bytes) = incoming { + // this is header + if let Some(encoding) = bytes.encoding() { + let encoding = encoding?; + if &*encoding != b"UTF-8" { + return Err(anyhow!("Unsupported encoding: {encoding:?}")); + } + } + incoming = skip_text!(reader, buf); + } if let Event::Start(e) = incoming { let (ns, local) = reader.resolve_element(e.name()); if ns != ResolveResult::Bound(Namespace(XMLNS.as_bytes())) {