From 42c22d045ff56a762d10431cfddd77cf38934847 Mon Sep 17 00:00:00 2001 From: Nikita Vilunov Date: Sun, 5 Mar 2023 22:04:28 +0100 Subject: [PATCH] feat(xmpp): implement socket start negotiation up to auth --- Cargo.lock | 74 ++++++++---------- Cargo.toml | 3 + src/prelude.rs | 4 +- src/projections/xmpp/mod.rs | 144 ++++++++++++++++++++++++++++-------- src/protos/xmpp/mod.rs | 19 +++++ src/protos/xmpp/sasl.rs | 78 +++++++++++++++++++ src/protos/xmpp/stream.rs | 117 ++++++++++++++++++++++++++--- 7 files changed, 348 insertions(+), 91 deletions(-) create mode 100644 src/protos/xmpp/sasl.rs diff --git a/Cargo.lock b/Cargo.lock index 8f58a48..c0f9d4d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -238,9 +238,9 @@ dependencies = [ [[package]] name = "h2" -version = "0.3.15" +version = "0.3.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f9f29bc9dda355256b2916cf526ab02ce0aeaaaf2bad60d65ef3f12f11dd0f4" +checksum = "5be7b54589b581f624f566bf5d8eb2bab1db736c51528720b6bd36b96b55924d" dependencies = [ "bytes", "fnv", @@ -272,9 +272,9 @@ dependencies = [ [[package]] name = "http" -version = "0.2.8" +version = "0.2.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "75f43d41e26995c17e71ee126451dd3941010b0514a81a9d11f3b341debc2399" +checksum = "bd6effc99afb63425aff9b05836f029929e345a6148a14b7ecd5ab67af944482" dependencies = [ "bytes", "fnv", @@ -353,9 +353,9 @@ dependencies = [ [[package]] name = "hyper" -version = "1.0.0-rc.2" +version = "1.0.0-rc.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "289cfdbf735dea222b0ec6a10224b4d9552c7662bb451d4589cbfda3d407d1a3" +checksum = "7b75264b2003a3913f118d35c586e535293b3e22e41f074930762929d071e092" dependencies = [ "bytes", "futures-channel", @@ -406,9 +406,9 @@ checksum = "30e22bd8629359895450b59ea7a776c850561b96a3b1d31321c1949d9e6c9146" [[package]] name = "itoa" -version = "1.0.5" +version = "1.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fad582f4b9e86b6caa621cabeb0963332d92eea04729ab12892c2533951e6440" +checksum = "453ad9f582a441959e5f0d088b02ce04cfe8d51a8eaf077f12ac6d3e94164ca6" [[package]] name = "js-sys" @@ -428,7 +428,7 @@ dependencies = [ "figment", "futures-util", "http-body-util", - "hyper 1.0.0-rc.2", + "hyper 1.0.0-rc.3", "nom", "prometheus", "quick-xml", @@ -502,7 +502,7 @@ dependencies = [ "libc", "log", "wasi", - "windows-sys 0.45.0", + "windows-sys", ] [[package]] @@ -537,9 +537,9 @@ dependencies = [ [[package]] name = "once_cell" -version = "1.17.0" +version = "1.17.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6f61fba1741ea2b3d6a1e3178721804bb716a68a6aeba1149b5d52e3d464ea66" +checksum = "b7e5500299e16ebb147ae15a00a942af264cf3688f47923b8fc2cd5858f23ad3" [[package]] name = "overload" @@ -567,7 +567,7 @@ dependencies = [ "libc", "redox_syscall", "smallvec", - "windows-sys 0.45.0", + "windows-sys", ] [[package]] @@ -656,8 +656,7 @@ dependencies = [ [[package]] name = "quick-xml" version = "0.27.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ffc053f057dd768a56f62cd7e434c42c831d296968997e9ac1f76ea7c2d14c41" +source = "git+https://github.com/vilunov/quick-xml.git?branch=async-writer#5adef14e6833e9089c4965b22a249bed43742bb2" dependencies = [ "memchr", "tokio", @@ -800,9 +799,9 @@ dependencies = [ [[package]] name = "ryu" -version = "1.0.12" +version = "1.0.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7b4b9743ed687d4b4bcedf9ff5eaa7398495ae14e61cba0a295704edbc7decde" +checksum = "f91339c0467de62360649f8d3e185ca8de4224ff281f66000de5eb2a77a79041" [[package]] name = "scopeguard" @@ -842,9 +841,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.93" +version = "1.0.94" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cad406b69c91885b5107daf2c29572f6c8cdb3c66826821e286c533490c0bc76" +checksum = "1c533a59c9d8a93a09c6ab31f0fd5e5f4dd1b8fc9434804029839884765d04ea" dependencies = [ "itoa", "ryu", @@ -894,9 +893,9 @@ dependencies = [ [[package]] name = "slab" -version = "0.4.7" +version = "0.4.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4614a76b2a8be0058caa9dbbaf66d988527d86d003c11a94fbd335d7661edcef" +checksum = "6528351c9bc8ab22353f9d776db39a20288e8d6c37ef8cfe3317cf875eecfc2d" dependencies = [ "autocfg", ] @@ -909,9 +908,9 @@ checksum = "a507befe795404456341dfab10cef66ead4c041f62b8b11bbb92bffe5d0953e0" [[package]] name = "socket2" -version = "0.4.7" +version = "0.4.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "02e2d2db9033d13a1567121ddd7a095ee144db4e1ca1b1bda3419bc0da294ebd" +checksum = "64a4a911eed85daf18834cfaa86a79b7d266ff93ff5ba14005426219480ed662" dependencies = [ "libc", "winapi", @@ -925,9 +924,9 @@ checksum = "6e63cff320ae2c57904679ba7cb63280a3dc4613885beafb148ee7bf9aa9042d" [[package]] name = "syn" -version = "1.0.107" +version = "1.0.109" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1f4064b5b16e03ae50984a5a8ed5d4f8803e6bc1fd170a3cda91a1be4b18e3f5" +checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237" dependencies = [ "proc-macro2", "quote", @@ -981,9 +980,9 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokio" -version = "1.25.0" +version = "1.26.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c8e00990ebabbe4c14c08aca901caed183ecd5c09562a12c824bb53d3c3fd3af" +checksum = "03201d01c3c27a29c8a5cee5b55a93ddae1ccf6f08f65365c2c918f8c1b76f64" dependencies = [ "autocfg", "bytes", @@ -996,7 +995,7 @@ dependencies = [ "signal-hook-registry", "socket2", "tokio-macros", - "windows-sys 0.42.0", + "windows-sys", ] [[package]] @@ -1168,9 +1167,9 @@ checksum = "d54675592c1dbefd78cbd98db9bacd89886e1ca50692a0692baefffdeb92dd58" [[package]] name = "unicode-ident" -version = "1.0.6" +version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "84a22b9f218b40614adcb3f4ff08b703773ad44fa9423e4e0d346d5db86e4ebc" +checksum = "775c11906edafc97bc378816b94585fbd9a054eabaf86fdd0ced94af449efab7" [[package]] name = "unicode-normalization" @@ -1340,21 +1339,6 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" -[[package]] -name = "windows-sys" -version = "0.42.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5a3e1820f08b8513f676f7ab6c1f99ff312fb97b553d30ff4dd86f9f15728aa7" -dependencies = [ - "windows_aarch64_gnullvm", - "windows_aarch64_msvc", - "windows_i686_gnu", - "windows_i686_msvc", - "windows_x86_64_gnu", - "windows_x86_64_gnullvm", - "windows_x86_64_msvc", -] - [[package]] name = "windows-sys" version = "0.45.0" diff --git a/Cargo.toml b/Cargo.toml index 9efe7dc..42e279f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,3 +26,6 @@ quick-xml = { version = "0.27.1", features = ["async-tokio"] } assert_matches = "1.5.0" regex = "1.7.1" reqwest = { version = "0.11", default-features = false } + +[patch.crates-io] +quick-xml = { git = "https://github.com/vilunov/quick-xml.git", branch = "async-writer" } diff --git a/src/prelude.rs b/src/prelude.rs index 1345d13..a7b8ddb 100644 --- a/src/prelude.rs +++ b/src/prelude.rs @@ -13,6 +13,6 @@ pub type Result = std::result::Result; pub type ByteVec = Vec; -pub fn fail(msg: &'static str) -> anyhow::Error { - anyhow::Error::msg(msg) +pub fn fail(msg: &str) -> anyhow::Error { + anyhow::Error::msg(msg.to_owned()) } diff --git a/src/projections/xmpp/mod.rs b/src/projections/xmpp/mod.rs index fe2d624..e31f2a0 100644 --- a/src/projections/xmpp/mod.rs +++ b/src/projections/xmpp/mod.rs @@ -7,9 +7,11 @@ use std::sync::Arc; use futures_util::future::join_all; use prometheus::Registry as MetricsRegistry; +use quick_xml::events::{BytesDecl, Event}; +use quick_xml::{NsReader, Writer}; use rustls_pemfile::{certs, rsa_private_keys}; use serde::Deserialize; -use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::io::{AsyncWriteExt, BufReader, BufWriter}; use tokio::net::{TcpListener, TcpStream}; use tokio::sync::mpsc::channel; use tokio_rustls::rustls::{Certificate, PrivateKey}; @@ -18,6 +20,7 @@ use tokio_rustls::TlsAcceptor; use crate::core::player::PlayerRegistry; use crate::core::room::RoomRegistry; use crate::prelude::*; +use crate::protos::xmpp; use crate::util::Terminator; #[derive(Deserialize, Debug, Clone)] @@ -126,26 +129,47 @@ async fn handle_socket( rooms: RoomRegistry, termination: Deferred<()>, // TODO use it to stop the connection gracefully ) -> Result<()> { + use xmpp::stream::*; log::debug!("Received an XMPP connection from {socket_addr}"); - // writer.write_all(b"Hi!\n").await?; - let mut buf = [0; 1024]; - stream.write_all(br###" - - - - - - - "###).await?; + let mut reader_buf = vec![]; + let (reader, writer) = stream.split(); + let mut buf_reader = BufReader::new(reader); + let mut buf_writer = BufWriter::new(writer); + { - let i = stream.read(&mut buf).await?; - match std::str::from_utf8(&buf[0..i]) { - Ok(e) => println!("{} END", e), - Err(_) => println!("{:?} END", &buf[0..i]), + let mut xml_reader = NsReader::from_reader(&mut buf_reader); + let mut xml_writer = Writer::new(&mut buf_writer); + let aaa = xml_reader.read_event_into_async(&mut reader_buf).await?; + if let Event::Decl(_) = aaa { + // this is header + } else { + return Err(fail("expected XML header")); } - stream - .write_all(br###""###) + let _ = ClientStreamStart::parse(&mut xml_reader, &mut reader_buf).await?; + + xml_writer + .write_event_async(Event::Decl(BytesDecl::new("1.0", None, None))) .await?; + xmpp::stream::ServerStreamStart { + from: "localhost".into(), + lang: "en".into(), + version: "1.0".into(), + } + .write_xml(&mut xml_writer) + .await?; + xmpp::stream::Features { + start_tls: true, + mechanisms: false, + bind: false, + } + .write_xml(&mut xml_writer) + .await?; + xml_writer.get_mut().flush().await?; + + let StartTLS = StartTLS::parse(&mut xml_reader, &mut reader_buf).await?; + // TODO read + xmpp::stream::Proceed.write_xml(&mut xml_writer).await?; + xml_writer.get_mut().flush().await?; } let config = tokio_rustls::rustls::ServerConfig::builder() @@ -153,27 +177,83 @@ async fn handle_socket( .with_no_client_auth() .with_single_cert(vec![config.cert.clone()], config.key.clone())?; - let i = stream.read(&mut buf).await?; - - match std::str::from_utf8(&buf[0..i]) { - Ok(e) => println!("{} END", e), - Err(_) => println!("{:?} END", &buf[0..i]), - } - let acceptor = TlsAcceptor::from(Arc::new(config)); - let mut new_stream = acceptor.accept(stream).await?; + let new_stream = acceptor.accept(stream).await?; log::debug!("TLS connection established"); + let (a, b) = tokio::io::split(new_stream); + let buf_reader = BufReader::new(a); + + let mut xml_reader = NsReader::from_reader(buf_reader); + let mut xml_writer = Writer::new(b); + + { + if let Event::Decl(_) = xml_reader.read_event_into_async(&mut reader_buf).await? { + // this is header + } else { + return Err(fail("expected XML header")); + } + let _ = ClientStreamStart::parse(&mut xml_reader, &mut reader_buf).await?; + + xml_writer + .write_event_async(Event::Decl(BytesDecl::new("1.0", None, None))) + .await?; + xmpp::stream::ServerStreamStart { + from: "localhost".into(), + lang: "en".into(), + version: "1.0".into(), + } + .write_xml(&mut xml_writer) + .await?; + xmpp::stream::Features { + start_tls: false, + mechanisms: true, + bind: false, + } + .write_xml(&mut xml_writer) + .await?; + xml_writer.get_mut().flush().await?; + + let _ = xmpp::sasl::Auth::parse(&mut xml_reader, &mut reader_buf).await?; + xmpp::sasl::Success.write_xml(&mut xml_writer).await?; + } + { + if let Event::Decl(_) = xml_reader.read_event_into_async(&mut reader_buf).await? { + // this is header + } else { + return Err(fail("expected XML header")); + } + let _ = ClientStreamStart::parse(&mut xml_reader, &mut reader_buf).await?; + + xml_writer + .write_event_async(Event::Decl(BytesDecl::new("1.0", None, None))) + .await?; + xmpp::stream::ServerStreamStart { + from: "localhost".into(), + lang: "en".into(), + version: "1.0".into(), + } + .write_xml(&mut xml_writer) + .await?; + xmpp::stream::Features { + start_tls: false, + mechanisms: false, + bind: true, + } + .write_xml(&mut xml_writer) + .await?; + xml_writer.get_mut().flush().await?; + } loop { - let i = new_stream.read(&mut buf).await?; - if i == 0 { + let event = xml_reader.read_event_into_async(&mut reader_buf).await?; + println!("EVENT: {event:?}"); + if event == Event::Eof { break; } - match std::str::from_utf8(&buf[0..i]) { - Ok(e) => println!("{} END", e), - Err(_) => println!("{:?} END", &buf[0..i]), - } } - new_stream.shutdown().await?; + + let a = xml_reader.into_inner().into_inner(); + let b = xml_writer.into_inner(); + a.unsplit(b).shutdown().await?; Ok(()) } diff --git a/src/protos/xmpp/mod.rs b/src/protos/xmpp/mod.rs index 93d19c0..e1a8a99 100644 --- a/src/protos/xmpp/mod.rs +++ b/src/protos/xmpp/mod.rs @@ -1,2 +1,21 @@ pub mod client; +pub mod sasl; pub mod stream; + +// Implemented as a macro instead of a fn due to borrowck limitations +macro_rules! skip_text { + ($reader: ident, $buf: ident) => { + loop { + use quick_xml::events::Event; + $buf.clear(); + let res = $reader.read_event_into_async($buf).await?; + if let Event::Text(_) = res { + continue; + } else { + break res; + } + } + }; +} + +pub(super) use skip_text; diff --git a/src/protos/xmpp/sasl.rs b/src/protos/xmpp/sasl.rs new file mode 100644 index 0000000..e78f605 --- /dev/null +++ b/src/protos/xmpp/sasl.rs @@ -0,0 +1,78 @@ +use std::borrow::Borrow; + +use quick_xml::{ + events::{BytesStart, Event}, + NsReader, Writer, +}; +use tokio::io::{AsyncBufRead, AsyncWrite}; + +use super::skip_text; +use crate::prelude::*; + +pub enum Mechanism { + Plain, +} +impl Mechanism { + pub fn to_str(&self) -> &'static str { + match self { + Mechanism::Plain => "PLAIN", + } + } + + pub fn from_str(input: &[u8]) -> Result { + match input { + b"PLAIN" => Ok(Mechanism::Plain), + _ => Err(fail(format!("unknown auth mechanism: {input:?}").as_str())), + } + } +} + +pub struct Auth { + pub mechanism: Mechanism, + pub body: Vec, +} +impl Auth { + pub async fn parse( + reader: &mut NsReader, + buf: &mut Vec, + ) -> Result { + let event = skip_text!(reader, buf); + let mechanism = if let Event::Start(bytes) = event { + let mut mechanism = None; + for attr in bytes.attributes() { + let attr = attr?; + if attr.key.0 == b"mechanism" { + mechanism = Some(attr.value) + } + } + if let Some(mechanism) = mechanism { + Mechanism::from_str(mechanism.borrow())? + } else { + return Err(fail("expected mechanism attribute in ")); + } + } else { + return Err(fail("expected start of ")); + }; + let body = if let Event::Text(text) = reader.read_event_into_async(buf).await? { + text.into_inner().into_owned() + } else { + return Err(fail("expected text body in ")); + }; + if let Event::End(_) = reader.read_event_into_async(buf).await? { + //TODO + } else { + return Err(fail("expected end of ")); + }; + + Ok(Auth { mechanism, body }) + } +} + +pub struct Success; +impl Success { + pub async fn write_xml(&self, writer: &mut Writer) -> Result<()> { + let event = BytesStart::new(r#"success xmlns="urn:ietf:params:xml:ns:xmpp-sasl""#); + writer.write_event_async(Event::Empty(event)).await?; + Ok(()) + } +} diff --git a/src/protos/xmpp/stream.rs b/src/protos/xmpp/stream.rs index 5e7f244..978721b 100644 --- a/src/protos/xmpp/stream.rs +++ b/src/protos/xmpp/stream.rs @@ -1,15 +1,16 @@ -use std::io::Write; - use quick_xml::events::attributes::Attribute; -use quick_xml::events::{BytesStart, Event}; +use quick_xml::events::{BytesEnd, BytesStart, BytesText, Event}; use quick_xml::name::{Namespace, QName, ResolveResult}; -use quick_xml::writer::Writer; -use quick_xml::{NsReader, Result}; -use tokio::io::AsyncBufRead; +use quick_xml::{NsReader, Writer}; +use tokio::io::{AsyncBufRead, AsyncWrite}; + +use super::skip_text; +use crate::prelude::*; pub static XMLNS: &'static str = "http://etherx.jabber.org/streams"; pub static PREFIX: &'static str = "stream"; pub static XMLNS_XML: &'static str = "http://www.w3.org/XML/1998/namespace"; +pub static XMLNS_TLS: &'static str = "urn:ietf:params:xml:ns:xmpp-tls"; #[derive(Debug, PartialEq, Eq)] pub struct ClientStreamStart { @@ -22,7 +23,8 @@ impl ClientStreamStart { reader: &mut NsReader, buf: &mut Vec, ) -> Result { - if let Event::Start(e) = reader.read_event_into_async(buf).await? { + let incoming = skip_text!(reader, buf); + if let Event::Start(e) = incoming { let (ns, local) = reader.resolve_element(e.name()); if ns != ResolveResult::Bound(Namespace(XMLNS.as_bytes())) { return Err(panic!()); @@ -61,6 +63,7 @@ impl ClientStreamStart { version: version.unwrap(), }) } else { + log::error!("WAT: {incoming:?}"); Err(panic!()) } } @@ -72,7 +75,7 @@ pub struct ServerStreamStart { pub version: String, } impl ServerStreamStart { - pub fn write(&self, writer: &mut Writer) -> Result<()> { + pub async fn write_xml(&self, writer: &mut Writer) -> Result<()> { let mut event = BytesStart::new("stream:stream"); let attributes = [ Attribute { @@ -97,7 +100,97 @@ impl ServerStreamStart { }, ]; event.extend_attributes(attributes.into_iter()); - writer.write_event(Event::Start(event))?; + writer.write_event_async(Event::Start(event)).await?; + Ok(()) + } +} + +pub struct Features { + pub start_tls: bool, + pub mechanisms: bool, + pub bind: bool, +} +impl Features { + pub async fn write_xml(&self, writer: &mut Writer) -> Result<()> { + writer + .write_event_async(Event::Start(BytesStart::new("stream:features"))) + .await?; + if self.start_tls { + writer + .write_event_async(Event::Start(BytesStart::new( + r#"starttls xmlns="urn:ietf:params:xml:ns:xmpp-tls""#, + ))) + .await?; + writer + .write_event_async(Event::Empty(BytesStart::new("required"))) + .await?; + writer + .write_event_async(Event::End(BytesEnd::new("starttls"))) + .await?; + } + if self.mechanisms { + writer + .write_event_async(Event::Start(BytesStart::new( + r#"mechanisms xmlns="urn:ietf:params:xml:ns:xmpp-sasl""#, + ))) + .await?; + writer + .write_event_async(Event::Start(BytesStart::new(r#"mechanism"#))) + .await?; + writer + .write_event_async(Event::Text(BytesText::new("PLAIN"))) + .await?; + writer + .write_event_async(Event::End(BytesEnd::new("mechanism"))) + .await?; + writer + .write_event_async(Event::End(BytesEnd::new("mechanisms"))) + .await?; + } + if self.bind { + writer + .write_event_async(Event::Empty(BytesStart::new( + r#"bind xmlns="urn:ietf:params:xml:ns:xmpp-bind""#, + ))) + .await?; + } + writer + .write_event_async(Event::End(BytesEnd::new("stream:features"))) + .await?; + Ok(()) + } +} + +pub struct StartTLS; +impl StartTLS { + pub async fn parse( + reader: &mut NsReader, + buf: &mut Vec, + ) -> Result { + let incoming = skip_text!(reader, buf); + if let Event::Empty(e) = incoming { + if e.name().0 == b"starttls" { + Ok(StartTLS) + } else { + Err(fail("starttls expected")) + } + } else { + log::error!("WAT: {incoming:?}"); + Err(panic!()) + } + } +} + +pub struct Proceed; +impl Proceed { + pub async fn write_xml(&self, writer: &mut Writer) -> Result<()> { + let mut event = BytesStart::new("proceed"); + let attributes = [Attribute { + key: QName(b"xmlns"), + value: XMLNS_TLS.as_bytes().into(), + }]; + event.extend_attributes(attributes.into_iter()); + writer.write_event_async(Event::Empty(event)).await?; Ok(()) } } @@ -124,8 +217,8 @@ mod test { ) } - #[test] - fn server_stream_start_write() { + #[tokio::test] + async fn server_stream_start_write() { let input = ServerStreamStart { from: "vlnv.dev".to_owned(), lang: "en".to_owned(), @@ -134,7 +227,7 @@ mod test { let expected = r###""###; let mut output: Vec = vec![]; let mut writer = Writer::new(&mut output); - input.write(&mut writer).unwrap(); + input.write_xml(&mut writer).await.unwrap(); assert_eq!(std::str::from_utf8(&output).unwrap(), expected); } }