forked from lavina/lavina
1
0
Fork 0

Compare commits

...

9 Commits

Author SHA1 Message Date
Nikita Vilunov 1e2a6d5656 xmpp: make xml-headers optional in the c2s stream 2024-04-04 17:49:03 +00:00
Nikita Vilunov d436631450 improve docs and split command handlers into methods (#40) 2024-03-26 16:26:31 +00:00
Nikita Vilunov 878ec33cbb apply uniform formatting 2024-03-20 19:59:15 +01:00
Nikita Vilunov 1d9937319e update dependencies 2024-03-20 19:53:51 +01:00
homycdev 4b1958b5ae irc: remove hardcoded text from welcome messages
- use server name in welcome message
- use app version of crate in app_version field

Reviewed-on: lavina/lavina#35
Co-authored-by: homycdev <abdulkhamid98@gmail.com>
Co-committed-by: homycdev <abdulkhamid98@gmail.com>
2024-03-15 00:54:55 +00:00
JustTestingV c6fb74a848 termination usage for stopping the socket connection gracefully (#34)
Reviewed-on: lavina/lavina#34
Co-authored-by: JustTestingV <JustTestingV@gmail.com>
Co-committed-by: JustTestingV <JustTestingV@gmail.com>
2024-02-18 16:46:29 +00:00
Nikita Vilunov 7613055dde update dependencies 2024-02-08 21:15:22 +01:00
G1ng3r 7ff9ffdcf7 irc: send ERR_SASLFAIL reply for auth fails (#30)
Reviewed-on: lavina/lavina#30
Co-authored-by: G1ng3r <saynulanude@gmail.com>
Co-committed-by: G1ng3r <saynulanude@gmail.com>
2024-02-06 23:08:14 +00:00
JustTestingV 614be92be3 xmpp: no panic! (#29)
Reviewed-on: lavina/lavina#29
Co-authored-by: JustTestingV <JustTestingV@gmail.com>
Co-committed-by: JustTestingV <JustTestingV@gmail.com>
2024-01-22 15:13:19 +00:00
27 changed files with 915 additions and 851 deletions

View File

@ -12,7 +12,7 @@ jobs:
uses: https://github.com/actions-rs/cargo@v1 uses: https://github.com/actions-rs/cargo@v1
with: with:
command: fmt command: fmt
args: "--check -p mgmt-api -p lavina-core -p projection-irc -p projection-xmpp -p sasl" args: "--check --all"
- name: cargo check - name: cargo check
uses: https://github.com/actions-rs/cargo@v1 uses: https://github.com/actions-rs/cargo@v1
with: with:

21
.run/Run lavina.run.xml Normal file
View File

@ -0,0 +1,21 @@
<component name="ProjectRunConfigurationManager">
<configuration default="false" name="Run lavina" type="CargoCommandRunConfiguration" factoryName="Cargo Command">
<option name="command" value="run --package lavina --bin lavina -- --config config.toml" />
<option name="workingDirectory" value="file://$PROJECT_DIR$" />
<envs>
<env name="RUST_LOG" value="debug" />
</envs>
<option name="emulateTerminal" value="true" />
<option name="channel" value="DEFAULT" />
<option name="requiredFeatures" value="true" />
<option name="allFeatures" value="false" />
<option name="withSudo" value="false" />
<option name="buildTarget" value="REMOTE" />
<option name="backtrace" value="FULL" />
<option name="isRedirectInput" value="false" />
<option name="redirectInputPath" value="" />
<method v="2">
<option name="CARGO.BUILD_TASK_PROVIDER" enabled="true" />
</method>
</configuration>
</component>

733
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -18,8 +18,8 @@ assert_matches = "1.5.0"
tokio = { version = "1.24.1", features = ["full"] } # async runtime tokio = { version = "1.24.1", features = ["full"] } # async runtime
futures-util = "0.3.25" futures-util = "0.3.25"
anyhow = "1.0.68" # error utils anyhow = "1.0.68" # error utils
nonempty = "0.8.1" nonempty = "0.10.0"
quick-xml = { version = "0.30.0", features = ["async-tokio"] } quick-xml = { version = "0.31.0", features = ["async-tokio"] }
lazy_static = "1.4.0" lazy_static = "1.4.0"
regex = "1.7.1" regex = "1.7.1"
derive_more = "0.99.17" derive_more = "0.99.17"
@ -62,4 +62,4 @@ clap.workspace = true
[dev-dependencies] [dev-dependencies]
assert_matches.workspace = true assert_matches.workspace = true
regex = "1.7.1" regex = "1.7.1"
reqwest = { version = "0.11", default-features = false } reqwest = { version = "0.12.0", default-features = false }

View File

@ -14,10 +14,7 @@ use std::{
use prometheus::{IntGauge, Registry as MetricsRegistry}; use prometheus::{IntGauge, Registry as MetricsRegistry};
use serde::Serialize; use serde::Serialize;
use tokio::{ use tokio::sync::mpsc::{channel, Receiver, Sender};
sync::mpsc::{channel, Receiver, Sender},
task::JoinHandle,
};
use crate::prelude::*; use crate::prelude::*;
use crate::room::{RoomHandle, RoomId, RoomInfo, RoomRegistry}; use crate::room::{RoomHandle, RoomId, RoomInfo, RoomRegistry};
@ -45,58 +42,65 @@ impl PlayerId {
} }
} }
/// Node-local identifier of a connection. It is used to address a connection within a player actor.
#[derive(Debug, Clone, PartialEq, Eq, Hash)] #[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct ConnectionId(pub AnonKey); pub struct ConnectionId(pub AnonKey);
/// Representation of an authenticated client connection.
/// The public API available to projections through which all client actions are executed.
///
/// The connection is used to send commands to the player actor and to receive updates that might be sent to the client.
pub struct PlayerConnection { pub struct PlayerConnection {
pub connection_id: ConnectionId, pub connection_id: ConnectionId,
pub receiver: Receiver<Updates>, pub receiver: Receiver<Updates>,
player_handle: PlayerHandle, player_handle: PlayerHandle,
} }
impl PlayerConnection { impl PlayerConnection {
/// Handled in [Player::send_message].
pub async fn send_message(&mut self, room_id: RoomId, body: Str) -> Result<()> { pub async fn send_message(&mut self, room_id: RoomId, body: Str) -> Result<()> {
self.player_handle let (promise, deferred) = oneshot();
.send_message(room_id, self.connection_id.clone(), body) let cmd = ClientCommand::SendMessage { room_id, body, promise };
.await self.player_handle.send(ActorCommand::ClientCommand(cmd, self.connection_id.clone())).await;
Ok(deferred.await?)
} }
/// Handled in [Player::join_room].
pub async fn join_room(&mut self, room_id: RoomId) -> Result<JoinResult> { pub async fn join_room(&mut self, room_id: RoomId) -> Result<JoinResult> {
self.player_handle.join_room(room_id, self.connection_id.clone()).await let (promise, deferred) = oneshot();
let cmd = ClientCommand::JoinRoom { room_id, promise };
self.player_handle.send(ActorCommand::ClientCommand(cmd, self.connection_id.clone())).await;
Ok(deferred.await?)
} }
/// Handled in [Player::change_topic].
pub async fn change_topic(&mut self, room_id: RoomId, new_topic: Str) -> Result<()> { pub async fn change_topic(&mut self, room_id: RoomId, new_topic: Str) -> Result<()> {
let (promise, deferred) = oneshot(); let (promise, deferred) = oneshot();
let cmd = Cmd::ChangeTopic { let cmd = ClientCommand::ChangeTopic {
room_id, room_id,
new_topic, new_topic,
promise, promise,
}; };
self.player_handle self.player_handle.send(ActorCommand::ClientCommand(cmd, self.connection_id.clone())).await;
.send(PlayerCommand::Cmd(cmd, self.connection_id.clone()))
.await;
Ok(deferred.await?) Ok(deferred.await?)
} }
/// Handled in [Player::leave_room].
pub async fn leave_room(&mut self, room_id: RoomId) -> Result<()> { pub async fn leave_room(&mut self, room_id: RoomId) -> Result<()> {
let (promise, deferred) = oneshot(); let (promise, deferred) = oneshot();
self.player_handle let cmd = ClientCommand::LeaveRoom { room_id, promise };
.send(PlayerCommand::Cmd( self.player_handle.send(ActorCommand::ClientCommand(cmd, self.connection_id.clone())).await;
Cmd::LeaveRoom { room_id, promise },
self.connection_id.clone(),
))
.await;
Ok(deferred.await?) Ok(deferred.await?)
} }
pub async fn terminate(self) { pub async fn terminate(self) {
self.player_handle self.player_handle.send(ActorCommand::TerminateConnection(self.connection_id)).await;
.send(PlayerCommand::TerminateConnection(self.connection_id))
.await;
} }
/// Handled in [Player::get_rooms].
pub async fn get_rooms(&self) -> Result<Vec<RoomInfo>> { pub async fn get_rooms(&self) -> Result<Vec<RoomInfo>> {
let (promise, deferred) = oneshot(); let (promise, deferred) = oneshot();
self.player_handle.send(PlayerCommand::GetRooms(promise)).await; let cmd = ClientCommand::GetRooms { promise };
self.player_handle.send(ActorCommand::ClientCommand(cmd, self.connection_id.clone())).await;
Ok(deferred.await?) Ok(deferred.await?)
} }
} }
@ -104,13 +108,13 @@ impl PlayerConnection {
/// Handle to a player actor. /// Handle to a player actor.
#[derive(Clone)] #[derive(Clone)]
pub struct PlayerHandle { pub struct PlayerHandle {
tx: Sender<PlayerCommand>, tx: Sender<ActorCommand>,
} }
impl PlayerHandle { impl PlayerHandle {
pub async fn subscribe(&self) -> PlayerConnection { pub async fn subscribe(&self) -> PlayerConnection {
let (sender, receiver) = channel(32); let (sender, receiver) = channel(32);
let (promise, deferred) = oneshot(); let (promise, deferred) = oneshot();
let cmd = PlayerCommand::AddConnection { sender, promise }; let cmd = ActorCommand::AddConnection { sender, promise };
let _ = self.tx.send(cmd).await; let _ = self.tx.send(cmd).await;
let connection_id = deferred.await.unwrap(); let connection_id = deferred.await.unwrap();
PlayerConnection { PlayerConnection {
@ -120,45 +124,34 @@ impl PlayerHandle {
} }
} }
pub async fn send_message(&self, room_id: RoomId, connection_id: ConnectionId, body: Str) -> Result<()> { async fn send(&self, command: ActorCommand) {
let (promise, deferred) = oneshot(); // TODO either handle the error or doc why it is safe to ignore
let cmd = Cmd::SendMessage { room_id, body, promise };
let _ = self.tx.send(PlayerCommand::Cmd(cmd, connection_id)).await;
Ok(deferred.await?)
}
pub async fn join_room(&self, room_id: RoomId, connection_id: ConnectionId) -> Result<JoinResult> {
let (promise, deferred) = oneshot();
let cmd = Cmd::JoinRoom { room_id, promise };
let _ = self.tx.send(PlayerCommand::Cmd(cmd, connection_id)).await;
Ok(deferred.await?)
}
async fn send(&self, command: PlayerCommand) {
let _ = self.tx.send(command).await; let _ = self.tx.send(command).await;
} }
pub async fn update(&self, update: Updates) { pub async fn update(&self, update: Updates) {
self.send(PlayerCommand::Update(update)).await; self.send(ActorCommand::Update(update)).await;
} }
} }
enum PlayerCommand { /// Messages sent to the player actor.
/** Commands from connections */ enum ActorCommand {
/// Establish a new connection.
AddConnection { AddConnection {
sender: Sender<Updates>, sender: Sender<Updates>,
promise: Promise<ConnectionId>, promise: Promise<ConnectionId>,
}, },
/// Terminate an existing connection.
TerminateConnection(ConnectionId), TerminateConnection(ConnectionId),
Cmd(Cmd, ConnectionId), /// Player-issued command.
/// Query - responds with a list of rooms the player is a member of. ClientCommand(ClientCommand, ConnectionId),
GetRooms(Promise<Vec<RoomInfo>>), /// Update which is sent from a room the player is member of.
/** Events from rooms */
Update(Updates), Update(Updates),
Stop, Stop,
} }
pub enum Cmd { /// Client-issued command sent to the player actor. The actor will respond with by fulfilling the promise.
pub enum ClientCommand {
JoinRoom { JoinRoom {
room_id: RoomId, room_id: RoomId,
promise: Promise<JoinResult>, promise: Promise<JoinResult>,
@ -177,6 +170,9 @@ pub enum Cmd {
new_topic: Str, new_topic: Str,
promise: Promise<()>, promise: Promise<()>,
}, },
GetRooms {
promise: Promise<Vec<RoomInfo>>,
},
} }
pub enum JoinResult { pub enum JoinResult {
@ -243,7 +239,7 @@ impl PlayerRegistry {
pub async fn shutdown_all(&mut self) -> Result<()> { pub async fn shutdown_all(&mut self) -> Result<()> {
let mut inner = self.0.write().unwrap(); let mut inner = self.0.write().unwrap();
for (i, (k, j)) in inner.players.drain() { for (i, (k, j)) in inner.players.drain() {
k.send(PlayerCommand::Stop).await; k.send(ActorCommand::Stop).await;
drop(k); drop(k);
j.await?; j.await?;
log::debug!("Player stopped #{i:?}") log::debug!("Player stopped #{i:?}")
@ -266,7 +262,7 @@ struct Player {
connections: AnonTable<Sender<Updates>>, connections: AnonTable<Sender<Updates>>,
my_rooms: HashMap<RoomId, RoomHandle>, my_rooms: HashMap<RoomId, RoomHandle>,
banned_from: HashSet<RoomId>, banned_from: HashSet<RoomId>,
rx: Receiver<PlayerCommand>, rx: Receiver<ActorCommand>,
handle: PlayerHandle, handle: PlayerHandle,
rooms: RoomRegistry, rooms: RoomRegistry,
} }
@ -291,124 +287,152 @@ impl Player {
async fn main_loop(mut self) -> Self { async fn main_loop(mut self) -> Self {
while let Some(cmd) = self.rx.recv().await { while let Some(cmd) = self.rx.recv().await {
match cmd { match cmd {
PlayerCommand::AddConnection { sender, promise } => { ActorCommand::AddConnection { sender, promise } => {
let connection_id = self.connections.insert(sender); let connection_id = self.connections.insert(sender);
if let Err(connection_id) = promise.send(ConnectionId(connection_id)) { if let Err(connection_id) = promise.send(ConnectionId(connection_id)) {
log::warn!("Connection {connection_id:?} terminated before finalization"); log::warn!("Connection {connection_id:?} terminated before finalization");
self.terminate_connection(connection_id); self.terminate_connection(connection_id);
} }
} }
PlayerCommand::TerminateConnection(connection_id) => { ActorCommand::TerminateConnection(connection_id) => {
self.terminate_connection(connection_id); self.terminate_connection(connection_id);
} }
PlayerCommand::GetRooms(promise) => { ActorCommand::Update(update) => self.handle_update(update).await,
let mut response = vec![]; ActorCommand::ClientCommand(cmd, connection_id) => self.handle_cmd(cmd, connection_id).await,
for (_, handle) in &self.my_rooms { ActorCommand::Stop => break,
response.push(handle.get_room_info().await);
}
let _ = promise.send(response);
}
PlayerCommand::Update(update) => {
log::info!(
"Player received an update, broadcasting to {} connections",
self.connections.len()
);
match update {
Updates::BannedFrom(ref room_id) => {
self.banned_from.insert(room_id.clone());
self.my_rooms.remove(room_id);
}
_ => {}
}
for (_, connection) in &self.connections {
let _ = connection.send(update.clone()).await;
}
}
PlayerCommand::Cmd(cmd, connection_id) => self.handle_cmd(cmd, connection_id).await,
PlayerCommand::Stop => break,
} }
} }
log::debug!("Shutting down player actor #{:?}", self.player_id); log::debug!("Shutting down player actor #{:?}", self.player_id);
self self
} }
/// Handle an incoming update by changing the internal state and broadcasting it to all connections if necessary.
async fn handle_update(&mut self, update: Updates) {
log::info!(
"Player received an update, broadcasting to {} connections",
self.connections.len()
);
match update {
Updates::BannedFrom(ref room_id) => {
self.banned_from.insert(room_id.clone());
self.my_rooms.remove(room_id);
}
_ => {}
}
for (_, connection) in &self.connections {
let _ = connection.send(update.clone()).await;
}
}
fn terminate_connection(&mut self, connection_id: ConnectionId) { fn terminate_connection(&mut self, connection_id: ConnectionId) {
if let None = self.connections.pop(connection_id.0) { if let None = self.connections.pop(connection_id.0) {
log::warn!("Connection {connection_id:?} already terminated"); log::warn!("Connection {connection_id:?} already terminated");
} }
} }
async fn handle_cmd(&mut self, cmd: Cmd, connection_id: ConnectionId) { /// Dispatches a client command to the appropriate handler.
async fn handle_cmd(&mut self, cmd: ClientCommand, connection_id: ConnectionId) {
match cmd { match cmd {
Cmd::JoinRoom { room_id, promise } => { ClientCommand::JoinRoom { room_id, promise } => {
if self.banned_from.contains(&room_id) { let result = self.join_room(connection_id, room_id).await;
let _ = promise.send(JoinResult::Banned); let _ = promise.send(result);
return;
}
let room = match self.rooms.get_or_create_room(room_id.clone()).await {
Ok(room) => room,
Err(e) => {
log::error!("Failed to get or create room: {e}");
return;
}
};
room.subscribe(self.player_id.clone(), self.handle.clone()).await;
self.my_rooms.insert(room_id.clone(), room.clone());
let room_info = room.get_room_info().await;
let _ = promise.send(JoinResult::Success(room_info));
let update = Updates::RoomJoined {
room_id,
new_member_id: self.player_id.clone(),
};
self.broadcast_update(update, connection_id).await;
} }
Cmd::LeaveRoom { room_id, promise } => { ClientCommand::LeaveRoom { room_id, promise } => {
let room = self.my_rooms.remove(&room_id); self.leave_room(connection_id, room_id).await;
if let Some(room) = room {
room.unsubscribe(&self.player_id).await;
let room_info = room.get_room_info().await;
}
let _ = promise.send(()); let _ = promise.send(());
let update = Updates::RoomLeft {
room_id,
former_member_id: self.player_id.clone(),
};
self.broadcast_update(update, connection_id).await;
} }
Cmd::SendMessage { room_id, body, promise } => { ClientCommand::SendMessage { room_id, body, promise } => {
let room = self.rooms.get_room(&room_id).await; self.send_message(connection_id, room_id, body).await;
if let Some(room) = room {
room.send_message(self.player_id.clone(), body.clone()).await;
} else {
tracing::info!("no room found");
}
let _ = promise.send(()); let _ = promise.send(());
let update = Updates::NewMessage {
room_id,
author_id: self.player_id.clone(),
body,
};
self.broadcast_update(update, connection_id).await;
} }
Cmd::ChangeTopic { ClientCommand::ChangeTopic {
room_id, room_id,
new_topic, new_topic,
promise, promise,
} => { } => {
let room = self.rooms.get_room(&room_id).await; self.change_topic(connection_id, room_id, new_topic).await;
if let Some(mut room) = room {
room.set_topic(self.player_id.clone(), new_topic.clone()).await;
} else {
tracing::info!("no room found");
}
let _ = promise.send(()); let _ = promise.send(());
let update = Updates::RoomTopicChanged { room_id, new_topic }; }
self.broadcast_update(update, connection_id).await; ClientCommand::GetRooms { promise } => {
let result = self.get_rooms().await;
let _ = promise.send(result);
} }
} }
} }
async fn join_room(&mut self, connection_id: ConnectionId, room_id: RoomId) -> JoinResult {
if self.banned_from.contains(&room_id) {
return JoinResult::Banned;
}
let room = match self.rooms.get_or_create_room(room_id.clone()).await {
Ok(room) => room,
Err(e) => {
log::error!("Failed to get or create room: {e}");
todo!();
}
};
room.subscribe(self.player_id.clone(), self.handle.clone()).await;
self.my_rooms.insert(room_id.clone(), room.clone());
let room_info = room.get_room_info().await;
let update = Updates::RoomJoined {
room_id,
new_member_id: self.player_id.clone(),
};
self.broadcast_update(update, connection_id).await;
JoinResult::Success(room_info)
}
async fn leave_room(&mut self, connection_id: ConnectionId, room_id: RoomId) {
let room = self.my_rooms.remove(&room_id);
if let Some(room) = room {
room.unsubscribe(&self.player_id).await;
}
let update = Updates::RoomLeft {
room_id,
former_member_id: self.player_id.clone(),
};
self.broadcast_update(update, connection_id).await;
}
async fn send_message(&mut self, connection_id: ConnectionId, room_id: RoomId, body: Str) {
let room = self.rooms.get_room(&room_id).await;
if let Some(room) = room {
room.send_message(self.player_id.clone(), body.clone()).await;
} else {
tracing::info!("no room found");
}
let update = Updates::NewMessage {
room_id,
author_id: self.player_id.clone(),
body,
};
self.broadcast_update(update, connection_id).await;
}
async fn change_topic(&mut self, connection_id: ConnectionId, room_id: RoomId, new_topic: Str) {
let room = self.rooms.get_room(&room_id).await;
if let Some(mut room) = room {
room.set_topic(self.player_id.clone(), new_topic.clone()).await;
} else {
tracing::info!("no room found");
}
let update = Updates::RoomTopicChanged { room_id, new_topic };
self.broadcast_update(update, connection_id).await;
}
async fn get_rooms(&self) -> Vec<RoomInfo> {
let mut response = vec![];
for (_, handle) in &self.my_rooms {
response.push(handle.get_room_info().await);
}
response
}
/// Broadcasts an update to all connections except the one with the given id.
///
/// This is called after handling a client command.
/// Sending the update to the connection which sent the command is handled by the connection itself.
async fn broadcast_update(&self, update: Updates, except: ConnectionId) { async fn broadcast_update(&self, update: Updates, except: ConnectionId) {
for (a, b) in &self.connections { for (a, b) in &self.connections {
if ConnectionId(a) == except { if ConnectionId(a) == except {

View File

@ -31,7 +31,7 @@ impl RoomId {
} }
} }
/// Shared datastructure for storing metadata about rooms. /// Shared data structure for storing metadata about rooms.
#[derive(Clone)] #[derive(Clone)]
pub struct RoomRegistry(Arc<AsyncRwLock<RoomRegistryInner>>); pub struct RoomRegistry(Arc<AsyncRwLock<RoomRegistryInner>>);
impl RoomRegistry { impl RoomRegistry {
@ -160,9 +160,13 @@ impl RoomHandle {
} }
struct Room { struct Room {
/// The numeric node-local id of the room as it is stored in the database.
storage_id: u32, storage_id: u32,
/// The cluster-global id of the room.
room_id: RoomId, room_id: RoomId,
/// Player actors on the local node which are subscribed to this room's updates.
subscriptions: HashMap<PlayerId, PlayerHandle>, subscriptions: HashMap<PlayerId, PlayerHandle>,
/// The total number of messages. Used to calculate the id of the new message.
message_count: u32, message_count: u32,
topic: Str, topic: Str,
storage: Storage, storage: Storage,
@ -180,9 +184,7 @@ impl Room {
async fn send_message(&mut self, author_id: PlayerId, body: Str) -> Result<()> { async fn send_message(&mut self, author_id: PlayerId, body: Str) -> Result<()> {
tracing::info!("Adding a message to room"); tracing::info!("Adding a message to room");
self.storage self.storage.insert_message(self.storage_id, self.message_count, &body, &*author_id.as_inner()).await?;
.insert_message(self.storage_id, self.message_count, &body, &*author_id.as_inner())
.await?;
self.message_count += 1; self.message_count += 1;
let update = Updates::NewMessage { let update = Updates::NewMessage {
room_id: self.room_id.clone(), room_id: self.room_id.clone(),
@ -193,6 +195,10 @@ impl Room {
Ok(()) Ok(())
} }
/// Broadcasts an update to all players except the one who caused the update.
///
/// This is called after handling a client command.
/// Sending the update to the player who sent the command is handled by the player actor.
async fn broadcast_update(&self, update: Updates, except: &PlayerId) { async fn broadcast_update(&self, update: Updates, except: &PlayerId) {
tracing::debug!("Broadcasting an update to {} subs", self.subscriptions.len()); tracing::debug!("Broadcasting an update to {} subs", self.subscriptions.len());
for (player_id, sub) in &self.subscriptions { for (player_id, sub) in &self.subscriptions {

View File

@ -1,2 +0,0 @@
max_width = 120
chain_width = 120

View File

@ -30,6 +30,8 @@ mod cap;
use crate::cap::Capabilities; use crate::cap::Capabilities;
pub const APP_VERSION: &str = concat!("lavina", "_", env!("CARGO_PKG_VERSION"));
#[derive(Deserialize, Debug, Clone)] #[derive(Deserialize, Debug, Clone)]
pub struct ServerConfig { pub struct ServerConfig {
pub listen_on: SocketAddr, pub listen_on: SocketAddr,
@ -42,7 +44,7 @@ struct RegisteredUser {
/** /**
* Username is mostly unused in modern IRC. * Username is mostly unused in modern IRC.
* *
* [https://stackoverflow.com/questions/31666247/what-is-the-difference-between-the-nick-username-and-real-name-in-irc-and-wha] * <https://stackoverflow.com/questions/31666247/what-is-the-difference-between-the-nick-username-and-real-name-in-irc-and-wha>
*/ */
username: Str, username: Str,
realname: Str, realname: Str,
@ -62,19 +64,24 @@ async fn handle_socket(
let mut reader: BufReader<ReadHalf> = BufReader::new(reader); let mut reader: BufReader<ReadHalf> = BufReader::new(reader);
let mut writer = BufWriter::new(writer); let mut writer = BufWriter::new(writer);
let registered_user: Result<RegisteredUser> = pin!(termination);
handle_registration(&mut reader, &mut writer, &mut storage, &config).await; select! {
biased;
match registered_user { _ = &mut termination =>{
Ok(user) => { log::info!("Socket handling was terminated");
log::debug!("User registered"); return Ok(())
handle_registered_socket(config, players, rooms, &mut reader, &mut writer, user).await?; },
} registered_user = handle_registration(&mut reader, &mut writer, &mut storage, &config) =>
Err(err) => { match registered_user {
log::debug!("Registration failed: {err}"); Ok(user) => {
} log::debug!("User registered");
handle_registered_socket(config, players, rooms, &mut reader, &mut writer, user).await?;
}
Err(err) => {
log::debug!("Registration failed: {err}");
}
}
} }
stream.shutdown().await?; stream.shutdown().await?;
Ok(()) Ok(())
} }
@ -180,14 +187,16 @@ async fn handle_registration<'a>(
writer.flush().await?; writer.flush().await?;
} }
CapabilitySubcommand::End => { CapabilitySubcommand::End => {
let Some((username, realname)) = future_username else { let Some((ref username, ref realname)) = future_username else {
todo!() todo!();
}; };
let Some(nickname) = future_nickname.clone() else { let Some(nickname) = future_nickname.clone() else {
todo!() todo!();
}; };
let username = username.clone();
let realname = realname.clone();
let candidate_user = RegisteredUser { let candidate_user = RegisteredUser {
nickname, nickname: nickname.clone(),
username, username,
realname, realname,
}; };
@ -197,7 +206,15 @@ async fn handle_registration<'a>(
break Ok(candidate_user); break Ok(candidate_user);
} else { } else {
let Some(candidate_password) = pass else { let Some(candidate_password) = pass else {
todo!(); sasl_fail_message(
config.server_name.clone(),
nickname.clone(),
"User credentials was not provided".into(),
)
.write_async(writer)
.await?;
writer.flush().await?;
continue;
}; };
auth_user(storage, &*candidate_user.nickname, &*candidate_password).await?; auth_user(storage, &*candidate_user.nickname, &*candidate_password).await?;
break Ok(candidate_user); break Ok(candidate_user);
@ -209,12 +226,20 @@ async fn handle_registration<'a>(
future_nickname = Some(nickname); future_nickname = Some(nickname);
} else if let Some((username, realname)) = future_username.clone() { } else if let Some((username, realname)) = future_username.clone() {
let candidate_user = RegisteredUser { let candidate_user = RegisteredUser {
nickname, nickname: nickname.clone(),
username, username,
realname, realname,
}; };
let Some(candidate_password) = pass else { let Some(candidate_password) = pass else {
todo!(); sasl_fail_message(
config.server_name.clone(),
nickname.clone(),
"User credentials was not provided".into(),
)
.write_async(writer)
.await?;
writer.flush().await?;
continue;
}; };
auth_user(storage, &*candidate_user.nickname, &*candidate_password).await?; auth_user(storage, &*candidate_user.nickname, &*candidate_password).await?;
break Ok(candidate_user); break Ok(candidate_user);
@ -227,12 +252,20 @@ async fn handle_registration<'a>(
future_username = Some((username, realname)); future_username = Some((username, realname));
} else if let Some(nickname) = future_nickname.clone() { } else if let Some(nickname) = future_nickname.clone() {
let candidate_user = RegisteredUser { let candidate_user = RegisteredUser {
nickname, nickname: nickname.clone(),
username, username,
realname, realname,
}; };
let Some(candidate_password) = pass else { let Some(candidate_password) = pass else {
todo!(); sasl_fail_message(
config.server_name.clone(),
nickname.clone(),
"User credentials was not provided".into(),
)
.write_async(writer)
.await?;
writer.flush().await?;
continue;
}; };
auth_user(storage, &*candidate_user.nickname, &*candidate_password).await?; auth_user(storage, &*candidate_user.nickname, &*candidate_password).await?;
break Ok(candidate_user); break Ok(candidate_user);
@ -255,38 +288,59 @@ async fn handle_registration<'a>(
.await?; .await?;
writer.flush().await?; writer.flush().await?;
} else { } else {
// TODO respond with 904 if let Some(nickname) = future_nickname.clone() {
todo!(); sasl_fail_message(
config.server_name.clone(),
nickname.clone(),
"Unsupported mechanism".into(),
)
.write_async(writer)
.await?;
writer.flush().await?;
} else {
break Err(anyhow::Error::msg("Wrong authentication sequence"));
}
} }
} else { } else {
let body = AuthBody::from_str(body.as_bytes())?; let body = AuthBody::from_str(body.as_bytes())?;
auth_user(storage, &body.login, &body.password).await?; if let Err(e) = auth_user(storage, &body.login, &body.password).await {
let login: Str = body.login.into(); tracing::warn!("Authentication failed: {:?}", e);
validated_user = Some(login.clone()); if let Some(nickname) = future_nickname.clone() {
ServerMessage { sasl_fail_message(config.server_name.clone(), nickname.clone(), "Bad credentials".into())
tags: vec![], .write_async(writer)
sender: Some(config.server_name.clone().into()), .await?;
body: ServerMessageBody::N900LoggedIn { writer.flush().await?;
nick: login.clone(), } else {
address: login.clone(), }
account: login.clone(), } else {
message: format!("You are now logged in as {}", login).into(), let login: Str = body.login.into();
}, validated_user = Some(login.clone());
ServerMessage {
tags: vec![],
sender: Some(config.server_name.clone().into()),
body: ServerMessageBody::N900LoggedIn {
nick: login.clone(),
address: login.clone(),
account: login.clone(),
message: format!("You are now logged in as {}", login).into(),
},
}
.write_async(writer)
.await?;
ServerMessage {
tags: vec![],
sender: Some(config.server_name.clone().into()),
body: ServerMessageBody::N903SaslSuccess {
nick: login.clone(),
message: "SASL authentication successful".into(),
},
}
.write_async(writer)
.await?;
writer.flush().await?;
} }
.write_async(writer)
.await?;
ServerMessage {
tags: vec![],
sender: Some(config.server_name.clone().into()),
body: ServerMessageBody::N903SaslSuccess {
nick: login.clone(),
message: "SASL authentication successful".into(),
},
}
.write_async(writer)
.await?;
writer.flush().await?;
} }
// TODO handle abortion of authentication // TODO handle abortion of authentication
} }
_ => {} _ => {}
@ -297,6 +351,14 @@ async fn handle_registration<'a>(
Ok(user) Ok(user)
} }
fn sasl_fail_message(sender: Str, nick: Str, text: Str) -> ServerMessage {
ServerMessage {
tags: vec![],
sender: Some(sender),
body: ServerMessageBody::N904SaslFail { nick, text },
}
}
async fn auth_user(storage: &mut Storage, login: &str, plain_password: &str) -> Result<()> { async fn auth_user(storage: &mut Storage, login: &str, plain_password: &str) -> Result<()> {
let stored_user = storage.retrieve_user_by_name(login).await?; let stored_user = storage.retrieve_user_by_name(login).await?;
@ -331,13 +393,14 @@ async fn handle_registered_socket<'a>(
let player_id = PlayerId::from(user.nickname.clone())?; let player_id = PlayerId::from(user.nickname.clone())?;
let mut connection = players.connect_to_player(player_id.clone()).await; let mut connection = players.connect_to_player(player_id.clone()).await;
let text: Str = format!("Welcome to {} Server", &config.server_name).into();
ServerMessage { ServerMessage {
tags: vec![], tags: vec![],
sender: Some(config.server_name.clone()), sender: Some(config.server_name.clone()),
body: ServerMessageBody::N001Welcome { body: ServerMessageBody::N001Welcome {
client: user.nickname.clone(), client: user.nickname.clone(),
text: "Welcome to Kek Server".into(), text: text.clone(),
}, },
} }
.write_async(writer) .write_async(writer)
@ -347,7 +410,7 @@ async fn handle_registered_socket<'a>(
sender: Some(config.server_name.clone()), sender: Some(config.server_name.clone()),
body: ServerMessageBody::N002YourHost { body: ServerMessageBody::N002YourHost {
client: user.nickname.clone(), client: user.nickname.clone(),
text: "Welcome to Kek Server".into(), text: text.clone(),
}, },
} }
.write_async(writer) .write_async(writer)
@ -357,7 +420,7 @@ async fn handle_registered_socket<'a>(
sender: Some(config.server_name.clone()), sender: Some(config.server_name.clone()),
body: ServerMessageBody::N003Created { body: ServerMessageBody::N003Created {
client: user.nickname.clone(), client: user.nickname.clone(),
text: "Welcome to Kek Server".into(), text: text.clone(),
}, },
} }
.write_async(writer) .write_async(writer)
@ -368,7 +431,7 @@ async fn handle_registered_socket<'a>(
body: ServerMessageBody::N004MyInfo { body: ServerMessageBody::N004MyInfo {
client: user.nickname.clone(), client: user.nickname.clone(),
hostname: config.server_name.clone(), hostname: config.server_name.clone(),
softname: "kek-0.1.alpha.3".into(), softname: APP_VERSION.into(),
}, },
} }
.write_async(writer) .write_async(writer)

View File

@ -1,3 +1,5 @@
use std::io::ErrorKind;
use std::net::SocketAddr;
use std::time::Duration; use std::time::Duration;
use anyhow::{anyhow, Result}; use anyhow::{anyhow, Result};
@ -8,8 +10,8 @@ use tokio::net::TcpStream;
use lavina_core::repo::{Storage, StorageConfig}; use lavina_core::repo::{Storage, StorageConfig};
use lavina_core::{player::PlayerRegistry, room::RoomRegistry}; use lavina_core::{player::PlayerRegistry, room::RoomRegistry};
use projection_irc::APP_VERSION;
use projection_irc::{launch, read_irc_message, RunningServer, ServerConfig}; use projection_irc::{launch, read_irc_message, RunningServer, ServerConfig};
struct TestScope<'a> { struct TestScope<'a> {
reader: BufReader<ReadHalf<'a>>, reader: BufReader<ReadHalf<'a>>,
writer: WriteHalf<'a>, writer: WriteHalf<'a>,
@ -111,10 +113,17 @@ async fn scenario_basic() -> Result<()> {
s.send("PASS password").await?; s.send("PASS password").await?;
s.send("NICK tester").await?; s.send("NICK tester").await?;
s.send("USER UserName 0 * :Real Name").await?; s.send("USER UserName 0 * :Real Name").await?;
s.expect(":testserver 001 tester :Welcome to Kek Server").await?; s.expect(":testserver 001 tester :Welcome to testserver Server").await?;
s.expect(":testserver 002 tester :Welcome to Kek Server").await?; s.expect(":testserver 002 tester :Welcome to testserver Server").await?;
s.expect(":testserver 003 tester :Welcome to Kek Server").await?; s.expect(":testserver 003 tester :Welcome to testserver Server").await?;
s.expect(":testserver 004 tester testserver kek-0.1.alpha.3 r CFILPQbcefgijklmnopqrstvz").await?; s.expect(
format!(
":testserver 004 tester testserver {} r CFILPQbcefgijklmnopqrstvz",
&APP_VERSION
)
.as_str(),
)
.await?;
s.expect(":testserver 005 tester CHANTYPES=# :are supported by this server").await?; s.expect(":testserver 005 tester CHANTYPES=# :are supported by this server").await?;
s.expect_nothing().await?; s.expect_nothing().await?;
s.send("QUIT :Leaving").await?; s.send("QUIT :Leaving").await?;
@ -159,10 +168,17 @@ async fn scenario_cap_full_negotiation() -> Result<()> {
s.send("CAP END").await?; s.send("CAP END").await?;
s.expect(":testserver 001 tester :Welcome to Kek Server").await?; s.expect(":testserver 001 tester :Welcome to testserver Server").await?;
s.expect(":testserver 002 tester :Welcome to Kek Server").await?; s.expect(":testserver 002 tester :Welcome to testserver Server").await?;
s.expect(":testserver 003 tester :Welcome to Kek Server").await?; s.expect(":testserver 003 tester :Welcome to testserver Server").await?;
s.expect(":testserver 004 tester testserver kek-0.1.alpha.3 r CFILPQbcefgijklmnopqrstvz").await?; s.expect(
format!(
":testserver 004 tester testserver {} r CFILPQbcefgijklmnopqrstvz",
&APP_VERSION
)
.as_str(),
)
.await?;
s.expect(":testserver 005 tester CHANTYPES=# :are supported by this server").await?; s.expect(":testserver 005 tester CHANTYPES=# :are supported by this server").await?;
s.expect_nothing().await?; s.expect_nothing().await?;
s.send("QUIT :Leaving").await?; s.send("QUIT :Leaving").await?;
@ -201,10 +217,17 @@ async fn scenario_cap_short_negotiation() -> Result<()> {
s.send("CAP END").await?; s.send("CAP END").await?;
s.expect(":testserver 001 tester :Welcome to Kek Server").await?; s.expect(":testserver 001 tester :Welcome to testserver Server").await?;
s.expect(":testserver 002 tester :Welcome to Kek Server").await?; s.expect(":testserver 002 tester :Welcome to testserver Server").await?;
s.expect(":testserver 003 tester :Welcome to Kek Server").await?; s.expect(":testserver 003 tester :Welcome to testserver Server").await?;
s.expect(":testserver 004 tester testserver kek-0.1.alpha.3 r CFILPQbcefgijklmnopqrstvz").await?; s.expect(
format!(
":testserver 004 tester testserver {} r CFILPQbcefgijklmnopqrstvz",
&APP_VERSION
)
.as_str(),
)
.await?;
s.expect(":testserver 005 tester CHANTYPES=# :are supported by this server").await?; s.expect(":testserver 005 tester CHANTYPES=# :are supported by this server").await?;
s.expect_nothing().await?; s.expect_nothing().await?;
s.send("QUIT :Leaving").await?; s.send("QUIT :Leaving").await?;
@ -218,3 +241,84 @@ async fn scenario_cap_short_negotiation() -> Result<()> {
server.server.terminate().await?; server.server.terminate().await?;
Ok(()) Ok(())
} }
#[tokio::test]
async fn scenario_cap_sasl_fail() -> Result<()> {
let mut server = TestServer::start().await?;
// test scenario
server.storage.create_user("tester").await?;
server.storage.set_password("tester", "password").await?;
let mut stream = TcpStream::connect(server.server.addr).await?;
let mut s = TestScope::new(&mut stream);
s.send("CAP LS 302").await?;
s.send("NICK tester").await?;
s.send("USER UserName 0 * :Real Name").await?;
s.expect(":testserver CAP * LS :sasl=PLAIN").await?;
s.send("CAP REQ :sasl").await?;
s.expect(":testserver CAP tester ACK :sasl").await?;
s.send("AUTHENTICATE SHA256").await?;
s.expect(":testserver 904 tester :Unsupported mechanism").await?;
s.send("AUTHENTICATE PLAIN").await?;
s.expect(":testserver AUTHENTICATE +").await?;
s.send("AUTHENTICATE dGVzdGVyAHRlc3RlcgBwYXNzd29yZDE=").await?;
s.expect(":testserver 904 tester :Bad credentials").await?;
s.send("AUTHENTICATE dGVzdGVyAHRlc3RlcgBwYXNzd29yZA==").await?; // base64-encoded 'tester\x00tester\x00password'
s.expect(":testserver 900 tester tester tester :You are now logged in as tester").await?;
s.expect(":testserver 903 tester :SASL authentication successful").await?;
s.send("CAP END").await?;
s.expect(":testserver 001 tester :Welcome to testserver Server").await?;
s.expect(":testserver 002 tester :Welcome to testserver Server").await?;
s.expect(":testserver 003 tester :Welcome to testserver Server").await?;
s.expect(
format!(
":testserver 004 tester testserver {} r CFILPQbcefgijklmnopqrstvz",
&APP_VERSION
)
.as_str(),
)
.await?;
s.expect(":testserver 005 tester CHANTYPES=# :are supported by this server").await?;
s.expect_nothing().await?;
s.send("QUIT :Leaving").await?;
s.expect(":testserver ERROR :Leaving the server").await?;
s.expect_eof().await?;
stream.shutdown().await?;
// wrap up
server.server.terminate().await?;
Ok(())
}
#[tokio::test]
async fn terminate_socket_scenario() -> Result<()> {
let mut server = TestServer::start().await?;
let address: SocketAddr = ("127.0.0.1:0".parse().unwrap());
// test scenario
server.storage.create_user("tester").await?;
server.storage.set_password("tester", "password").await?;
let mut stream = TcpStream::connect(server.server.addr).await?;
let mut s = TestScope::new(&mut stream);
s.send("NICK tester").await?;
s.send("CAP REQ :sasl").await?;
s.send("USER UserName 0 * :Real Name").await?;
s.expect(":testserver CAP tester ACK :sasl").await?;
s.send("AUTHENTICATE PLAIN").await?;
s.expect(":testserver AUTHENTICATE +").await?;
server.server.terminate().await?;
assert_eq!(stream.read_u8().await.unwrap_err().kind(), ErrorKind::UnexpectedEof);
Ok(())
}

View File

@ -1,2 +0,0 @@
max_width = 120
chain_width = 120

View File

@ -83,7 +83,7 @@ pub async fn launch(
let key = match read_one(&mut SyncBufReader::new(File::open(config.key)?))? { let key = match read_one(&mut SyncBufReader::new(File::open(config.key)?))? {
Some(PemItem::ECKey(k) | PemItem::PKCS8Key(k) | PemItem::RSAKey(k)) => PrivateKey(k), Some(PemItem::ECKey(k) | PemItem::PKCS8Key(k) | PemItem::RSAKey(k)) => PrivateKey(k),
_ => panic!("no keys in file"), _ => return Err(fail("no keys in file")),
}; };
let loaded_config = Arc::new(LoadedConfig { let loaded_config = Arc::new(LoadedConfig {
@ -187,18 +187,33 @@ async fn handle_socket(
let mut xml_reader = NsReader::from_reader(BufReader::new(a)); let mut xml_reader = NsReader::from_reader(BufReader::new(a));
let mut xml_writer = Writer::new(b); let mut xml_writer = Writer::new(b);
let authenticated = socket_auth(&mut xml_reader, &mut xml_writer, &mut reader_buf, &mut storage).await?; pin!(termination);
log::debug!("User authenticated"); select! {
let mut connection = players.connect_to_player(authenticated.player_id.clone()).await; biased;
socket_final( _ = &mut termination =>{
&mut xml_reader, log::info!("Socket handling was terminated");
&mut xml_writer, return Ok(())
&mut reader_buf, },
&authenticated, authenticated = socket_auth(&mut xml_reader, &mut xml_writer, &mut reader_buf, &mut storage) => {
&mut connection, match authenticated {
&rooms, Ok(authenticated) => {
) let mut connection = players.connect_to_player(authenticated.player_id.clone()).await;
.await?; socket_final(
&mut xml_reader,
&mut xml_writer,
&mut reader_buf,
&authenticated,
&mut connection,
&rooms,
)
.await?;
},
Err(err) => {
log::error!("Authentication error: {:?}", err);
}
}
},
}
let a = xml_reader.into_inner().into_inner(); let a = xml_reader.into_inner().into_inner();
let b = xml_writer.into_inner(); let b = xml_writer.into_inner();
@ -214,7 +229,6 @@ async fn socket_force_tls(
use proto_xmpp::tls::*; use proto_xmpp::tls::*;
let xml_reader = &mut NsReader::from_reader(reader); let xml_reader = &mut NsReader::from_reader(reader);
let xml_writer = &mut Writer::new(writer); let xml_writer = &mut Writer::new(writer);
read_xml_header(xml_reader, reader_buf).await?;
let _ = ClientStreamStart::parse(xml_reader, reader_buf).await?; let _ = ClientStreamStart::parse(xml_reader, reader_buf).await?;
let event = Event::Decl(BytesDecl::new("1.0", None, None)); let event = Event::Decl(BytesDecl::new("1.0", None, None));
@ -246,7 +260,6 @@ async fn socket_auth(
reader_buf: &mut Vec<u8>, reader_buf: &mut Vec<u8>,
storage: &mut Storage, storage: &mut Storage,
) -> Result<Authenticated> { ) -> Result<Authenticated> {
read_xml_header(xml_reader, reader_buf).await?;
let _ = ClientStreamStart::parse(xml_reader, reader_buf).await?; let _ = ClientStreamStart::parse(xml_reader, reader_buf).await?;
xml_writer.write_event_async(Event::Decl(BytesDecl::new("1.0", None, None))).await?; xml_writer.write_event_async(Event::Decl(BytesDecl::new("1.0", None, None))).await?;
@ -312,7 +325,6 @@ async fn socket_final(
user_handle: &mut PlayerConnection, user_handle: &mut PlayerConnection,
rooms: &RoomRegistry, rooms: &RoomRegistry,
) -> Result<()> { ) -> Result<()> {
read_xml_header(xml_reader, reader_buf).await?;
let _ = ClientStreamStart::parse(xml_reader, reader_buf).await?; let _ = ClientStreamStart::parse(xml_reader, reader_buf).await?;
xml_writer.write_event_async(Event::Decl(BytesDecl::new("1.0", None, None))).await?; xml_writer.write_event_async(Event::Decl(BytesDecl::new("1.0", None, None))).await?;
@ -404,7 +416,7 @@ struct XmppConnection<'a> {
impl<'a> XmppConnection<'a> { impl<'a> XmppConnection<'a> {
async fn handle_packet(&mut self, output: &mut Vec<Event<'static>>, packet: ClientPacket) -> Result<bool> { async fn handle_packet(&mut self, output: &mut Vec<Event<'static>>, packet: ClientPacket) -> Result<bool> {
let res = match packet { let res = match packet {
proto::ClientPacket::Iq(iq) => { ClientPacket::Iq(iq) => {
self.handle_iq(output, iq).await; self.handle_iq(output, iq).await;
false false
} }
@ -412,11 +424,11 @@ impl<'a> XmppConnection<'a> {
self.handle_message(output, m).await?; self.handle_message(output, m).await?;
false false
} }
proto::ClientPacket::Presence(p) => { ClientPacket::Presence(p) => {
self.handle_presence(output, p).await?; self.handle_presence(output, p).await?;
false false
} }
proto::ClientPacket::StreamEnd => { ClientPacket::StreamEnd => {
ServerStreamEnd.serialize(output); ServerStreamEnd.serialize(output);
true true
} }
@ -424,25 +436,3 @@ impl<'a> XmppConnection<'a> {
Ok(res) Ok(res)
} }
} }
async fn read_xml_header(
xml_reader: &mut NsReader<(impl AsyncBufRead + Unpin)>,
reader_buf: &mut Vec<u8>,
) -> Result<()> {
if let Event::Decl(bytes) = xml_reader.read_event_into_async(reader_buf).await? {
// this is <?xml ...> header
if let Some(encoding) = bytes.encoding() {
let encoding = encoding?;
if &*encoding == b"UTF-8" {
Ok(())
} else {
Err(anyhow!("Unsupported encoding: {encoding:?}"))
}
} else {
// Err(fail("No XML encoding provided"))
Ok(())
}
} else {
Err(anyhow!("Expected XML header"))
}
}

View File

@ -1,3 +1,5 @@
use std::io::ErrorKind;
use std::net::SocketAddr;
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
@ -6,7 +8,7 @@ use assert_matches::*;
use prometheus::Registry as MetricsRegistry; use prometheus::Registry as MetricsRegistry;
use quick_xml::events::Event; use quick_xml::events::Event;
use quick_xml::NsReader; use quick_xml::NsReader;
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader}; use tokio::io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader};
use tokio::io::{ReadHalf as GenericReadHalf, WriteHalf as GenericWriteHalf}; use tokio::io::{ReadHalf as GenericReadHalf, WriteHalf as GenericWriteHalf};
use tokio::net::tcp::{ReadHalf, WriteHalf}; use tokio::net::tcp::{ReadHalf, WriteHalf};
use tokio::net::TcpStream; use tokio::net::TcpStream;
@ -122,7 +124,7 @@ impl ServerCertVerifier for IgnoreCertVerification {
#[tokio::test] #[tokio::test]
async fn scenario_basic() -> Result<()> { async fn scenario_basic() -> Result<()> {
tracing_subscriber::fmt::init(); tracing_subscriber::fmt::try_init();
let config = ServerConfig { let config = ServerConfig {
listen_on: "127.0.0.1:0".parse().unwrap(), listen_on: "127.0.0.1:0".parse().unwrap(),
cert: "tests/certs/xmpp.pem".parse().unwrap(), cert: "tests/certs/xmpp.pem".parse().unwrap(),
@ -184,3 +186,61 @@ async fn scenario_basic() -> Result<()> {
server.terminate().await?; server.terminate().await?;
Ok(()) Ok(())
} }
#[tokio::test]
async fn terminate_socket() -> Result<()> {
tracing_subscriber::fmt::try_init();
let config = ServerConfig {
listen_on: "127.0.0.1:0".parse().unwrap(),
cert: "tests/certs/xmpp.pem".parse().unwrap(),
key: "tests/certs/xmpp.key".parse().unwrap(),
};
let mut metrics = MetricsRegistry::new();
let mut storage = Storage::open(StorageConfig {
db_path: ":memory:".into(),
})
.await?;
let rooms = RoomRegistry::new(&mut metrics, storage.clone()).unwrap();
let players = PlayerRegistry::empty(rooms.clone(), &mut metrics).unwrap();
let server = launch(config, players, rooms, metrics, storage.clone()).await.unwrap();
let address: SocketAddr = ("127.0.0.1:0".parse().unwrap());
// test scenario
storage.create_user("tester").await?;
storage.set_password("tester", "password").await?;
let mut stream = TcpStream::connect(server.addr).await?;
let mut s = TestScope::new(&mut stream);
tracing::info!("TCP connection established");
s.send(r#"<?xml version="1.0"?>"#).await?;
s.send(r#"<stream:stream xmlns:stream="http://etherx.jabber.org/streams" to="127.0.0.1" xml:lang="en" xmlns:xml="http://www.w3.org/XML/1998/namespace" xmlns="jabber:client" version="1.0">"#).await?;
assert_matches!(s.next_xml_event().await?, Event::Decl(_) => {});
assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"stream"));
assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"features"));
assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"starttls"));
assert_matches!(s.next_xml_event().await?, Event::Empty(b) => assert_eq!(b.local_name().into_inner(), b"required"));
assert_matches!(s.next_xml_event().await?, Event::End(b) => assert_eq!(b.local_name().into_inner(), b"starttls"));
assert_matches!(s.next_xml_event().await?, Event::End(b) => assert_eq!(b.local_name().into_inner(), b"features"));
s.send(r#"<starttls/>"#).await?;
assert_matches!(s.next_xml_event().await?, Event::Empty(b) => assert_eq!(b.local_name().into_inner(), b"proceed"));
let buffer = s.buffer;
let connector = TlsConnector::from(Arc::new(
ClientConfig::builder()
.with_safe_defaults()
.with_custom_certificate_verifier(Arc::new(IgnoreCertVerification))
.with_no_client_auth(),
));
tracing::info!("Initiating TLS connection...");
let mut stream = connector.connect(ServerName::IpAddress(server.addr.ip()), stream).await?;
tracing::info!("TLS connection established");
server.terminate().await?;
assert_eq!(stream.read_u8().await.unwrap_err().kind(), ErrorKind::UnexpectedEof);
Ok(())
}

View File

@ -7,42 +7,42 @@ use nonempty::NonEmpty;
/// Client-to-server command. /// Client-to-server command.
#[derive(Clone, Debug, PartialEq, Eq)] #[derive(Clone, Debug, PartialEq, Eq)]
pub enum ClientMessage { pub enum ClientMessage {
/// CAP. Capability-related commands. /// `CAP`. Capability-related commands.
Capability { Capability {
subcommand: CapabilitySubcommand, subcommand: CapabilitySubcommand,
}, },
/// PING <token> /// `PING <token>`
Ping { Ping {
token: Str, token: Str,
}, },
/// PONG <token> /// `PONG <token>`
Pong { Pong {
token: Str, token: Str,
}, },
/// NICK <nickname> /// `NICK <nickname>`
Nick { Nick {
nickname: Str, nickname: Str,
}, },
/// PASS <password> /// `PASS <password>`
Pass { Pass {
password: Str, password: Str,
}, },
/// USER <username> 0 * :<realname> /// `USER <username> 0 * :<realname>`
User { User {
username: Str, username: Str,
realname: Str, realname: Str,
}, },
/// JOIN <chan> /// `JOIN <chan>`
Join(Chan), Join(Chan),
/// MODE <target> /// `MODE <target>`
Mode { Mode {
target: Recipient, target: Recipient,
}, },
/// WHO <target> /// `WHO <target>`
Who { Who {
target: Recipient, // aka mask target: Recipient, // aka mask
}, },
/// TOPIC <chan> :<topic> /// `TOPIC <chan> :<topic>`
Topic { Topic {
chan: Chan, chan: Chan,
topic: Str, topic: Str,
@ -51,12 +51,12 @@ pub enum ClientMessage {
chan: Chan, chan: Chan,
message: Str, message: Str,
}, },
/// PRIVMSG <target> :<msg> /// `PRIVMSG <target> :<msg>`
PrivateMessage { PrivateMessage {
recipient: Recipient, recipient: Recipient,
body: Str, body: Str,
}, },
/// QUIT :<reason> /// `QUIT :<reason>`
Quit { Quit {
reason: Str, reason: Str,
}, },

View File

@ -30,11 +30,15 @@ fn token(input: &str) -> IResult<&str, &str> {
take_while(|i| i != '\n' && i != '\r')(input) take_while(|i| i != '\n' && i != '\r')(input)
} }
fn params(input: &str) -> IResult<&str, &str> {
take_while(|i| i != '\n' && i != '\r' && i != ':')(input)
}
#[derive(Clone, Debug, PartialEq, Eq)] #[derive(Clone, Debug, PartialEq, Eq)]
pub enum Chan { pub enum Chan {
/// #<name> — network-global channel, available from any server in the network. /// `#<name>` — network-global channel, available from any server in the network.
Global(Str), Global(Str),
/// &<name> — server-local channel, available only to connections to the same server. Rarely used in practice. /// `&<name>` — server-local channel, available only to connections to the same server. Rarely used in practice.
Local(Str), Local(Str),
} }
impl Chan { impl Chan {
@ -114,9 +118,7 @@ mod test {
assert_matches!(result, Ok((_, result)) => assert_eq!(expected, result)); assert_matches!(result, Ok((_, result)) => assert_eq!(expected, result));
let mut bytes = vec![]; let mut bytes = vec![];
sync_future(expected.write_async(&mut bytes)) sync_future(expected.write_async(&mut bytes)).unwrap().unwrap();
.unwrap()
.unwrap();
assert_eq!(bytes.as_slice(), input.as_bytes()); assert_eq!(bytes.as_slice(), input.as_bytes());
} }
@ -130,9 +132,7 @@ mod test {
assert_matches!(result, Ok((_, result)) => assert_eq!(expected, result)); assert_matches!(result, Ok((_, result)) => assert_eq!(expected, result));
let mut bytes = vec![]; let mut bytes = vec![];
sync_future(expected.write_async(&mut bytes)) sync_future(expected.write_async(&mut bytes)).unwrap().unwrap();
.unwrap()
.unwrap();
assert_eq!(bytes.as_slice(), input.as_bytes()); assert_eq!(bytes.as_slice(), input.as_bytes());
} }
@ -146,9 +146,7 @@ mod test {
assert_matches!(result, Ok((_, result)) => assert_eq!(expected, result)); assert_matches!(result, Ok((_, result)) => assert_eq!(expected, result));
let mut bytes = vec![]; let mut bytes = vec![];
sync_future(expected.write_async(&mut bytes)) sync_future(expected.write_async(&mut bytes)).unwrap().unwrap();
.unwrap()
.unwrap();
assert_eq!(bytes.as_slice(), input.as_bytes()); assert_eq!(bytes.as_slice(), input.as_bytes());
} }

View File

@ -1,3 +1,5 @@
use std::sync::Arc;
use nonempty::NonEmpty; use nonempty::NonEmpty;
use tokio::io::AsyncWrite; use tokio::io::AsyncWrite;
use tokio::io::AsyncWriteExt; use tokio::io::AsyncWriteExt;
@ -152,7 +154,11 @@ pub enum ServerMessageBody {
N903SaslSuccess { N903SaslSuccess {
nick: Str, nick: Str,
message: Str, message: Str,
} },
N904SaslFail {
nick: Str,
text: Str,
},
} }
impl ServerMessageBody { impl ServerMessageBody {
@ -267,11 +273,7 @@ impl ServerMessageBody {
writer.write_all(b" :").await?; writer.write_all(b" :").await?;
writer.write_all(msg.as_bytes()).await?; writer.write_all(msg.as_bytes()).await?;
} }
ServerMessageBody::N332Topic { ServerMessageBody::N332Topic { client, chat, topic } => {
client,
chat,
topic,
} => {
writer.write_all(b"332 ").await?; writer.write_all(b"332 ").await?;
writer.write_all(client.as_bytes()).await?; writer.write_all(client.as_bytes()).await?;
writer.write_all(b" ").await?; writer.write_all(b" ").await?;
@ -309,20 +311,14 @@ impl ServerMessageBody {
writer.write_all(b" ").await?; writer.write_all(b" ").await?;
writer.write_all(realname.as_bytes()).await?; writer.write_all(realname.as_bytes()).await?;
} }
ServerMessageBody::N353NamesReply { ServerMessageBody::N353NamesReply { client, chan, members } => {
client,
chan,
members,
} => {
writer.write_all(b"353 ").await?; writer.write_all(b"353 ").await?;
writer.write_all(client.as_bytes()).await?; writer.write_all(client.as_bytes()).await?;
writer.write_all(b" = ").await?; writer.write_all(b" = ").await?;
chan.write_async(writer).await?; chan.write_async(writer).await?;
writer.write_all(b" :").await?; writer.write_all(b" :").await?;
for member in members { for member in members {
writer writer.write_all(member.prefix.to_string().as_bytes()).await?;
.write_all(member.prefix.to_string().as_bytes())
.await?;
writer.write_all(member.nick.as_bytes()).await?; writer.write_all(member.nick.as_bytes()).await?;
writer.write_all(b" ").await?; writer.write_all(b" ").await?;
} }
@ -334,11 +330,7 @@ impl ServerMessageBody {
chan.write_async(writer).await?; chan.write_async(writer).await?;
writer.write_all(b" :End of /NAMES list").await?; writer.write_all(b" :End of /NAMES list").await?;
} }
ServerMessageBody::N474BannedFromChan { ServerMessageBody::N474BannedFromChan { client, chan, message } => {
client,
chan,
message,
} => {
writer.write_all(b"474 ").await?; writer.write_all(b"474 ").await?;
writer.write_all(client.as_bytes()).await?; writer.write_all(client.as_bytes()).await?;
writer.write_all(b" ").await?; writer.write_all(b" ").await?;
@ -353,7 +345,12 @@ impl ServerMessageBody {
writer.write_all(b" :").await?; writer.write_all(b" :").await?;
writer.write_all(message.as_bytes()).await?; writer.write_all(message.as_bytes()).await?;
} }
ServerMessageBody::N900LoggedIn { nick, address, account, message } => { ServerMessageBody::N900LoggedIn {
nick,
address,
account,
message,
} => {
writer.write_all(b"900 ").await?; writer.write_all(b"900 ").await?;
writer.write_all(nick.as_bytes()).await?; writer.write_all(nick.as_bytes()).await?;
writer.write_all(b" ").await?; writer.write_all(b" ").await?;
@ -369,6 +366,13 @@ impl ServerMessageBody {
writer.write_all(b" :").await?; writer.write_all(b" :").await?;
writer.write_all(message.as_bytes()).await?; writer.write_all(message.as_bytes()).await?;
} }
ServerMessageBody::N904SaslFail { nick, text } => {
writer.write_all(b"904").await?;
writer.write_all(b" ").await?;
writer.write_all(nick.as_bytes()).await?;
writer.write_all(b" :").await?;
writer.write_all(text.as_bytes()).await?;
}
} }
Ok(()) Ok(())
} }

View File

@ -42,27 +42,19 @@ impl Display for Jid {
impl Jid { impl Jid {
pub fn from_string(i: &str) -> Result<Jid> { pub fn from_string(i: &str) -> Result<Jid> {
use regex::Regex;
use lazy_static::lazy_static; use lazy_static::lazy_static;
use regex::Regex;
lazy_static! { lazy_static! {
static ref RE: Regex = Regex::new(r"^(([a-zA-Z]+)@)?([a-zA-Z.]+)(/([a-zA-Z\-]+))?$").unwrap(); static ref RE: Regex = Regex::new(r"^(([a-zA-Z]+)@)?([a-zA-Z.]+)(/([a-zA-Z\-]+))?$").unwrap();
} }
let m = RE let m = RE.captures(i).ok_or(anyhow!("Incorrectly format jid: {i}"))?;
.captures(i)
.ok_or(anyhow!("Incorrectly format jid: {i}"))?;
let name = m.get(2).map(|name| Name(name.as_str().into())); let name = m.get(2).map(|name| Name(name.as_str().into()));
let server = m.get(3).unwrap(); let server = m.get(3).unwrap();
let server = Server(server.as_str().into()); let server = Server(server.as_str().into());
let resource = m let resource = m.get(5).map(|resource| Resource(resource.as_str().into()));
.get(5)
.map(|resource| Resource(resource.as_str().into()));
Ok(Jid { Ok(Jid { name, server, resource })
name,
server,
resource,
})
} }
} }
@ -137,9 +129,7 @@ pub struct BindResponse(pub Jid);
impl ToXml for BindResponse { impl ToXml for BindResponse {
fn serialize(&self, events: &mut Vec<Event<'static>>) { fn serialize(&self, events: &mut Vec<Event<'static>>) {
events.extend_from_slice(&[ events.extend_from_slice(&[
Event::Start(BytesStart::new( Event::Start(BytesStart::new(r#"bind xmlns="urn:ietf:params:xml:ns:xmpp-bind""#)),
r#"bind xmlns="urn:ietf:params:xml:ns:xmpp-bind""#,
)),
Event::Start(BytesStart::new(r#"jid"#)), Event::Start(BytesStart::new(r#"jid"#)),
Event::Text(BytesText::new(self.0.to_string().as_str()).into_owned()), Event::Text(BytesText::new(self.0.to_string().as_str()).into_owned()),
Event::End(BytesEnd::new("jid")), Event::End(BytesEnd::new("jid")),
@ -156,23 +146,16 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn parse_message() { async fn parse_message() {
let input = let input = r#"<bind xmlns="urn:ietf:params:xml:ns:xmpp-bind"><resource>mobile</resource></bind>"#;
r#"<bind xmlns="urn:ietf:params:xml:ns:xmpp-bind"><resource>mobile</resource></bind>"#;
let mut reader = NsReader::from_reader(input.as_bytes()); let mut reader = NsReader::from_reader(input.as_bytes());
let mut buf = vec![]; let mut buf = vec![];
let (ns, event) = reader let (ns, event) = reader.read_resolved_event_into_async(&mut buf).await.unwrap();
.read_resolved_event_into_async(&mut buf)
.await
.unwrap();
let mut parser = BindRequest::parse().consume(ns, &event); let mut parser = BindRequest::parse().consume(ns, &event);
let result = loop { let result = loop {
match parser { match parser {
Continuation::Final(res) => break res, Continuation::Final(res) => break res,
Continuation::Continue(next) => { Continuation::Continue(next) => {
let (ns, event) = reader let (ns, event) = reader.read_resolved_event_into_async(&mut buf).await.unwrap();
.read_resolved_event_into_async(&mut buf)
.await
.unwrap();
parser = next.consume(ns, &event); parser = next.consume(ns, &event);
} }
} }

View File

@ -2,8 +2,8 @@ use quick_xml::events::attributes::Attribute;
use quick_xml::events::{BytesEnd, BytesStart, Event}; use quick_xml::events::{BytesEnd, BytesStart, Event};
use quick_xml::name::{QName, ResolveResult}; use quick_xml::name::{QName, ResolveResult};
use anyhow::{Result, anyhow as ffail};
use crate::xml::*; use crate::xml::*;
use anyhow::{anyhow as ffail, Result};
use super::bind::Jid; use super::bind::Jid;
@ -174,11 +174,7 @@ impl FromXml for Identity {
let Some(r#type) = r#type else { let Some(r#type) = r#type else {
return Err(ffail!("No type provided")); return Err(ffail!("No type provided"));
}; };
let item = Identity { let item = Identity { category, name, r#type };
category,
name,
r#type,
};
if end { if end {
return Ok(item); return Ok(item);
} }

View File

@ -1,21 +1,16 @@
#![feature( #![feature(coroutines, coroutine_trait, type_alias_impl_trait, impl_trait_in_assoc_type)]
coroutines,
coroutine_trait,
type_alias_impl_trait,
impl_trait_in_assoc_type
)]
pub mod bind; pub mod bind;
pub mod client; pub mod client;
pub mod disco; pub mod disco;
pub mod muc; pub mod muc;
mod prelude;
pub mod roster; pub mod roster;
pub mod sasl; pub mod sasl;
pub mod session; pub mod session;
pub mod stanzaerror; pub mod stanzaerror;
pub mod stream; pub mod stream;
pub mod tls; pub mod tls;
mod prelude;
pub mod xml; pub mod xml;
// Implemented as a macro instead of a fn due to borrowck limitations // Implemented as a macro instead of a fn due to borrowck limitations

View File

@ -52,9 +52,6 @@ impl FromXmlTag for RosterQuery {
impl ToXml for RosterQuery { impl ToXml for RosterQuery {
fn serialize(&self, events: &mut Vec<Event<'static>>) { fn serialize(&self, events: &mut Vec<Event<'static>>) {
events.push(Event::Empty(BytesStart::new(format!( events.push(Event::Empty(BytesStart::new(format!(r#"query xmlns="{}""#, XMLNS))));
r#"query xmlns="{}""#,
XMLNS
))));
} }
} }

View File

@ -25,9 +25,7 @@ impl Parser for SessionParser {
) -> Continuation<Self, Self::Output> { ) -> Continuation<Self, Self::Output> {
match self.0 { match self.0 {
SessionParserInner::Initial => match event { SessionParserInner::Initial => match event {
Event::Start(_) => { Event::Start(_) => Continuation::Continue(SessionParser(SessionParserInner::InSession)),
Continuation::Continue(SessionParser(SessionParserInner::InSession))
}
Event::Empty(_) => Continuation::Final(Ok(Session)), Event::Empty(_) => Continuation::Final(Ok(Session)),
_ => Continuation::Final(Err(anyhow!("Unexpected XML event: {event:?}"))), _ => Continuation::Final(Err(anyhow!("Unexpected XML event: {event:?}"))),
}, },
@ -54,9 +52,6 @@ impl FromXmlTag for Session {
impl ToXml for Session { impl ToXml for Session {
fn serialize(&self, events: &mut Vec<Event<'static>>) { fn serialize(&self, events: &mut Vec<Event<'static>>) {
events.push(Event::Empty(BytesStart::new(format!( events.push(Event::Empty(BytesStart::new(format!(r#"session xmlns="{}""#, XMLNS))));
r#"session xmlns="{}""#,
XMLNS
))));
} }
} }

View File

@ -6,8 +6,8 @@ use tokio::io::{AsyncBufRead, AsyncWrite};
use super::skip_text; use super::skip_text;
use anyhow::{anyhow, Result};
use crate::xml::ToXml; use crate::xml::ToXml;
use anyhow::{anyhow, Result};
pub static XMLNS: &'static str = "http://etherx.jabber.org/streams"; pub static XMLNS: &'static str = "http://etherx.jabber.org/streams";
pub static PREFIX: &'static str = "stream"; pub static PREFIX: &'static str = "stream";
@ -24,14 +24,24 @@ impl ClientStreamStart {
reader: &mut NsReader<impl AsyncBufRead + Unpin>, reader: &mut NsReader<impl AsyncBufRead + Unpin>,
buf: &mut Vec<u8>, buf: &mut Vec<u8>,
) -> Result<ClientStreamStart> { ) -> Result<ClientStreamStart> {
let incoming = skip_text!(reader, buf); let mut incoming = skip_text!(reader, buf);
if let Event::Decl(bytes) = incoming {
// this is <?xml ...> header
if let Some(encoding) = bytes.encoding() {
let encoding = encoding?;
if &*encoding != b"UTF-8" {
return Err(anyhow!("Unsupported encoding: {encoding:?}"));
}
}
incoming = skip_text!(reader, buf);
}
if let Event::Start(e) = incoming { if let Event::Start(e) = incoming {
let (ns, local) = reader.resolve_element(e.name()); let (ns, local) = reader.resolve_element(e.name());
if ns != ResolveResult::Bound(Namespace(XMLNS.as_bytes())) { if ns != ResolveResult::Bound(Namespace(XMLNS.as_bytes())) {
return Err(panic!()); return Err(anyhow!("Invalid namespace for stream element"));
} }
if local.into_inner() != b"stream" { if local.into_inner() != b"stream" {
return Err(panic!()); return Err(anyhow!("Invalid local name for stream element"));
} }
let mut to = None; let mut to = None;
let mut lang = None; let mut lang = None;
@ -44,10 +54,7 @@ impl ClientStreamStart {
let value = attr.unescape_value()?; let value = attr.unescape_value()?;
to = Some(value.to_string()); to = Some(value.to_string());
} }
( (ResolveResult::Bound(Namespace(b"http://www.w3.org/XML/1998/namespace")), b"lang") => {
ResolveResult::Bound(Namespace(b"http://www.w3.org/XML/1998/namespace")),
b"lang",
) => {
let value = attr.unescape_value()?; let value = attr.unescape_value()?;
lang = Some(value.to_string()); lang = Some(value.to_string());
} }
@ -64,7 +71,7 @@ impl ClientStreamStart {
version: version.unwrap(), version: version.unwrap(),
}) })
} else { } else {
Err(panic!()) Err(anyhow!("Incoming message does not belong XML Start Event"))
} }
} }
} }
@ -124,21 +131,15 @@ pub struct Features {
} }
impl Features { impl Features {
pub async fn write_xml(&self, writer: &mut Writer<impl AsyncWrite + Unpin>) -> Result<()> { pub async fn write_xml(&self, writer: &mut Writer<impl AsyncWrite + Unpin>) -> Result<()> {
writer writer.write_event_async(Event::Start(BytesStart::new("stream:features"))).await?;
.write_event_async(Event::Start(BytesStart::new("stream:features")))
.await?;
if self.start_tls { if self.start_tls {
writer writer
.write_event_async(Event::Start(BytesStart::new( .write_event_async(Event::Start(BytesStart::new(
r#"starttls xmlns="urn:ietf:params:xml:ns:xmpp-tls""#, r#"starttls xmlns="urn:ietf:params:xml:ns:xmpp-tls""#,
))) )))
.await?; .await?;
writer writer.write_event_async(Event::Empty(BytesStart::new("required"))).await?;
.write_event_async(Event::Empty(BytesStart::new("required"))) writer.write_event_async(Event::End(BytesEnd::new("starttls"))).await?;
.await?;
writer
.write_event_async(Event::End(BytesEnd::new("starttls")))
.await?;
} }
if self.mechanisms { if self.mechanisms {
writer writer
@ -146,18 +147,10 @@ impl Features {
r#"mechanisms xmlns="urn:ietf:params:xml:ns:xmpp-sasl""#, r#"mechanisms xmlns="urn:ietf:params:xml:ns:xmpp-sasl""#,
))) )))
.await?; .await?;
writer writer.write_event_async(Event::Start(BytesStart::new(r#"mechanism"#))).await?;
.write_event_async(Event::Start(BytesStart::new(r#"mechanism"#))) writer.write_event_async(Event::Text(BytesText::new("PLAIN"))).await?;
.await?; writer.write_event_async(Event::End(BytesEnd::new("mechanism"))).await?;
writer writer.write_event_async(Event::End(BytesEnd::new("mechanisms"))).await?;
.write_event_async(Event::Text(BytesText::new("PLAIN")))
.await?;
writer
.write_event_async(Event::End(BytesEnd::new("mechanism")))
.await?;
writer
.write_event_async(Event::End(BytesEnd::new("mechanisms")))
.await?;
} }
if self.bind { if self.bind {
writer writer
@ -166,9 +159,7 @@ impl Features {
))) )))
.await?; .await?;
} }
writer writer.write_event_async(Event::End(BytesEnd::new("stream:features"))).await?;
.write_event_async(Event::End(BytesEnd::new("stream:features")))
.await?;
Ok(()) Ok(())
} }
} }
@ -182,9 +173,7 @@ mod test {
let input = r###"<stream:stream xmlns:stream="http://etherx.jabber.org/streams" to="vlnv.dev" version="1.0" xmlns="jabber:client" xml:lang="en" xmlns:xml="http://www.w3.org/XML/1998/namespace">"###; let input = r###"<stream:stream xmlns:stream="http://etherx.jabber.org/streams" to="vlnv.dev" version="1.0" xmlns="jabber:client" xml:lang="en" xmlns:xml="http://www.w3.org/XML/1998/namespace">"###;
let mut reader = NsReader::from_reader(input.as_bytes()); let mut reader = NsReader::from_reader(input.as_bytes());
let mut buf = vec![]; let mut buf = vec![];
let res = ClientStreamStart::parse(&mut reader, &mut buf) let res = ClientStreamStart::parse(&mut reader, &mut buf).await.unwrap();
.await
.unwrap();
assert_eq!( assert_eq!(
res, res,
ClientStreamStart { ClientStreamStart {

View File

@ -12,10 +12,7 @@ pub static XMLNS: &'static str = "urn:ietf:params:xml:ns:xmpp-tls";
pub struct StartTLS; pub struct StartTLS;
impl StartTLS { impl StartTLS {
pub async fn parse( pub async fn parse(reader: &mut NsReader<impl AsyncBufRead + Unpin>, buf: &mut Vec<u8>) -> Result<StartTLS> {
reader: &mut NsReader<impl AsyncBufRead + Unpin>,
buf: &mut Vec<u8>,
) -> Result<StartTLS> {
let incoming = skip_text!(reader, buf); let incoming = skip_text!(reader, buf);
if let Event::Empty(ref e) = incoming { if let Event::Empty(ref e) = incoming {
if e.name().0 == b"starttls" { if e.name().0 == b"starttls" {

View File

@ -15,11 +15,7 @@ enum IgnoreParserInner {
impl Parser for IgnoreParser { impl Parser for IgnoreParser {
type Output = Result<Ignore>; type Output = Result<Ignore>;
fn consume<'a>( fn consume<'a>(self: Self, _: ResolveResult, event: &Event<'a>) -> Continuation<Self, Self::Output> {
self: Self,
_: ResolveResult,
event: &Event<'a>,
) -> Continuation<Self, Self::Output> {
match self.0 { match self.0 {
IgnoreParserInner::Initial => match event { IgnoreParserInner::Initial => match event {
Event::Start(bytes) => { Event::Start(bytes) => {
@ -34,13 +30,7 @@ impl Parser for IgnoreParser {
if depth == 0 { if depth == 0 {
Continuation::Final(Ok(Ignore)) Continuation::Final(Ok(Ignore))
} else { } else {
Continuation::Continue( Continuation::Continue(IgnoreParserInner::InTag { name, depth: depth - 1 }.into())
IgnoreParserInner::InTag {
name,
depth: depth - 1,
}
.into(),
)
} }
} }
_ => Continuation::Continue(IgnoreParserInner::InTag { name, depth }.into()), _ => Continuation::Continue(IgnoreParserInner::InTag { name, depth }.into()),

View File

@ -1,9 +1,9 @@
use std::ops::Coroutine; use std::ops::Coroutine;
use std::pin::Pin; use std::pin::Pin;
use quick_xml::NsReader;
use quick_xml::events::Event; use quick_xml::events::Event;
use quick_xml::name::ResolveResult; use quick_xml::name::ResolveResult;
use quick_xml::NsReader;
use anyhow::Result; use anyhow::Result;
@ -28,25 +28,16 @@ pub trait FromXmlTag: FromXml {
pub trait Parser: Sized { pub trait Parser: Sized {
type Output; type Output;
fn consume<'a>( fn consume<'a>(self: Self, namespace: ResolveResult, event: &Event<'a>) -> Continuation<Self, Self::Output>;
self: Self,
namespace: ResolveResult,
event: &Event<'a>,
) -> Continuation<Self, Self::Output>;
} }
impl<T, Out> Parser for T impl<T, Out> Parser for T
where where
T: Coroutine<(ResolveResult<'static>, &'static Event<'static>), Yield = (), Return = Out> T: Coroutine<(ResolveResult<'static>, &'static Event<'static>), Yield = (), Return = Out> + Unpin,
+ Unpin,
{ {
type Output = Out; type Output = Out;
fn consume<'a>( fn consume<'a>(mut self: Self, namespace: ResolveResult, event: &Event<'a>) -> Continuation<Self, Self::Output> {
mut self: Self,
namespace: ResolveResult,
event: &Event<'a>,
) -> Continuation<Self, Self::Output> {
let s = Pin::new(&mut self); let s = Pin::new(&mut self);
// this is a very rude workaround fixing the fact that rust coroutines // this is a very rude workaround fixing the fact that rust coroutines
// 1. don't support higher-kinded lifetimes (i.e. no `impl for <'a> Coroutine<Event<'a>>) // 1. don't support higher-kinded lifetimes (i.e. no `impl for <'a> Coroutine<Event<'a>>)

View File

@ -1 +1 @@
nightly-2023-12-07 nightly-2024-03-20

View File

@ -1 +1,2 @@
max_width = 120 max_width = 120
chain_width = 120

View File

@ -62,7 +62,14 @@ async fn main() -> Result<()> {
storage.clone(), storage.clone(),
) )
.await?; .await?;
let xmpp = projection_xmpp::launch(xmpp_config, players.clone(), rooms.clone(), metrics.clone(), storage.clone()).await?; let xmpp = projection_xmpp::launch(
xmpp_config,
players.clone(),
rooms.clone(),
metrics.clone(),
storage.clone(),
)
.await?;
tracing::info!("Started"); tracing::info!("Started");
sleep.await; sleep.await;