forked from lavina/lavina
1
0
Fork 0

feat(xmpp): implement socket start negotiation up to auth

This commit is contained in:
Nikita Vilunov 2023-03-05 22:04:28 +01:00
parent 435da6663a
commit 42c22d045f
7 changed files with 348 additions and 91 deletions

74
Cargo.lock generated
View File

@ -238,9 +238,9 @@ dependencies = [
[[package]] [[package]]
name = "h2" name = "h2"
version = "0.3.15" version = "0.3.16"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5f9f29bc9dda355256b2916cf526ab02ce0aeaaaf2bad60d65ef3f12f11dd0f4" checksum = "5be7b54589b581f624f566bf5d8eb2bab1db736c51528720b6bd36b96b55924d"
dependencies = [ dependencies = [
"bytes", "bytes",
"fnv", "fnv",
@ -272,9 +272,9 @@ dependencies = [
[[package]] [[package]]
name = "http" name = "http"
version = "0.2.8" version = "0.2.9"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "75f43d41e26995c17e71ee126451dd3941010b0514a81a9d11f3b341debc2399" checksum = "bd6effc99afb63425aff9b05836f029929e345a6148a14b7ecd5ab67af944482"
dependencies = [ dependencies = [
"bytes", "bytes",
"fnv", "fnv",
@ -353,9 +353,9 @@ dependencies = [
[[package]] [[package]]
name = "hyper" name = "hyper"
version = "1.0.0-rc.2" version = "1.0.0-rc.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "289cfdbf735dea222b0ec6a10224b4d9552c7662bb451d4589cbfda3d407d1a3" checksum = "7b75264b2003a3913f118d35c586e535293b3e22e41f074930762929d071e092"
dependencies = [ dependencies = [
"bytes", "bytes",
"futures-channel", "futures-channel",
@ -406,9 +406,9 @@ checksum = "30e22bd8629359895450b59ea7a776c850561b96a3b1d31321c1949d9e6c9146"
[[package]] [[package]]
name = "itoa" name = "itoa"
version = "1.0.5" version = "1.0.6"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fad582f4b9e86b6caa621cabeb0963332d92eea04729ab12892c2533951e6440" checksum = "453ad9f582a441959e5f0d088b02ce04cfe8d51a8eaf077f12ac6d3e94164ca6"
[[package]] [[package]]
name = "js-sys" name = "js-sys"
@ -428,7 +428,7 @@ dependencies = [
"figment", "figment",
"futures-util", "futures-util",
"http-body-util", "http-body-util",
"hyper 1.0.0-rc.2", "hyper 1.0.0-rc.3",
"nom", "nom",
"prometheus", "prometheus",
"quick-xml", "quick-xml",
@ -502,7 +502,7 @@ dependencies = [
"libc", "libc",
"log", "log",
"wasi", "wasi",
"windows-sys 0.45.0", "windows-sys",
] ]
[[package]] [[package]]
@ -537,9 +537,9 @@ dependencies = [
[[package]] [[package]]
name = "once_cell" name = "once_cell"
version = "1.17.0" version = "1.17.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6f61fba1741ea2b3d6a1e3178721804bb716a68a6aeba1149b5d52e3d464ea66" checksum = "b7e5500299e16ebb147ae15a00a942af264cf3688f47923b8fc2cd5858f23ad3"
[[package]] [[package]]
name = "overload" name = "overload"
@ -567,7 +567,7 @@ dependencies = [
"libc", "libc",
"redox_syscall", "redox_syscall",
"smallvec", "smallvec",
"windows-sys 0.45.0", "windows-sys",
] ]
[[package]] [[package]]
@ -656,8 +656,7 @@ dependencies = [
[[package]] [[package]]
name = "quick-xml" name = "quick-xml"
version = "0.27.1" version = "0.27.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "git+https://github.com/vilunov/quick-xml.git?branch=async-writer#5adef14e6833e9089c4965b22a249bed43742bb2"
checksum = "ffc053f057dd768a56f62cd7e434c42c831d296968997e9ac1f76ea7c2d14c41"
dependencies = [ dependencies = [
"memchr", "memchr",
"tokio", "tokio",
@ -800,9 +799,9 @@ dependencies = [
[[package]] [[package]]
name = "ryu" name = "ryu"
version = "1.0.12" version = "1.0.13"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7b4b9743ed687d4b4bcedf9ff5eaa7398495ae14e61cba0a295704edbc7decde" checksum = "f91339c0467de62360649f8d3e185ca8de4224ff281f66000de5eb2a77a79041"
[[package]] [[package]]
name = "scopeguard" name = "scopeguard"
@ -842,9 +841,9 @@ dependencies = [
[[package]] [[package]]
name = "serde_json" name = "serde_json"
version = "1.0.93" version = "1.0.94"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cad406b69c91885b5107daf2c29572f6c8cdb3c66826821e286c533490c0bc76" checksum = "1c533a59c9d8a93a09c6ab31f0fd5e5f4dd1b8fc9434804029839884765d04ea"
dependencies = [ dependencies = [
"itoa", "itoa",
"ryu", "ryu",
@ -894,9 +893,9 @@ dependencies = [
[[package]] [[package]]
name = "slab" name = "slab"
version = "0.4.7" version = "0.4.8"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4614a76b2a8be0058caa9dbbaf66d988527d86d003c11a94fbd335d7661edcef" checksum = "6528351c9bc8ab22353f9d776db39a20288e8d6c37ef8cfe3317cf875eecfc2d"
dependencies = [ dependencies = [
"autocfg", "autocfg",
] ]
@ -909,9 +908,9 @@ checksum = "a507befe795404456341dfab10cef66ead4c041f62b8b11bbb92bffe5d0953e0"
[[package]] [[package]]
name = "socket2" name = "socket2"
version = "0.4.7" version = "0.4.9"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "02e2d2db9033d13a1567121ddd7a095ee144db4e1ca1b1bda3419bc0da294ebd" checksum = "64a4a911eed85daf18834cfaa86a79b7d266ff93ff5ba14005426219480ed662"
dependencies = [ dependencies = [
"libc", "libc",
"winapi", "winapi",
@ -925,9 +924,9 @@ checksum = "6e63cff320ae2c57904679ba7cb63280a3dc4613885beafb148ee7bf9aa9042d"
[[package]] [[package]]
name = "syn" name = "syn"
version = "1.0.107" version = "1.0.109"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1f4064b5b16e03ae50984a5a8ed5d4f8803e6bc1fd170a3cda91a1be4b18e3f5" checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
@ -981,9 +980,9 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20"
[[package]] [[package]]
name = "tokio" name = "tokio"
version = "1.25.0" version = "1.26.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c8e00990ebabbe4c14c08aca901caed183ecd5c09562a12c824bb53d3c3fd3af" checksum = "03201d01c3c27a29c8a5cee5b55a93ddae1ccf6f08f65365c2c918f8c1b76f64"
dependencies = [ dependencies = [
"autocfg", "autocfg",
"bytes", "bytes",
@ -996,7 +995,7 @@ dependencies = [
"signal-hook-registry", "signal-hook-registry",
"socket2", "socket2",
"tokio-macros", "tokio-macros",
"windows-sys 0.42.0", "windows-sys",
] ]
[[package]] [[package]]
@ -1168,9 +1167,9 @@ checksum = "d54675592c1dbefd78cbd98db9bacd89886e1ca50692a0692baefffdeb92dd58"
[[package]] [[package]]
name = "unicode-ident" name = "unicode-ident"
version = "1.0.6" version = "1.0.7"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "84a22b9f218b40614adcb3f4ff08b703773ad44fa9423e4e0d346d5db86e4ebc" checksum = "775c11906edafc97bc378816b94585fbd9a054eabaf86fdd0ced94af449efab7"
[[package]] [[package]]
name = "unicode-normalization" name = "unicode-normalization"
@ -1340,21 +1339,6 @@ version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f"
[[package]]
name = "windows-sys"
version = "0.42.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5a3e1820f08b8513f676f7ab6c1f99ff312fb97b553d30ff4dd86f9f15728aa7"
dependencies = [
"windows_aarch64_gnullvm",
"windows_aarch64_msvc",
"windows_i686_gnu",
"windows_i686_msvc",
"windows_x86_64_gnu",
"windows_x86_64_gnullvm",
"windows_x86_64_msvc",
]
[[package]] [[package]]
name = "windows-sys" name = "windows-sys"
version = "0.45.0" version = "0.45.0"

View File

@ -26,3 +26,6 @@ quick-xml = { version = "0.27.1", features = ["async-tokio"] }
assert_matches = "1.5.0" assert_matches = "1.5.0"
regex = "1.7.1" regex = "1.7.1"
reqwest = { version = "0.11", default-features = false } reqwest = { version = "0.11", default-features = false }
[patch.crates-io]
quick-xml = { git = "https://github.com/vilunov/quick-xml.git", branch = "async-writer" }

View File

@ -13,6 +13,6 @@ pub type Result<T> = std::result::Result<T, anyhow::Error>;
pub type ByteVec = Vec<u8>; pub type ByteVec = Vec<u8>;
pub fn fail(msg: &'static str) -> anyhow::Error { pub fn fail(msg: &str) -> anyhow::Error {
anyhow::Error::msg(msg) anyhow::Error::msg(msg.to_owned())
} }

View File

@ -7,9 +7,11 @@ use std::sync::Arc;
use futures_util::future::join_all; use futures_util::future::join_all;
use prometheus::Registry as MetricsRegistry; use prometheus::Registry as MetricsRegistry;
use quick_xml::events::{BytesDecl, Event};
use quick_xml::{NsReader, Writer};
use rustls_pemfile::{certs, rsa_private_keys}; use rustls_pemfile::{certs, rsa_private_keys};
use serde::Deserialize; use serde::Deserialize;
use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::io::{AsyncWriteExt, BufReader, BufWriter};
use tokio::net::{TcpListener, TcpStream}; use tokio::net::{TcpListener, TcpStream};
use tokio::sync::mpsc::channel; use tokio::sync::mpsc::channel;
use tokio_rustls::rustls::{Certificate, PrivateKey}; use tokio_rustls::rustls::{Certificate, PrivateKey};
@ -18,6 +20,7 @@ use tokio_rustls::TlsAcceptor;
use crate::core::player::PlayerRegistry; use crate::core::player::PlayerRegistry;
use crate::core::room::RoomRegistry; use crate::core::room::RoomRegistry;
use crate::prelude::*; use crate::prelude::*;
use crate::protos::xmpp;
use crate::util::Terminator; use crate::util::Terminator;
#[derive(Deserialize, Debug, Clone)] #[derive(Deserialize, Debug, Clone)]
@ -126,26 +129,47 @@ async fn handle_socket(
rooms: RoomRegistry, rooms: RoomRegistry,
termination: Deferred<()>, // TODO use it to stop the connection gracefully termination: Deferred<()>, // TODO use it to stop the connection gracefully
) -> Result<()> { ) -> Result<()> {
use xmpp::stream::*;
log::debug!("Received an XMPP connection from {socket_addr}"); log::debug!("Received an XMPP connection from {socket_addr}");
// writer.write_all(b"Hi!\n").await?; let mut reader_buf = vec![];
let mut buf = [0; 1024]; let (reader, writer) = stream.split();
stream.write_all(br###"<?xml version='1.0'?> let mut buf_reader = BufReader::new(reader);
<stream:stream id='11698431101746707873' version='1.0' xml:lang='en' xmlns:stream='http://etherx.jabber.org/streams' from='localhost' xmlns='jabber:client'> let mut buf_writer = BufWriter::new(writer);
<stream:features>
<starttls xmlns="urn:ietf:params:xml:ns:xmpp-tls">
<required/>
</starttls>
</stream:features>
"###).await?;
{ {
let i = stream.read(&mut buf).await?; let mut xml_reader = NsReader::from_reader(&mut buf_reader);
match std::str::from_utf8(&buf[0..i]) { let mut xml_writer = Writer::new(&mut buf_writer);
Ok(e) => println!("{} END", e), let aaa = xml_reader.read_event_into_async(&mut reader_buf).await?;
Err(_) => println!("{:?} END", &buf[0..i]), if let Event::Decl(_) = aaa {
// this is <?xml ...> header
} else {
return Err(fail("expected XML header"));
} }
stream let _ = ClientStreamStart::parse(&mut xml_reader, &mut reader_buf).await?;
.write_all(br###"<proceed xmlns="urn:ietf:params:xml:ns:xmpp-tls"/>"###)
xml_writer
.write_event_async(Event::Decl(BytesDecl::new("1.0", None, None)))
.await?; .await?;
xmpp::stream::ServerStreamStart {
from: "localhost".into(),
lang: "en".into(),
version: "1.0".into(),
}
.write_xml(&mut xml_writer)
.await?;
xmpp::stream::Features {
start_tls: true,
mechanisms: false,
bind: false,
}
.write_xml(&mut xml_writer)
.await?;
xml_writer.get_mut().flush().await?;
let StartTLS = StartTLS::parse(&mut xml_reader, &mut reader_buf).await?;
// TODO read <starttls xmlns="urn:ietf:params:xml:ns:xmpp-tls"/>
xmpp::stream::Proceed.write_xml(&mut xml_writer).await?;
xml_writer.get_mut().flush().await?;
} }
let config = tokio_rustls::rustls::ServerConfig::builder() let config = tokio_rustls::rustls::ServerConfig::builder()
@ -153,27 +177,83 @@ async fn handle_socket(
.with_no_client_auth() .with_no_client_auth()
.with_single_cert(vec![config.cert.clone()], config.key.clone())?; .with_single_cert(vec![config.cert.clone()], config.key.clone())?;
let i = stream.read(&mut buf).await?;
match std::str::from_utf8(&buf[0..i]) {
Ok(e) => println!("{} END", e),
Err(_) => println!("{:?} END", &buf[0..i]),
}
let acceptor = TlsAcceptor::from(Arc::new(config)); let acceptor = TlsAcceptor::from(Arc::new(config));
let mut new_stream = acceptor.accept(stream).await?; let new_stream = acceptor.accept(stream).await?;
log::debug!("TLS connection established"); log::debug!("TLS connection established");
let (a, b) = tokio::io::split(new_stream);
let buf_reader = BufReader::new(a);
let mut xml_reader = NsReader::from_reader(buf_reader);
let mut xml_writer = Writer::new(b);
{
if let Event::Decl(_) = xml_reader.read_event_into_async(&mut reader_buf).await? {
// this is <?xml ...> header
} else {
return Err(fail("expected XML header"));
}
let _ = ClientStreamStart::parse(&mut xml_reader, &mut reader_buf).await?;
xml_writer
.write_event_async(Event::Decl(BytesDecl::new("1.0", None, None)))
.await?;
xmpp::stream::ServerStreamStart {
from: "localhost".into(),
lang: "en".into(),
version: "1.0".into(),
}
.write_xml(&mut xml_writer)
.await?;
xmpp::stream::Features {
start_tls: false,
mechanisms: true,
bind: false,
}
.write_xml(&mut xml_writer)
.await?;
xml_writer.get_mut().flush().await?;
let _ = xmpp::sasl::Auth::parse(&mut xml_reader, &mut reader_buf).await?;
xmpp::sasl::Success.write_xml(&mut xml_writer).await?;
}
{
if let Event::Decl(_) = xml_reader.read_event_into_async(&mut reader_buf).await? {
// this is <?xml ...> header
} else {
return Err(fail("expected XML header"));
}
let _ = ClientStreamStart::parse(&mut xml_reader, &mut reader_buf).await?;
xml_writer
.write_event_async(Event::Decl(BytesDecl::new("1.0", None, None)))
.await?;
xmpp::stream::ServerStreamStart {
from: "localhost".into(),
lang: "en".into(),
version: "1.0".into(),
}
.write_xml(&mut xml_writer)
.await?;
xmpp::stream::Features {
start_tls: false,
mechanisms: false,
bind: true,
}
.write_xml(&mut xml_writer)
.await?;
xml_writer.get_mut().flush().await?;
}
loop { loop {
let i = new_stream.read(&mut buf).await?; let event = xml_reader.read_event_into_async(&mut reader_buf).await?;
if i == 0 { println!("EVENT: {event:?}");
if event == Event::Eof {
break; break;
} }
match std::str::from_utf8(&buf[0..i]) {
Ok(e) => println!("{} END", e),
Err(_) => println!("{:?} END", &buf[0..i]),
}
} }
new_stream.shutdown().await?;
let a = xml_reader.into_inner().into_inner();
let b = xml_writer.into_inner();
a.unsplit(b).shutdown().await?;
Ok(()) Ok(())
} }

View File

@ -1,2 +1,21 @@
pub mod client; pub mod client;
pub mod sasl;
pub mod stream; pub mod stream;
// Implemented as a macro instead of a fn due to borrowck limitations
macro_rules! skip_text {
($reader: ident, $buf: ident) => {
loop {
use quick_xml::events::Event;
$buf.clear();
let res = $reader.read_event_into_async($buf).await?;
if let Event::Text(_) = res {
continue;
} else {
break res;
}
}
};
}
pub(super) use skip_text;

78
src/protos/xmpp/sasl.rs Normal file
View File

@ -0,0 +1,78 @@
use std::borrow::Borrow;
use quick_xml::{
events::{BytesStart, Event},
NsReader, Writer,
};
use tokio::io::{AsyncBufRead, AsyncWrite};
use super::skip_text;
use crate::prelude::*;
pub enum Mechanism {
Plain,
}
impl Mechanism {
pub fn to_str(&self) -> &'static str {
match self {
Mechanism::Plain => "PLAIN",
}
}
pub fn from_str(input: &[u8]) -> Result<Mechanism> {
match input {
b"PLAIN" => Ok(Mechanism::Plain),
_ => Err(fail(format!("unknown auth mechanism: {input:?}").as_str())),
}
}
}
pub struct Auth {
pub mechanism: Mechanism,
pub body: Vec<u8>,
}
impl Auth {
pub async fn parse(
reader: &mut NsReader<impl AsyncBufRead + Unpin>,
buf: &mut Vec<u8>,
) -> Result<Auth> {
let event = skip_text!(reader, buf);
let mechanism = if let Event::Start(bytes) = event {
let mut mechanism = None;
for attr in bytes.attributes() {
let attr = attr?;
if attr.key.0 == b"mechanism" {
mechanism = Some(attr.value)
}
}
if let Some(mechanism) = mechanism {
Mechanism::from_str(mechanism.borrow())?
} else {
return Err(fail("expected mechanism attribute in <auth>"));
}
} else {
return Err(fail("expected start of <auth>"));
};
let body = if let Event::Text(text) = reader.read_event_into_async(buf).await? {
text.into_inner().into_owned()
} else {
return Err(fail("expected text body in <auth>"));
};
if let Event::End(_) = reader.read_event_into_async(buf).await? {
//TODO
} else {
return Err(fail("expected end of <auth>"));
};
Ok(Auth { mechanism, body })
}
}
pub struct Success;
impl Success {
pub async fn write_xml(&self, writer: &mut Writer<impl AsyncWrite + Unpin>) -> Result<()> {
let event = BytesStart::new(r#"success xmlns="urn:ietf:params:xml:ns:xmpp-sasl""#);
writer.write_event_async(Event::Empty(event)).await?;
Ok(())
}
}

View File

@ -1,15 +1,16 @@
use std::io::Write;
use quick_xml::events::attributes::Attribute; use quick_xml::events::attributes::Attribute;
use quick_xml::events::{BytesStart, Event}; use quick_xml::events::{BytesEnd, BytesStart, BytesText, Event};
use quick_xml::name::{Namespace, QName, ResolveResult}; use quick_xml::name::{Namespace, QName, ResolveResult};
use quick_xml::writer::Writer; use quick_xml::{NsReader, Writer};
use quick_xml::{NsReader, Result}; use tokio::io::{AsyncBufRead, AsyncWrite};
use tokio::io::AsyncBufRead;
use super::skip_text;
use crate::prelude::*;
pub static XMLNS: &'static str = "http://etherx.jabber.org/streams"; pub static XMLNS: &'static str = "http://etherx.jabber.org/streams";
pub static PREFIX: &'static str = "stream"; pub static PREFIX: &'static str = "stream";
pub static XMLNS_XML: &'static str = "http://www.w3.org/XML/1998/namespace"; pub static XMLNS_XML: &'static str = "http://www.w3.org/XML/1998/namespace";
pub static XMLNS_TLS: &'static str = "urn:ietf:params:xml:ns:xmpp-tls";
#[derive(Debug, PartialEq, Eq)] #[derive(Debug, PartialEq, Eq)]
pub struct ClientStreamStart { pub struct ClientStreamStart {
@ -22,7 +23,8 @@ impl ClientStreamStart {
reader: &mut NsReader<impl AsyncBufRead + Unpin>, reader: &mut NsReader<impl AsyncBufRead + Unpin>,
buf: &mut Vec<u8>, buf: &mut Vec<u8>,
) -> Result<ClientStreamStart> { ) -> Result<ClientStreamStart> {
if let Event::Start(e) = reader.read_event_into_async(buf).await? { let incoming = skip_text!(reader, buf);
if let Event::Start(e) = incoming {
let (ns, local) = reader.resolve_element(e.name()); let (ns, local) = reader.resolve_element(e.name());
if ns != ResolveResult::Bound(Namespace(XMLNS.as_bytes())) { if ns != ResolveResult::Bound(Namespace(XMLNS.as_bytes())) {
return Err(panic!()); return Err(panic!());
@ -61,6 +63,7 @@ impl ClientStreamStart {
version: version.unwrap(), version: version.unwrap(),
}) })
} else { } else {
log::error!("WAT: {incoming:?}");
Err(panic!()) Err(panic!())
} }
} }
@ -72,7 +75,7 @@ pub struct ServerStreamStart {
pub version: String, pub version: String,
} }
impl ServerStreamStart { impl ServerStreamStart {
pub fn write(&self, writer: &mut Writer<impl Write>) -> Result<()> { pub async fn write_xml(&self, writer: &mut Writer<impl AsyncWrite + Unpin>) -> Result<()> {
let mut event = BytesStart::new("stream:stream"); let mut event = BytesStart::new("stream:stream");
let attributes = [ let attributes = [
Attribute { Attribute {
@ -97,7 +100,97 @@ impl ServerStreamStart {
}, },
]; ];
event.extend_attributes(attributes.into_iter()); event.extend_attributes(attributes.into_iter());
writer.write_event(Event::Start(event))?; writer.write_event_async(Event::Start(event)).await?;
Ok(())
}
}
pub struct Features {
pub start_tls: bool,
pub mechanisms: bool,
pub bind: bool,
}
impl Features {
pub async fn write_xml(&self, writer: &mut Writer<impl AsyncWrite + Unpin>) -> Result<()> {
writer
.write_event_async(Event::Start(BytesStart::new("stream:features")))
.await?;
if self.start_tls {
writer
.write_event_async(Event::Start(BytesStart::new(
r#"starttls xmlns="urn:ietf:params:xml:ns:xmpp-tls""#,
)))
.await?;
writer
.write_event_async(Event::Empty(BytesStart::new("required")))
.await?;
writer
.write_event_async(Event::End(BytesEnd::new("starttls")))
.await?;
}
if self.mechanisms {
writer
.write_event_async(Event::Start(BytesStart::new(
r#"mechanisms xmlns="urn:ietf:params:xml:ns:xmpp-sasl""#,
)))
.await?;
writer
.write_event_async(Event::Start(BytesStart::new(r#"mechanism"#)))
.await?;
writer
.write_event_async(Event::Text(BytesText::new("PLAIN")))
.await?;
writer
.write_event_async(Event::End(BytesEnd::new("mechanism")))
.await?;
writer
.write_event_async(Event::End(BytesEnd::new("mechanisms")))
.await?;
}
if self.bind {
writer
.write_event_async(Event::Empty(BytesStart::new(
r#"bind xmlns="urn:ietf:params:xml:ns:xmpp-bind""#,
)))
.await?;
}
writer
.write_event_async(Event::End(BytesEnd::new("stream:features")))
.await?;
Ok(())
}
}
pub struct StartTLS;
impl StartTLS {
pub async fn parse(
reader: &mut NsReader<impl AsyncBufRead + Unpin>,
buf: &mut Vec<u8>,
) -> Result<StartTLS> {
let incoming = skip_text!(reader, buf);
if let Event::Empty(e) = incoming {
if e.name().0 == b"starttls" {
Ok(StartTLS)
} else {
Err(fail("starttls expected"))
}
} else {
log::error!("WAT: {incoming:?}");
Err(panic!())
}
}
}
pub struct Proceed;
impl Proceed {
pub async fn write_xml(&self, writer: &mut Writer<impl AsyncWrite + Unpin>) -> Result<()> {
let mut event = BytesStart::new("proceed");
let attributes = [Attribute {
key: QName(b"xmlns"),
value: XMLNS_TLS.as_bytes().into(),
}];
event.extend_attributes(attributes.into_iter());
writer.write_event_async(Event::Empty(event)).await?;
Ok(()) Ok(())
} }
} }
@ -124,8 +217,8 @@ mod test {
) )
} }
#[test] #[tokio::test]
fn server_stream_start_write() { async fn server_stream_start_write() {
let input = ServerStreamStart { let input = ServerStreamStart {
from: "vlnv.dev".to_owned(), from: "vlnv.dev".to_owned(),
lang: "en".to_owned(), lang: "en".to_owned(),
@ -134,7 +227,7 @@ mod test {
let expected = r###"<stream:stream from="vlnv.dev" version="1.0" xmlns="jabber:client" xmlns:stream="http://etherx.jabber.org/streams" xml:lang="en">"###; let expected = r###"<stream:stream from="vlnv.dev" version="1.0" xmlns="jabber:client" xmlns:stream="http://etherx.jabber.org/streams" xml:lang="en">"###;
let mut output: Vec<u8> = vec![]; let mut output: Vec<u8> = vec![];
let mut writer = Writer::new(&mut output); let mut writer = Writer::new(&mut output);
input.write(&mut writer).unwrap(); input.write_xml(&mut writer).await.unwrap();
assert_eq!(std::str::from_utf8(&output).unwrap(), expected); assert_eq!(std::str::from_utf8(&output).unwrap(), expected);
} }
} }