forked from lavina/lavina
1
0
Fork 0

feat(xmpp): handle sending messages to muc

This commit is contained in:
Nikita Vilunov 2023-04-09 23:31:43 +02:00
parent 2b54260f0b
commit 58582f4e51
6 changed files with 224 additions and 48 deletions

View File

@ -19,12 +19,12 @@ use tokio::sync::mpsc::channel;
use tokio_rustls::rustls::{Certificate, PrivateKey}; use tokio_rustls::rustls::{Certificate, PrivateKey};
use tokio_rustls::TlsAcceptor; use tokio_rustls::TlsAcceptor;
use crate::core::player::PlayerRegistry; use crate::core::player::{PlayerConnection, PlayerId, PlayerRegistry};
use crate::core::room::RoomRegistry; use crate::core::room::{RoomId, RoomRegistry};
use crate::prelude::*; use crate::prelude::*;
use crate::protos::xmpp; use crate::protos::xmpp;
use crate::protos::xmpp::bind::{BindResponse, Jid, Name, Resource, Server}; use crate::protos::xmpp::bind::{BindResponse, Jid, Name, Resource, Server};
use crate::protos::xmpp::client::{Iq, Presence}; use crate::protos::xmpp::client::{Iq, Message, MessageType, Presence};
use crate::protos::xmpp::disco::*; use crate::protos::xmpp::disco::*;
use crate::protos::xmpp::roster::RosterQuery; use crate::protos::xmpp::roster::RosterQuery;
use crate::protos::xmpp::session::Session; use crate::protos::xmpp::session::Session;
@ -46,6 +46,13 @@ struct LoadedConfig {
key: PrivateKey, key: PrivateKey,
} }
struct Authenticated {
player_id: PlayerId,
xmpp_name: Name,
xmpp_resource: Resource,
xmpp_muc_name: Resource,
}
pub async fn launch( pub async fn launch(
config: ServerConfig, config: ServerConfig,
players: PlayerRegistry, players: PlayerRegistry,
@ -136,7 +143,7 @@ async fn handle_socket(
config: Arc<LoadedConfig>, config: Arc<LoadedConfig>,
mut stream: TcpStream, mut stream: TcpStream,
socket_addr: &SocketAddr, socket_addr: &SocketAddr,
players: PlayerRegistry, mut players: PlayerRegistry,
rooms: RoomRegistry, rooms: RoomRegistry,
termination: Deferred<()>, // TODO use it to stop the connection gracefully termination: Deferred<()>, // TODO use it to stop the connection gracefully
) -> Result<()> { ) -> Result<()> {
@ -162,8 +169,18 @@ async fn handle_socket(
let mut xml_reader = NsReader::from_reader(BufReader::new(a)); let mut xml_reader = NsReader::from_reader(BufReader::new(a));
let mut xml_writer = Writer::new(b); let mut xml_writer = Writer::new(b);
socket_auth(&mut xml_reader, &mut xml_writer, &mut reader_buf).await?; let authenticated = socket_auth(&mut xml_reader, &mut xml_writer, &mut reader_buf).await?;
socket_final(&mut xml_reader, &mut xml_writer, &mut reader_buf).await?; let mut connection = players
.connect_to_player(authenticated.player_id.clone())
.await;
socket_final(
&mut xml_reader,
&mut xml_writer,
&mut reader_buf,
&authenticated,
&mut connection,
)
.await?;
let a = xml_reader.into_inner().into_inner(); let a = xml_reader.into_inner().into_inner();
let b = xml_writer.into_inner(); let b = xml_writer.into_inner();
@ -209,7 +226,7 @@ async fn socket_auth(
xml_reader: &mut NsReader<(impl AsyncBufRead + Unpin)>, xml_reader: &mut NsReader<(impl AsyncBufRead + Unpin)>,
xml_writer: &mut Writer<(impl AsyncWrite + Unpin)>, xml_writer: &mut Writer<(impl AsyncWrite + Unpin)>,
reader_buf: &mut Vec<u8>, reader_buf: &mut Vec<u8>,
) -> Result<()> { ) -> Result<Authenticated> {
read_xml_header(xml_reader, reader_buf).await?; read_xml_header(xml_reader, reader_buf).await?;
let _ = ClientStreamStart::parse(xml_reader, reader_buf).await?; let _ = ClientStreamStart::parse(xml_reader, reader_buf).await?;
@ -235,13 +252,20 @@ async fn socket_auth(
let _ = xmpp::sasl::Auth::parse(xml_reader, reader_buf).await?; let _ = xmpp::sasl::Auth::parse(xml_reader, reader_buf).await?;
xmpp::sasl::Success.write_xml(xml_writer).await?; xmpp::sasl::Success.write_xml(xml_writer).await?;
Ok(()) Ok(Authenticated {
player_id: PlayerId::from_bytes(b"darova".to_vec())?,
xmpp_name: Name("darova".to_owned()),
xmpp_resource: Resource("darova".to_owned()),
xmpp_muc_name: Resource("darova".to_owned()),
})
} }
async fn socket_final( async fn socket_final(
xml_reader: &mut NsReader<(impl AsyncBufRead + Unpin)>, xml_reader: &mut NsReader<(impl AsyncBufRead + Unpin)>,
xml_writer: &mut Writer<(impl AsyncWrite + Unpin)>, xml_writer: &mut Writer<(impl AsyncWrite + Unpin)>,
reader_buf: &mut Vec<u8>, reader_buf: &mut Vec<u8>,
authenticated: &Authenticated,
user_handle: &mut PlayerConnection,
) -> Result<()> { ) -> Result<()> {
read_xml_header(xml_reader, reader_buf).await?; read_xml_header(xml_reader, reader_buf).await?;
let _ = ClientStreamStart::parse(xml_reader, reader_buf).await?; let _ = ClientStreamStart::parse(xml_reader, reader_buf).await?;
@ -282,7 +306,7 @@ async fn socket_final(
Continuation::Final(res) => { Continuation::Final(res) => {
let res = res?; let res = res?;
dbg!(&res); dbg!(&res);
let stop = handle_packet(&mut events, res); let stop = handle_packet(&mut events, res, authenticated, user_handle).await?;
for i in &events { for i in &events {
xml_writer.write_event_async(i).await?; xml_writer.write_event_async(i).await?;
} }
@ -299,18 +323,96 @@ async fn socket_final(
Ok(()) Ok(())
} }
fn handle_packet(output: &mut Vec<Event<'static>>, packet: ClientPacket) -> bool { async fn handle_packet(
match packet { output: &mut Vec<Event<'static>>,
packet: ClientPacket,
user: &Authenticated,
user_handle: &mut PlayerConnection,
) -> Result<bool> {
Ok(match packet {
proto::ClientPacket::Iq(iq) => { proto::ClientPacket::Iq(iq) => {
handle_iq(output, iq); handle_iq(output, iq);
false false
} }
proto::ClientPacket::Message(_) => todo!(), proto::ClientPacket::Message(m) => {
if let Some(Jid {
name: Some(name),
server,
resource: _,
}) = m.to
{
if server.0 == "rooms.localhost" && m.r#type == MessageType::Groupchat {
user_handle
.send_message(
RoomId::from_bytes(name.0.clone().into_bytes())?,
m.body.clone(),
)
.await?;
Message {
to: Some(Jid {
name: Some(user.xmpp_name.clone()),
server: Server("localhost".into()),
resource: Some(user.xmpp_resource.clone()),
}),
from: Some(Jid {
name: Some(name),
server: Server("rooms.localhost".into()),
resource: Some(user.xmpp_muc_name.clone()),
}),
id: m.id,
r#type: xmpp::client::MessageType::Groupchat,
lang: None,
subject: None,
body: m.body,
}
.serialize(output);
false
} else {
todo!()
}
} else {
todo!()
}
}
proto::ClientPacket::Presence(p) => { proto::ClientPacket::Presence(p) => {
let response = Presence::<()> { let response = if p.to.is_none() {
to: Some("darova@localhost/kek".to_string()), Presence::<()> {
from: Some("darova@localhost/kek".to_string()), to: Some(Jid {
..Default::default() name: Some(user.xmpp_name.clone()),
server: Server("localhost".into()),
resource: Some(user.xmpp_resource.clone()),
}),
from: Some(Jid {
name: Some(user.xmpp_name.clone()),
server: Server("localhost".into()),
resource: Some(user.xmpp_resource.clone()),
}),
..Default::default()
}
} else if let Some(Jid {
name: Some(name),
server,
resource: Some(resource),
}) = p.to
{
let a = user_handle
.join_room(RoomId::from_bytes(name.0.clone().into_bytes())?)
.await?;
Presence::<()> {
to: Some(Jid {
name: Some(user.xmpp_name.clone()),
server: Server("localhost".into()),
resource: Some(user.xmpp_resource.clone()),
}),
from: Some(Jid {
name: Some(name),
server: Server("rooms.localhost".into()),
resource: Some(user.xmpp_muc_name.clone()),
}),
..Default::default()
}
} else {
Presence::<()>::default()
}; };
response.serialize(output); response.serialize(output);
false false
@ -319,7 +421,7 @@ fn handle_packet(output: &mut Vec<Event<'static>>, packet: ClientPacket) -> bool
ServerStreamEnd.serialize(output); ServerStreamEnd.serialize(output);
true true
} }
} })
} }
fn handle_iq(output: &mut Vec<Event<'static>>, iq: Iq<IqClientBody>) { fn handle_iq(output: &mut Vec<Event<'static>>, iq: Iq<IqClientBody>) {

View File

@ -60,7 +60,7 @@ impl FromXml for ClientPacket {
fn parse() -> Self::P { fn parse() -> Self::P {
|(namespace, event): (ResolveResult<'static>, &'static Event<'static>)| -> Result<Self> { |(namespace, event): (ResolveResult<'static>, &'static Event<'static>)| -> Result<Self> {
match event { match event {
Event::Start(bytes) => { Event::Start(bytes) | Event::Empty(bytes) => {
let name = bytes.name(); let name = bytes.name();
match_parser!(name, namespace, event; match_parser!(name, namespace, event;
Iq::<IqClientBody>, Iq::<IqClientBody>,

View File

@ -11,16 +11,16 @@ pub const XMLNS: &'static str = "urn:ietf:params:xml:ns:xmpp-bind";
// TODO remove `pub` in newtypes, introduce validation // TODO remove `pub` in newtypes, introduce validation
#[derive(PartialEq, Eq, Debug)] #[derive(PartialEq, Eq, Debug, Clone)]
pub struct Name(pub String); pub struct Name(pub String);
#[derive(PartialEq, Eq, Debug)] #[derive(PartialEq, Eq, Debug, Clone)]
pub struct Server(pub String); pub struct Server(pub String);
#[derive(PartialEq, Eq, Debug)] #[derive(PartialEq, Eq, Debug, Clone)]
pub struct Resource(pub String); pub struct Resource(pub String);
#[derive(PartialEq, Eq, Debug)] #[derive(PartialEq, Eq, Debug, Clone)]
pub struct Jid { pub struct Jid {
pub name: Option<Name>, pub name: Option<Name>,
pub server: Server, pub server: Server,

View File

@ -6,13 +6,15 @@ use quick_xml::name::{QName, ResolveResult};
use crate::prelude::*; use crate::prelude::*;
use crate::util::xml::*; use crate::util::xml::*;
use super::bind::Jid;
pub const XMLNS: &'static str = "jabber:client"; pub const XMLNS: &'static str = "jabber:client";
#[derive(PartialEq, Eq, Debug)] #[derive(PartialEq, Eq, Debug)]
pub struct Message { pub struct Message {
pub from: Option<String>, pub from: Option<Jid>,
pub id: Option<String>, pub id: Option<String>,
pub to: Option<String>, pub to: Option<Jid>,
// default is Normal // default is Normal
pub r#type: MessageType, pub r#type: MessageType,
pub lang: Option<String>, pub lang: Option<String>,
@ -47,9 +49,9 @@ enum MessageParserInner {
} }
#[derive(Default)] #[derive(Default)]
struct MessageParserState { struct MessageParserState {
from: Option<String>, from: Option<Jid>,
id: Option<String>, id: Option<String>,
to: Option<String>, to: Option<Jid>,
r#type: MessageType, r#type: MessageType,
lang: Option<String>, lang: Option<String>,
subject: Option<String>, subject: Option<String>,
@ -73,13 +75,15 @@ impl Parser for MessageParser {
let attr = fail_fast!(attr); let attr = fail_fast!(attr);
if attr.key.0 == b"from" { if attr.key.0 == b"from" {
let value = fail_fast!(std::str::from_utf8(&*attr.value)); let value = fail_fast!(std::str::from_utf8(&*attr.value));
state.from = Some(value.to_string()) let value = fail_fast!(Jid::from_string(value));
state.from = Some(value)
} else if attr.key.0 == b"id" { } else if attr.key.0 == b"id" {
let value = fail_fast!(std::str::from_utf8(&*attr.value)); let value = fail_fast!(std::str::from_utf8(&*attr.value));
state.id = Some(value.to_string()) state.id = Some(value.to_string())
} else if attr.key.0 == b"to" { } else if attr.key.0 == b"to" {
let value = fail_fast!(std::str::from_utf8(&*attr.value)); let value = fail_fast!(std::str::from_utf8(&*attr.value));
state.to = Some(value.to_string()) let value = fail_fast!(Jid::from_string(value));
state.to = Some(value)
} else if attr.key.0 == b"type" { } else if attr.key.0 == b"type" {
let value = fail_fast!(MessageType::from_str(&*attr.value)); let value = fail_fast!(MessageType::from_str(&*attr.value));
state.r#type = value; state.r#type = value;
@ -141,6 +145,39 @@ impl Parser for MessageParser {
} }
} }
impl ToXml for Message {
fn serialize(&self, events: &mut Vec<Event<'static>>) {
let mut bytes = BytesStart::new(format!(r#"message xmlns="{}""#, XMLNS));
if let Some(from) = &self.from {
bytes.push_attribute(Attribute {
key: QName(b"from"),
value: from.to_string().into_bytes().into(),
});
}
if let Some(to) = &self.to {
bytes.push_attribute(Attribute {
key: QName(b"to"),
value: to.to_string().into_bytes().into(),
});
}
if let Some(id) = &self.id {
bytes.push_attribute(Attribute {
key: QName(b"id"),
value: id.clone().into_bytes().into(),
});
}
bytes.push_attribute(Attribute {
key: QName(b"type"),
value: self.r#type.as_str().as_bytes().into(),
});
events.push(Event::Start(bytes));
events.push(Event::Start(BytesStart::new("body")));
events.push(Event::Text(BytesText::new(&self.body).into_owned()));
events.push(Event::End(BytesEnd::new("body")));
events.push(Event::End(BytesEnd::new("message")));
}
}
#[derive(PartialEq, Eq, Debug)] #[derive(PartialEq, Eq, Debug)]
pub enum MessageType { pub enum MessageType {
Chat, Chat,
@ -169,6 +206,16 @@ impl MessageType {
t => Err(ffail!("Unknown message type: {t}")), t => Err(ffail!("Unknown message type: {t}")),
} }
} }
pub fn as_str(&self) -> &'static str {
match self {
MessageType::Chat => "chat",
MessageType::Error => "error",
MessageType::Groupchat => "groupchat",
MessageType::Headline => "headline",
MessageType::Normal => "normal",
}
}
} }
#[derive(PartialEq, Eq, Debug)] #[derive(PartialEq, Eq, Debug)]
@ -344,12 +391,13 @@ impl<T: ToXml> ToXml for Iq<T> {
#[derive(PartialEq, Eq, Debug)] #[derive(PartialEq, Eq, Debug)]
pub struct Presence<T> { pub struct Presence<T> {
pub to: Option<String>, pub to: Option<Jid>,
pub from: Option<String>, pub from: Option<Jid>,
pub priority: Option<PresencePriority>, pub priority: Option<PresencePriority>,
pub show: Option<PresenceShow>, pub show: Option<PresenceShow>,
pub status: Vec<String>, pub status: Vec<String>,
pub custom: Vec<T>, pub custom: Vec<T>,
pub r#type: Option<String>,
} }
impl<T> Default for Presence<T> { impl<T> Default for Presence<T> {
@ -361,6 +409,7 @@ impl<T> Default for Presence<T> {
show: Default::default(), show: Default::default(),
status: Default::default(), status: Default::default(),
custom: Default::default(), custom: Default::default(),
r#type: None,
} }
} }
} }
@ -407,12 +456,33 @@ impl<T: FromXml> FromXml for Presence<T> {
fn parse() -> Self::P { fn parse() -> Self::P {
|(namespace, event): (ResolveResult<'static>, &'static Event<'static>)| -> Result<Self> { |(namespace, event): (ResolveResult<'static>, &'static Event<'static>)| -> Result<Self> {
let _ = match event { let (bytes, end) = match event {
Event::Start(bytes) => bytes, Event::Start(bytes) => (bytes, false),
Event::Empty(_) => return Ok(Presence::default()), Event::Empty(bytes) => (bytes, true),
_ => return Err(ffail!("Unexpected XML event: {event:?}")), _ => return Err(ffail!("Unexpected XML event: {event:?}")),
}; };
let mut p = Presence::<T>::default(); let mut p = Presence::<T>::default();
for attr in bytes.attributes() {
let attr = attr?;
match attr.key.0 {
b"to" => {
let s = std::str::from_utf8(&attr.value)?;
p.to = Some(Jid::from_string(s)?);
}
b"from" => {
let s = std::str::from_utf8(&attr.value)?;
p.to = Some(Jid::from_string(s)?);
}
b"type" => {
let s = std::str::from_utf8(&attr.value)?;
p.r#type = Some(s.into());
}
_ => {}
}
}
if end {
return Ok(p);
}
loop { loop {
let (namespace, event) = yield; let (namespace, event) = yield;
match event { match event {
@ -485,13 +555,13 @@ impl<T: ToXml> ToXml for Presence<T> {
if let Some(ref to) = self.to { if let Some(ref to) = self.to {
start.extend_attributes([Attribute { start.extend_attributes([Attribute {
key: QName(b"to"), key: QName(b"to"),
value: to.as_bytes().into(), value: to.to_string().as_bytes().into(),
}]); }]);
} }
if let Some(ref from) = self.from { if let Some(ref from) = self.from {
start.extend_attributes([Attribute { start.extend_attributes([Attribute {
key: QName(b"from"), key: QName(b"from"),
value: from.as_bytes().into(), value: from.to_string().as_bytes().into(),
}]); }]);
} }
events.push(Event::Start(start)); events.push(Event::Start(start));
@ -509,7 +579,7 @@ impl<T: ToXml> ToXml for Presence<T> {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use crate::protos::xmpp::bind::{BindRequest, Resource}; use crate::protos::xmpp::bind::{BindRequest, Name, Resource, Server};
use super::*; use super::*;
use quick_xml::NsReader; use quick_xml::NsReader;
@ -542,7 +612,11 @@ mod tests {
Message { Message {
from: None, from: None,
id: Some("aacea".to_string()), id: Some("aacea".to_string()),
to: Some("nikita@vlnv.dev".to_string()), to: Some(Jid {
name: Some(Name("nikita".to_owned())),
server: Server("vlnv.dev".to_owned()),
resource: None
}),
r#type: MessageType::Chat, r#type: MessageType::Chat,
lang: None, lang: None,
subject: Some("daa".to_string()), subject: Some("daa".to_string()),

View File

@ -34,7 +34,7 @@ impl FromXml for InfoQuery {
let attr = attr?; let attr = attr?;
match attr.key.0 { match attr.key.0 {
b"node" => { b"node" => {
let s = std::str::from_utf8(bytes)?; let s = std::str::from_utf8(&attr.value)?;
node = Some(s.to_owned()) node = Some(s.to_owned())
} }
_ => {} _ => {}
@ -154,15 +154,15 @@ impl FromXml for Identity {
let attr = attr?; let attr = attr?;
match attr.key.0 { match attr.key.0 {
b"category" => { b"category" => {
let s = std::str::from_utf8(bytes)?; let s = std::str::from_utf8(&attr.value)?;
category = Some(s.to_owned()) category = Some(s.to_owned())
} }
b"name" => { b"name" => {
let s = std::str::from_utf8(bytes)?; let s = std::str::from_utf8(&attr.value)?;
name = Some(s.to_owned()) name = Some(s.to_owned())
} }
b"type" => { b"type" => {
let s = std::str::from_utf8(bytes)?; let s = std::str::from_utf8(&attr.value)?;
r#type = Some(s.to_owned()) r#type = Some(s.to_owned())
} }
_ => {} _ => {}
@ -224,7 +224,7 @@ impl FromXml for Feature {
let attr = attr?; let attr = attr?;
match attr.key.0 { match attr.key.0 {
b"var" => { b"var" => {
let s = std::str::from_utf8(bytes)?; let s = std::str::from_utf8(&attr.value)?;
var = Some(s.to_owned()) var = Some(s.to_owned())
} }
_ => {} _ => {}
@ -359,15 +359,15 @@ impl FromXml for Item {
let attr = attr?; let attr = attr?;
match attr.key.0 { match attr.key.0 {
b"name" => { b"name" => {
let s = std::str::from_utf8(bytes)?; let s = std::str::from_utf8(&attr.value)?;
name = Some(s.to_owned()) name = Some(s.to_owned())
} }
b"node" => { b"node" => {
let s = std::str::from_utf8(bytes)?; let s = std::str::from_utf8(&attr.value)?;
node = Some(s.to_owned()) node = Some(s.to_owned())
} }
b"jid" => { b"jid" => {
let s = std::str::from_utf8(bytes)?; let s = std::str::from_utf8(&attr.value)?;
let s = Jid::from_string(s)?; let s = Jid::from_string(s)?;
jid = Some(s) jid = Some(s)
} }

View File

@ -28,17 +28,17 @@ impl FromXml for History {
let attr = attr?; let attr = attr?;
match attr.key.0 { match attr.key.0 {
b"maxchars" => { b"maxchars" => {
let s = std::str::from_utf8(bytes)?; let s = std::str::from_utf8(&attr.value)?;
let a = s.parse()?; let a = s.parse()?;
history.maxchars = Some(a) history.maxchars = Some(a)
} }
b"maxstanzas" => { b"maxstanzas" => {
let s = std::str::from_utf8(bytes)?; let s = std::str::from_utf8(&attr.value)?;
let a = s.parse()?; let a = s.parse()?;
history.maxstanzas = Some(a) history.maxstanzas = Some(a)
} }
b"seconds" => { b"seconds" => {
let s = std::str::from_utf8(bytes)?; let s = std::str::from_utf8(&attr.value)?;
let a = s.parse()?; let a = s.parse()?;
history.seconds = Some(a) history.seconds = Some(a)
} }