forked from lavina/lavina
1
0
Fork 0

Compare commits

...

10 Commits

26 changed files with 828 additions and 163 deletions

53
Cargo.lock generated
View File

@ -204,7 +204,7 @@ dependencies = [
"http-body 0.4.6", "http-body 0.4.6",
"hyper 0.14.28", "hyper 0.14.28",
"itoa", "itoa",
"matchit", "matchit 0.7.3",
"memchr", "memchr",
"mime", "mime",
"percent-encoding", "percent-encoding",
@ -709,8 +709,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "94b22e06ecb0110981051723910cbf0b5f5e09a2062dd7663334ee79a9d1286c" checksum = "94b22e06ecb0110981051723910cbf0b5f5e09a2062dd7663334ee79a9d1286c"
dependencies = [ dependencies = [
"cfg-if", "cfg-if",
"js-sys",
"libc", "libc",
"wasi", "wasi",
"wasm-bindgen",
] ]
[[package]] [[package]]
@ -1062,6 +1064,7 @@ version = "0.0.3-dev"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"assert_matches", "assert_matches",
"chrono",
"clap", "clap",
"derive_more", "derive_more",
"figment", "figment",
@ -1096,8 +1099,13 @@ dependencies = [
"anyhow", "anyhow",
"argon2", "argon2",
"chrono", "chrono",
"mgmt-api",
"opentelemetry",
"prometheus", "prometheus",
"rand_core", "rand_core",
"reqwest",
"reqwest-middleware",
"reqwest-tracing",
"serde", "serde",
"sqlx", "sqlx",
"tokio", "tokio",
@ -1164,6 +1172,12 @@ version = "0.7.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0e7465ac9959cc2b1404e8e2367b43684a6d13790fe23056cc8c6c5a6b7bcb94" checksum = "0e7465ac9959cc2b1404e8e2367b43684a6d13790fe23056cc8c6c5a6b7bcb94"
[[package]]
name = "matchit"
version = "0.8.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "540f1c43aed89909c0cc0cc604e3bb2f7e7a341a3728a9e6cfe760e733cd11ed"
[[package]] [[package]]
name = "md-5" name = "md-5"
version = "0.10.6" version = "0.10.6"
@ -1763,9 +1777,9 @@ checksum = "adad44e29e4c806119491a7f06f03de4d1af22c3a680dd47f1e6e179439d1f56"
[[package]] [[package]]
name = "reqwest" name = "reqwest"
version = "0.12.3" version = "0.12.4"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3e6cc1e89e689536eb5aeede61520e874df5a4707df811cd5da4aa5fbb2aae19" checksum = "566cafdd92868e0939d3fb961bd0dc25fcfaaed179291093b3d43e6b3150ea10"
dependencies = [ dependencies = [
"base64 0.22.0", "base64 0.22.0",
"bytes", "bytes",
@ -1796,6 +1810,39 @@ dependencies = [
"winreg", "winreg",
] ]
[[package]]
name = "reqwest-middleware"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0209efb52486ad88136190094ee214759ef7507068b27992256ed6610eb71a01"
dependencies = [
"anyhow",
"async-trait",
"http 1.1.0",
"reqwest",
"serde",
"thiserror",
"tower-service",
]
[[package]]
name = "reqwest-tracing"
version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b253954a1979e02eabccd7e9c3d61d8f86576108baa160775e7f160bb4e800a3"
dependencies = [
"anyhow",
"async-trait",
"getrandom",
"http 1.1.0",
"matchit 0.8.2",
"opentelemetry",
"reqwest",
"reqwest-middleware",
"tracing",
"tracing-opentelemetry",
]
[[package]] [[package]]
name = "ring" name = "ring"
version = "0.17.8" version = "0.17.8"

View File

@ -32,6 +32,7 @@ lavina-core = { path = "crates/lavina-core" }
tracing-subscriber = "0.3.16" tracing-subscriber = "0.3.16"
sasl = { path = "crates/sasl" } sasl = { path = "crates/sasl" }
chrono = "0.4.37" chrono = "0.4.37"
reqwest = { version = "0.12.0", default-features = false, features = ["json"] }
[package] [package]
name = "lavina" name = "lavina"
@ -64,8 +65,9 @@ opentelemetry-semantic-conventions = "0.14.0"
opentelemetry_sdk = { version = "0.22.1", features = ["rt-tokio"] } opentelemetry_sdk = { version = "0.22.1", features = ["rt-tokio"] }
opentelemetry-otlp = "0.15.0" opentelemetry-otlp = "0.15.0"
tracing-opentelemetry = "0.23.0" tracing-opentelemetry = "0.23.0"
chrono.workspace = true
[dev-dependencies] [dev-dependencies]
assert_matches.workspace = true assert_matches.workspace = true
regex = "1.7.1" regex = "1.7.1"
reqwest = { version = "0.12.0", default-features = false } reqwest.workspace = true

31
config.0.toml Normal file
View File

@ -0,0 +1,31 @@
[telemetry]
listen_on = "127.0.0.1:8080"
[irc]
listen_on = "127.0.0.1:6667"
server_name = "irc.localhost"
[xmpp]
listen_on = "127.0.0.1:5222"
cert = "./certs/xmpp.pem"
key = "./certs/xmpp.key"
hostname = "localhost"
[storage]
db_path = "db.0.sqlite"
[cluster]
addresses = [
"127.0.0.1:8080",
"127.0.0.1:8081",
]
[cluster.metadata]
node_id = 0
main_owner = 0
test_owner = 1
test2_owner = 0
[tracing]
endpoint = "http://localhost:4317"
service_name = "lavina-0"

31
config.1.toml Normal file
View File

@ -0,0 +1,31 @@
[telemetry]
listen_on = "127.0.0.1:8081"
[irc]
listen_on = "127.0.0.1:6668"
server_name = "irc.localhost"
[xmpp]
listen_on = "127.0.0.1:5223"
cert = "./certs/xmpp.pem"
key = "./certs/xmpp.key"
hostname = "localhost"
[storage]
db_path = "db.1.sqlite"
[cluster]
addresses = [
"127.0.0.1:8080",
"127.0.0.1:8081",
]
[cluster.metadata]
node_id = 1
main_owner = 0
test_owner = 1
test2_owner = 0
[tracing]
endpoint = "http://localhost:4317"
service_name = "lavina-1"

View File

@ -1,15 +0,0 @@
[telemetry]
listen_on = "127.0.0.1:8080"
[irc]
listen_on = "127.0.0.1:6667"
server_name = "irc.localhost"
[xmpp]
listen_on = "127.0.0.1:5222"
cert = "./certs/xmpp.pem"
key = "./certs/xmpp.key"
hostname = "localhost"
[storage]
db_path = "db.sqlite"

View File

@ -13,3 +13,8 @@ prometheus.workspace = true
chrono.workspace = true chrono.workspace = true
argon2 = { version = "0.5.3" } argon2 = { version = "0.5.3" }
rand_core = { version = "0.6.4", features = ["getrandom"] } rand_core = { version = "0.6.4", features = ["getrandom"] }
reqwest.workspace = true
reqwest-middleware = { version = "0.3", features = ["json"] }
opentelemetry = "0.22.0"
mgmt-api = { path = "../mgmt-api" }
reqwest-tracing = { version = "0.5", features = ["opentelemetry_0_22"] }

View File

@ -0,0 +1,69 @@
use anyhow::{anyhow, Result};
use reqwest::Client;
use reqwest_middleware::{ClientBuilder, ClientWithMiddleware};
use reqwest_tracing::{DefaultSpanBackend, TracingMiddleware};
use serde::{Deserialize, Serialize};
use std::net::SocketAddr;
use std::sync::Arc;
type Addresses = Vec<SocketAddr>;
#[derive(Deserialize, Debug, Clone)]
pub struct ClusterConfig {
pub metadata: ClusterMetadata,
pub addresses: Addresses,
}
#[derive(Deserialize, Debug, Clone)]
pub struct ClusterMetadata {
pub node_id: u32,
/// Owns all rooms and players in the cluster.
pub main_owner: u32,
/// Owns the room `test`.
pub test_owner: u32,
/// Owns the room `test2`.
pub test2_owner: u32,
}
#[derive(Clone)]
pub struct LavinaClient {
addresses: Arc<Addresses>,
client: ClientWithMiddleware,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct SendMessageReq<'a> {
pub room_id: &'a str,
pub player_id: &'a str,
pub message: &'a str,
pub created_at: &'a str,
}
impl LavinaClient {
pub fn new(addresses: Addresses) -> Self {
let client = ClientBuilder::new(Client::new()).with(TracingMiddleware::<DefaultSpanBackend>::new()).build();
Self {
addresses: Arc::new(addresses),
client,
}
}
#[tracing::instrument(skip(self, req), name = "LavinaClient::send_room_message")]
pub async fn send_room_message(&self, node_id: u32, req: SendMessageReq<'_>) -> Result<()> {
tracing::info!("Sending a message to a room on a remote node");
let Some(address) = self.addresses.get(node_id as usize) else {
tracing::error!("Failed");
return Err(anyhow!("Unknown node"));
};
match self.client.post(format!("http://{}/cluster/rooms/add_message", address)).json(&req).send().await {
Ok(_) => {
tracing::info!("Message sent");
Ok(())
}
Err(e) => {
tracing::error!("Failed to send message: {e:?}");
Err(e.into())
}
}
}
}

View File

@ -1,6 +1,8 @@
//! Domain definitions and implementation of common chat logic. //! Domain definitions and implementation of common chat logic.
use crate::clustering::{ClusterConfig, LavinaClient};
use anyhow::Result; use anyhow::Result;
use prometheus::Registry as MetricsRegistry; use prometheus::Registry as MetricsRegistry;
use std::sync::Arc;
use crate::dialog::DialogRegistry; use crate::dialog::DialogRegistry;
use crate::player::PlayerRegistry; use crate::player::PlayerRegistry;
@ -8,6 +10,7 @@ use crate::repo::Storage;
use crate::room::RoomRegistry; use crate::room::RoomRegistry;
pub mod auth; pub mod auth;
pub mod clustering;
pub mod dialog; pub mod dialog;
pub mod player; pub mod player;
pub mod prelude; pub mod prelude;
@ -25,11 +28,23 @@ pub struct LavinaCore {
} }
impl LavinaCore { impl LavinaCore {
pub async fn new(mut metrics: MetricsRegistry, storage: Storage) -> Result<LavinaCore> { pub async fn new(
mut metrics: MetricsRegistry,
cluster_config: ClusterConfig,
storage: Storage,
) -> Result<LavinaCore> {
// TODO shutdown all services in reverse order on error // TODO shutdown all services in reverse order on error
let client = LavinaClient::new(cluster_config.addresses.clone());
let rooms = RoomRegistry::new(&mut metrics, storage.clone())?; let rooms = RoomRegistry::new(&mut metrics, storage.clone())?;
let dialogs = DialogRegistry::new(storage.clone()); let dialogs = DialogRegistry::new(storage.clone());
let players = PlayerRegistry::empty(rooms.clone(), dialogs.clone(), storage.clone(), &mut metrics)?; let players = PlayerRegistry::empty(
rooms.clone(),
dialogs.clone(),
storage.clone(),
&mut metrics,
Arc::new(cluster_config.metadata),
client,
)?;
dialogs.set_players(players.clone()).await; dialogs.set_players(players.clone()).await;
Ok(LavinaCore { Ok(LavinaCore {
players, players,

View File

@ -17,6 +17,7 @@ use tokio::sync::mpsc::{channel, Receiver, Sender};
use tokio::sync::RwLock; use tokio::sync::RwLock;
use tracing::{Instrument, Span}; use tracing::{Instrument, Span};
use crate::clustering::{ClusterMetadata, LavinaClient, SendMessageReq};
use crate::dialog::DialogRegistry; use crate::dialog::DialogRegistry;
use crate::prelude::*; use crate::prelude::*;
use crate::repo::Storage; use crate::repo::Storage;
@ -55,7 +56,7 @@ pub struct ConnectionId(pub AnonKey);
/// The connection is used to send commands to the player actor and to receive updates that might be sent to the client. /// The connection is used to send commands to the player actor and to receive updates that might be sent to the client.
pub struct PlayerConnection { pub struct PlayerConnection {
pub connection_id: ConnectionId, pub connection_id: ConnectionId,
pub receiver: Receiver<Updates>, pub receiver: Receiver<ConnectionMessage>,
player_handle: PlayerHandle, player_handle: PlayerHandle,
} }
impl PlayerConnection { impl PlayerConnection {
@ -160,7 +161,7 @@ impl PlayerHandle {
enum ActorCommand { enum ActorCommand {
/// Establish a new connection. /// Establish a new connection.
AddConnection { AddConnection {
sender: Sender<Updates>, sender: Sender<ConnectionMessage>,
promise: Promise<ConnectionId>, promise: Promise<ConnectionId>,
}, },
/// Terminate an existing connection. /// Terminate an existing connection.
@ -253,6 +254,8 @@ impl PlayerRegistry {
dialogs: DialogRegistry, dialogs: DialogRegistry,
storage: Storage, storage: Storage,
metrics: &mut MetricsRegistry, metrics: &mut MetricsRegistry,
cluster_metadata: Arc<ClusterMetadata>,
cluster_client: LavinaClient,
) -> Result<PlayerRegistry> { ) -> Result<PlayerRegistry> {
let metric_active_players = IntGauge::new("chat_players_active", "Number of alive player actors")?; let metric_active_players = IntGauge::new("chat_players_active", "Number of alive player actors")?;
metrics.register(Box::new(metric_active_players.clone()))?; metrics.register(Box::new(metric_active_players.clone()))?;
@ -260,6 +263,8 @@ impl PlayerRegistry {
room_registry, room_registry,
dialogs, dialogs,
storage, storage,
cluster_metadata,
cluster_client,
players: HashMap::new(), players: HashMap::new(),
metric_active_players, metric_active_players,
}; };
@ -276,12 +281,28 @@ impl PlayerRegistry {
Ok(()) Ok(())
} }
#[tracing::instrument(skip(self), name = "PlayerRegistry::get_player")]
pub async fn get_player(&self, id: &PlayerId) -> Option<PlayerHandle> { pub async fn get_player(&self, id: &PlayerId) -> Option<PlayerHandle> {
let inner = self.0.read().await; let inner = self.0.read().await;
inner.players.get(id).map(|(handle, _)| handle.clone()) inner.players.get(id).map(|(handle, _)| handle.clone())
} }
pub async fn get_or_launch_player(&mut self, id: &PlayerId) -> PlayerHandle { #[tracing::instrument(skip(self), name = "PlayerRegistry::stop_player")]
pub async fn stop_player(&self, id: &PlayerId) -> Result<Option<()>> {
let mut inner = self.0.write().await;
if let Some((handle, fiber)) = inner.players.remove(id) {
handle.send(ActorCommand::Stop).await;
drop(handle);
fiber.await?;
inner.metric_active_players.dec();
Ok(Some(()))
} else {
Ok(None)
}
}
#[tracing::instrument(skip(self), name = "PlayerRegistry::get_or_launch_player")]
pub async fn get_or_launch_player(&self, id: &PlayerId) -> PlayerHandle {
let inner = self.0.read().await; let inner = self.0.read().await;
if let Some((handle, _)) = inner.players.get(id) { if let Some((handle, _)) = inner.players.get(id) {
handle.clone() handle.clone()
@ -295,6 +316,8 @@ impl PlayerRegistry {
id.clone(), id.clone(),
inner.room_registry.clone(), inner.room_registry.clone(),
inner.dialogs.clone(), inner.dialogs.clone(),
inner.cluster_metadata.clone(),
inner.cluster_client.clone(),
inner.storage.clone(), inner.storage.clone(),
) )
.await; .await;
@ -305,7 +328,8 @@ impl PlayerRegistry {
} }
} }
pub async fn connect_to_player(&mut self, id: &PlayerId) -> PlayerConnection { #[tracing::instrument(skip(self), name = "PlayerRegistry::connect_to_player")]
pub async fn connect_to_player(&self, id: &PlayerId) -> PlayerConnection {
let player_handle = self.get_or_launch_player(id).await; let player_handle = self.get_or_launch_player(id).await;
player_handle.subscribe().await player_handle.subscribe().await
} }
@ -328,29 +352,40 @@ struct PlayerRegistryInner {
room_registry: RoomRegistry, room_registry: RoomRegistry,
dialogs: DialogRegistry, dialogs: DialogRegistry,
storage: Storage, storage: Storage,
cluster_metadata: Arc<ClusterMetadata>,
cluster_client: LavinaClient,
/// Active player actors. /// Active player actors.
players: HashMap<PlayerId, (PlayerHandle, JoinHandle<Player>)>, players: HashMap<PlayerId, (PlayerHandle, JoinHandle<Player>)>,
metric_active_players: IntGauge, metric_active_players: IntGauge,
} }
enum RoomRef {
Local(RoomHandle),
Remote { node_id: u32 },
}
/// Player actor inner state representation. /// Player actor inner state representation.
struct Player { struct Player {
player_id: PlayerId, player_id: PlayerId,
storage_id: u32, storage_id: u32,
connections: AnonTable<Sender<Updates>>, connections: AnonTable<Sender<ConnectionMessage>>,
my_rooms: HashMap<RoomId, RoomHandle>, my_rooms: HashMap<RoomId, RoomRef>,
banned_from: HashSet<RoomId>, banned_from: HashSet<RoomId>,
rx: Receiver<(ActorCommand, Span)>, rx: Receiver<(ActorCommand, Span)>,
handle: PlayerHandle, handle: PlayerHandle,
rooms: RoomRegistry, rooms: RoomRegistry,
dialogs: DialogRegistry, dialogs: DialogRegistry,
storage: Storage, storage: Storage,
cluster_metadata: Arc<ClusterMetadata>,
cluster_client: LavinaClient,
} }
impl Player { impl Player {
async fn launch( async fn launch(
player_id: PlayerId, player_id: PlayerId,
rooms: RoomRegistry, rooms: RoomRegistry,
dialogs: DialogRegistry, dialogs: DialogRegistry,
cluster_metadata: Arc<ClusterMetadata>,
cluster_client: LavinaClient,
storage: Storage, storage: Storage,
) -> (PlayerHandle, JoinHandle<Player>) { ) -> (PlayerHandle, JoinHandle<Player>) {
let (tx, rx) = channel(32); let (tx, rx) = channel(32);
@ -371,6 +406,8 @@ impl Player {
rooms, rooms,
dialogs, dialogs,
storage, storage,
cluster_metadata,
cluster_client,
}; };
let fiber = tokio::task::spawn(player.main_loop()); let fiber = tokio::task::spawn(player.main_loop());
(handle_clone, fiber) (handle_clone, fiber)
@ -379,11 +416,20 @@ impl Player {
async fn main_loop(mut self) -> Self { async fn main_loop(mut self) -> Self {
let rooms = self.storage.get_rooms_of_a_user(self.storage_id).await.unwrap(); let rooms = self.storage.get_rooms_of_a_user(self.storage_id).await.unwrap();
for room_id in rooms { for room_id in rooms {
let room = self.rooms.get_room(&room_id).await; let node = match &**room_id.as_inner() {
if let Some(room) = room { "aaaaa" => self.cluster_metadata.test_owner,
self.my_rooms.insert(room_id, room); "test" => self.cluster_metadata.test2_owner,
_ => self.cluster_metadata.main_owner,
};
if node == self.cluster_metadata.node_id {
let room = self.rooms.get_room(&room_id).await;
if let Some(room) = room {
self.my_rooms.insert(room_id, RoomRef::Local(room));
} else {
tracing::error!("Room #{room_id:?} not found");
}
} else { } else {
tracing::error!("Room #{room_id:?} not found"); self.my_rooms.insert(room_id, RoomRef::Remote { node_id: node });
} }
} }
while let Some(cmd) = self.rx.recv().await { while let Some(cmd) = self.rx.recv().await {
@ -438,7 +484,7 @@ impl Player {
_ => {} _ => {}
} }
for (_, connection) in &self.connections { for (_, connection) in &self.connections {
let _ = connection.send(update.clone()).await; let _ = connection.send(ConnectionMessage::Update(update.clone())).await;
} }
} }
@ -504,7 +550,8 @@ impl Player {
}; };
room.add_member(&self.player_id, self.storage_id).await; room.add_member(&self.player_id, self.storage_id).await;
room.subscribe(&self.player_id, self.handle.clone()).await; room.subscribe(&self.player_id, self.handle.clone()).await;
self.my_rooms.insert(room_id.clone(), room.clone()); // self.my_rooms.insert(room_id.clone(), room.clone());
panic!();
let room_info = room.get_room_info().await; let room_info = room.get_room_info().await;
let update = Updates::RoomJoined { let update = Updates::RoomJoined {
room_id, room_id,
@ -518,8 +565,9 @@ impl Player {
async fn leave_room(&mut self, connection_id: ConnectionId, room_id: RoomId) { async fn leave_room(&mut self, connection_id: ConnectionId, room_id: RoomId) {
let room = self.my_rooms.remove(&room_id); let room = self.my_rooms.remove(&room_id);
if let Some(room) = room { if let Some(room) = room {
room.unsubscribe(&self.player_id).await; panic!();
room.remove_member(&self.player_id, self.storage_id).await; // room.unsubscribe(&self.player_id).await;
// room.remove_member(&self.player_id, self.storage_id).await;
} }
let update = Updates::RoomLeft { let update = Updates::RoomLeft {
room_id, room_id,
@ -535,7 +583,20 @@ impl Player {
return SendMessageResult::NoSuchRoom; return SendMessageResult::NoSuchRoom;
}; };
let created_at = chrono::Utc::now(); let created_at = chrono::Utc::now();
room.send_message(&self.player_id, body.clone(), created_at.clone()).await; match room {
RoomRef::Local(room) => {
room.send_message(&self.player_id, body.clone(), created_at.clone()).await;
}
RoomRef::Remote { node_id } => {
let req = SendMessageReq {
room_id: room_id.as_inner(),
player_id: self.player_id.as_inner(),
message: &*body,
created_at: &*created_at.to_rfc3339(),
};
self.cluster_client.send_room_message(*node_id, req).await.unwrap();
}
}
let update = Updates::NewMessage { let update = Updates::NewMessage {
room_id, room_id,
author_id: self.player_id.clone(), author_id: self.player_id.clone(),
@ -552,7 +613,8 @@ impl Player {
tracing::info!("no room found"); tracing::info!("no room found");
return; return;
}; };
room.set_topic(&self.player_id, new_topic.clone()).await; // room.set_topic(&self.player_id, new_topic.clone()).await;
todo!();
let update = Updates::RoomTopicChanged { room_id, new_topic }; let update = Updates::RoomTopicChanged { room_id, new_topic };
self.broadcast_update(update, connection_id).await; self.broadcast_update(update, connection_id).await;
} }
@ -560,8 +622,18 @@ impl Player {
#[tracing::instrument(skip(self), name = "Player::get_rooms")] #[tracing::instrument(skip(self), name = "Player::get_rooms")]
async fn get_rooms(&self) -> Vec<RoomInfo> { async fn get_rooms(&self) -> Vec<RoomInfo> {
let mut response = vec![]; let mut response = vec![];
for (_, handle) in &self.my_rooms { for (room_id, handle) in &self.my_rooms {
response.push(handle.get_room_info().await); if let RoomRef::Local(handle) = handle {
response.push(handle.get_room_info().await);
} else {
let room_info = RoomInfo {
id: room_id.clone(),
topic: "unknown".into(),
members: vec![],
};
response.push(room_info);
// TODO
}
} }
response response
} }
@ -589,7 +661,18 @@ impl Player {
if ConnectionId(a) == except { if ConnectionId(a) == except {
continue; continue;
} }
let _ = b.send(update.clone()).await; let _ = b.send(ConnectionMessage::Update(update.clone())).await;
} }
} }
} }
pub enum ConnectionMessage {
Update(Updates),
Stop(StopReason),
}
#[derive(Debug)]
pub enum StopReason {
ServerShutdown,
InternalError,
}

View File

@ -1,5 +1,7 @@
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
pub mod rooms;
#[derive(Serialize, Deserialize)] #[derive(Serialize, Deserialize)]
pub struct ErrorResponse<'a> { pub struct ErrorResponse<'a> {
pub code: &'a str, pub code: &'a str,
@ -11,6 +13,11 @@ pub struct CreatePlayerRequest<'a> {
pub name: &'a str, pub name: &'a str,
} }
#[derive(Serialize, Deserialize)]
pub struct StopPlayerRequest<'a> {
pub name: &'a str,
}
#[derive(Serialize, Deserialize)] #[derive(Serialize, Deserialize)]
pub struct ChangePasswordRequest<'a> { pub struct ChangePasswordRequest<'a> {
pub player_name: &'a str, pub player_name: &'a str,
@ -19,6 +26,7 @@ pub struct ChangePasswordRequest<'a> {
pub mod paths { pub mod paths {
pub const CREATE_PLAYER: &'static str = "/mgmt/create_player"; pub const CREATE_PLAYER: &'static str = "/mgmt/create_player";
pub const STOP_PLAYER: &'static str = "/mgmt/stop_player";
pub const SET_PASSWORD: &'static str = "/mgmt/set_password"; pub const SET_PASSWORD: &'static str = "/mgmt/set_password";
} }

View File

@ -0,0 +1,24 @@
use serde::{Deserialize, Serialize};
#[derive(Serialize, Deserialize)]
pub struct SendMessageReq<'a> {
pub room_id: &'a str,
pub author_id: &'a str,
pub message: &'a str,
}
#[derive(Serialize, Deserialize)]
pub struct SetTopicReq<'a> {
pub room_id: &'a str,
pub author_id: &'a str,
pub topic: &'a str,
}
pub mod paths {
pub const SEND_MESSAGE: &'static str = "/mgmt/rooms/send_message";
pub const SET_TOPIC: &'static str = "/mgmt/rooms/set_topic";
}
pub mod errors {
pub const ROOM_NOT_FOUND: &'static str = "room_not_found";
}

View File

@ -507,11 +507,18 @@ async fn handle_registered_socket<'a>(
buffer.clear(); buffer.clear();
}, },
update = connection.receiver.recv() => { update = connection.receiver.recv() => {
if let Some(update) = update { match update {
handle_update(&config, &user, &player_id, writer, &rooms, update).await?; Some(ConnectionMessage::Update(update)) => {
} else { handle_update(&config, &user, &player_id, writer, &rooms, update).await?;
log::warn!("Player is terminated, must terminate the connection"); }
break; Some(ConnectionMessage::Stop(_)) => {
tracing::debug!("Connection is being terminated");
break;
}
None => {
log::warn!("Player is terminated, must terminate the connection");
break;
}
} }
} }
} }

View File

@ -23,7 +23,7 @@ use tokio_rustls::rustls::{Certificate, PrivateKey};
use tokio_rustls::TlsAcceptor; use tokio_rustls::TlsAcceptor;
use lavina_core::auth::{Authenticator, Verdict}; use lavina_core::auth::{Authenticator, Verdict};
use lavina_core::player::{PlayerConnection, PlayerId, PlayerRegistry}; use lavina_core::player::{ConnectionMessage, PlayerConnection, PlayerId, PlayerRegistry, StopReason};
use lavina_core::prelude::*; use lavina_core::prelude::*;
use lavina_core::repo::Storage; use lavina_core::repo::Storage;
use lavina_core::room::RoomRegistry; use lavina_core::room::RoomRegistry;
@ -31,6 +31,7 @@ use lavina_core::terminator::Terminator;
use lavina_core::LavinaCore; use lavina_core::LavinaCore;
use proto_xmpp::bind::{Name, Resource}; use proto_xmpp::bind::{Name, Resource};
use proto_xmpp::stream::*; use proto_xmpp::stream::*;
use proto_xmpp::streamerror::{StreamError, StreamErrorKind};
use proto_xmpp::xml::{Continuation, FromXml, Parser, ToXml}; use proto_xmpp::xml::{Continuation, FromXml, Parser, ToXml};
use sasl::AuthBody; use sasl::AuthBody;
@ -296,20 +297,19 @@ async fn socket_auth(
xml_writer.get_mut().flush().await?; xml_writer.get_mut().flush().await?;
let auth: proto_xmpp::sasl::Auth = proto_xmpp::sasl::Auth::parse(xml_reader, reader_buf).await?; let auth: proto_xmpp::sasl::Auth = proto_xmpp::sasl::Auth::parse(xml_reader, reader_buf).await?;
proto_xmpp::sasl::Success.write_xml(xml_writer).await?;
xml_writer.get_mut().flush().await?;
match AuthBody::from_str(&auth.body) { match AuthBody::from_str(&auth.body) {
Ok(logopass) => { Ok(logopass) => {
let name = &logopass.login; let name = &logopass.login;
let verdict = Authenticator::new(storage).authenticate(name, &logopass.password).await?; let verdict = Authenticator::new(storage).authenticate(name, &logopass.password).await?;
// TODO return proper XML errors to the client
match verdict { match verdict {
Verdict::Authenticated => {} Verdict::Authenticated => {
Verdict::UserNotFound => { proto_xmpp::sasl::Success.write_xml(xml_writer).await?;
return Err(anyhow!("no user found")); xml_writer.get_mut().flush().await?;
} }
Verdict::InvalidPassword => { Verdict::UserNotFound | Verdict::InvalidPassword => {
proto_xmpp::sasl::Failure.write_xml(xml_writer).await?;
xml_writer.get_mut().flush().await?;
return Err(anyhow!("incorrect credentials")); return Err(anyhow!("incorrect credentials"));
} }
} }
@ -396,16 +396,41 @@ async fn socket_final(
true true
}, },
update = conn.user_handle.receiver.recv() => { update = conn.user_handle.receiver.recv() => {
if let Some(update) = update { match update {
conn.handle_update(&mut events, update).await?; Some(ConnectionMessage::Update(update)) => {
for i in &events { conn.handle_update(&mut events, update).await?;
xml_writer.write_event_async(i).await?; for i in &events {
xml_writer.write_event_async(i).await?;
}
events.clear();
xml_writer.get_mut().flush().await?;
}
Some(ConnectionMessage::Stop(reason)) => {
tracing::debug!("Connection is being terminated: {reason:?}");
let kind = match reason {
StopReason::ServerShutdown => StreamErrorKind::SystemShutdown,
StopReason::InternalError => StreamErrorKind::InternalServerError,
};
StreamError { kind }.serialize(&mut events);
ServerStreamEnd.serialize(&mut events);
for i in &events {
xml_writer.write_event_async(i).await?;
}
events.clear();
xml_writer.get_mut().flush().await?;
break;
}
None => {
log::error!("Player is terminated, must terminate the connection");
StreamError { kind: StreamErrorKind::SystemShutdown }.serialize(&mut events);
ServerStreamEnd.serialize(&mut events);
for i in &events {
xml_writer.write_event_async(i).await?;
}
events.clear();
xml_writer.get_mut().flush().await?;
break;
} }
events.clear();
xml_writer.get_mut().flush().await?;
} else {
log::warn!("Player is terminated, must terminate the connection");
break;
} }
false false
} }
@ -447,6 +472,7 @@ impl<'a> XmppConnection<'a> {
ServerStreamEnd.serialize(output); ServerStreamEnd.serialize(output);
true true
} }
ClientPacket::Eos => true,
}; };
Ok(res) Ok(res)
} }

View File

@ -14,7 +14,7 @@ impl<'a> XmppConnection<'a> {
pub async fn handle_presence(&mut self, output: &mut Vec<Event<'static>>, p: Presence<Ignore>) -> Result<()> { pub async fn handle_presence(&mut self, output: &mut Vec<Event<'static>>, p: Presence<Ignore>) -> Result<()> {
match p.to { match p.to {
None => { None => {
self.self_presence(output).await; self.self_presence(output, p.r#type.as_deref()).await;
} }
Some(Jid { Some(Jid {
name: Some(name), name: Some(name),
@ -33,21 +33,29 @@ impl<'a> XmppConnection<'a> {
Ok(()) Ok(())
} }
async fn self_presence(&mut self, output: &mut Vec<Event<'static>>) { async fn self_presence(&mut self, output: &mut Vec<Event<'static>>, r#type: Option<&str>) {
let response = Presence::<()> { match r#type {
to: Some(Jid { Some("unavailable") => {
name: Some(self.user.xmpp_name.clone()), // do not print anything
server: Server(self.hostname.clone()), }
resource: Some(self.user.xmpp_resource.clone()), None => {
}), let response = Presence::<()> {
from: Some(Jid { to: Some(Jid {
name: Some(self.user.xmpp_name.clone()), name: Some(self.user.xmpp_name.clone()),
server: Server(self.hostname.clone()), server: Server(self.hostname.clone()),
resource: Some(self.user.xmpp_resource.clone()), resource: Some(self.user.xmpp_resource.clone()),
}), }),
..Default::default() from: Some(Jid {
}; name: Some(self.user.xmpp_name.clone()),
response.serialize(output); server: Server(self.hostname.clone()),
resource: Some(self.user.xmpp_resource.clone()),
}),
..Default::default()
};
response.serialize(output);
}
_ => todo!(),
}
} }
async fn muc_presence(&mut self, name: Name, output: &mut Vec<Event<'static>>) -> Result<()> { async fn muc_presence(&mut self, name: Name, output: &mut Vec<Event<'static>>) -> Result<()> {

View File

@ -25,7 +25,7 @@ impl FromXml for IqClientBody {
type P = impl Parser<Output = Result<Self>>; type P = impl Parser<Output = Result<Self>>;
fn parse() -> Self::P { fn parse() -> Self::P {
|(namespace, event): (ResolveResult<'static>, &'static Event<'static>)| -> Result<Self> { |(mut namespace, mut event): (ResolveResult<'static>, &'static Event<'static>)| -> Result<Self> {
let bytes = match event { let bytes = match event {
Event::Start(bytes) => bytes, Event::Start(bytes) => bytes,
Event::Empty(bytes) => bytes, Event::Empty(bytes) => bytes,
@ -52,13 +52,14 @@ pub enum ClientPacket {
Message(Message<Ignore>), Message(Message<Ignore>),
Presence(Presence<Ignore>), Presence(Presence<Ignore>),
StreamEnd, StreamEnd,
Eos,
} }
impl FromXml for ClientPacket { impl FromXml for ClientPacket {
type P = impl Parser<Output = Result<Self>>; type P = impl Parser<Output = Result<Self>>;
fn parse() -> Self::P { fn parse() -> Self::P {
|(namespace, event): (ResolveResult<'static>, &'static Event<'static>)| -> Result<Self> { |(mut namespace, mut event): (ResolveResult<'static>, &'static Event<'static>)| -> Result<Self> {
match event { match event {
Event::Start(bytes) | Event::Empty(bytes) => { Event::Start(bytes) | Event::Empty(bytes) => {
let name = bytes.name(); let name = bytes.name();
@ -83,6 +84,7 @@ impl FromXml for ClientPacket {
return Err(anyhow!("Unexpected XML event: {event:?}")); return Err(anyhow!("Unexpected XML event: {event:?}"));
} }
} }
Event::Eof => Ok(ClientPacket::Eos),
_ => { _ => {
return Err(anyhow!("Unexpected XML event: {event:?}")); return Err(anyhow!("Unexpected XML event: {event:?}"));
} }

View File

@ -35,6 +35,10 @@ struct TestScope<'a> {
buffer: Vec<u8>, buffer: Vec<u8>,
} }
fn element_name<'a>(event: &quick_xml::events::BytesStart<'a>) -> &'a str {
std::str::from_utf8(event.local_name().into_inner()).unwrap()
}
impl<'a> TestScope<'a> { impl<'a> TestScope<'a> {
fn new(stream: &mut TcpStream) -> TestScope<'_> { fn new(stream: &mut TcpStream) -> TestScope<'_> {
let (reader, writer) = stream.split(); let (reader, writer) = stream.split();
@ -55,19 +59,13 @@ impl<'a> TestScope<'a> {
Ok(event) Ok(event)
} }
async fn read<T: FromXml>(&mut self) -> Result<T> { async fn expect_starttls_required(&mut self) -> Result<()> {
self.buffer.clear(); assert_matches!(self.next_xml_event().await?, Event::Start(b) => assert_eq!(element_name(&b), "features"));
let (ns, event) = self.reader.read_resolved_event_into_async(&mut self.buffer).await?; assert_matches!(self.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"starttls"));
let mut parser: Continuation<_, std::result::Result<T, anyhow::Error>> = T::parse().consume(ns, &event); assert_matches!(self.next_xml_event().await?, Event::Empty(b) => assert_eq!(b.local_name().into_inner(), b"required"));
loop { assert_matches!(self.next_xml_event().await?, Event::End(b) => assert_eq!(b.local_name().into_inner(), b"starttls"));
match parser { assert_matches!(self.next_xml_event().await?, Event::End(b) => assert_eq!(b.local_name().into_inner(), b"features"));
Continuation::Final(res) => return Ok(res?), Ok(())
Continuation::Continue(next) => {
let (ns, event) = self.reader.read_resolved_event_into_async(&mut self.buffer).await?;
parser = next.consume(ns, &event);
}
}
}
} }
} }
@ -82,7 +80,7 @@ impl<'a> TestScopeTls<'a> {
fn new(stream: &'a mut TlsStream<TcpStream>, buffer: Vec<u8>) -> TestScopeTls<'a> { fn new(stream: &'a mut TlsStream<TcpStream>, buffer: Vec<u8>) -> TestScopeTls<'a> {
let (reader, writer) = tokio::io::split(stream); let (reader, writer) = tokio::io::split(stream);
let reader = NsReader::from_reader(BufReader::new(reader)); let reader = NsReader::from_reader(BufReader::new(reader));
let timeout = Duration::from_millis(100); let timeout = Duration::from_millis(500);
TestScopeTls { TestScopeTls {
reader, reader,
@ -98,6 +96,24 @@ impl<'a> TestScopeTls<'a> {
Ok(()) Ok(())
} }
async fn expect_auth_mechanisms(&mut self) -> Result<()> {
assert_matches!(self.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"features"));
assert_matches!(self.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"mechanisms"));
assert_matches!(self.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"mechanism"));
assert_matches!(self.next_xml_event().await?, Event::Text(b) => assert_eq!(&*b, b"PLAIN"));
assert_matches!(self.next_xml_event().await?, Event::End(b) => assert_eq!(b.local_name().into_inner(), b"mechanism"));
assert_matches!(self.next_xml_event().await?, Event::End(b) => assert_eq!(b.local_name().into_inner(), b"mechanisms"));
assert_matches!(self.next_xml_event().await?, Event::End(b) => assert_eq!(b.local_name().into_inner(), b"features"));
Ok(())
}
async fn expect_bind_feature(&mut self) -> Result<()> {
assert_matches!(self.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"features"));
assert_matches!(self.next_xml_event().await?, Event::Empty(b) => assert_eq!(b.local_name().into_inner(), b"bind"));
assert_matches!(self.next_xml_event().await?, Event::End(b) => assert_eq!(b.local_name().into_inner(), b"features"));
Ok(())
}
async fn next_xml_event(&mut self) -> Result<Event<'_>> { async fn next_xml_event(&mut self) -> Result<Event<'_>> {
self.buffer.clear(); self.buffer.clear();
let event = self.reader.read_event_into_async(&mut self.buffer); let event = self.reader.read_event_into_async(&mut self.buffer);
@ -176,11 +192,7 @@ async fn scenario_basic() -> Result<()> {
s.send(r#"<stream:stream xmlns:stream="http://etherx.jabber.org/streams" to="127.0.0.1" xml:lang="en" xmlns:xml="http://www.w3.org/XML/1998/namespace" xmlns="jabber:client" version="1.0">"#).await?; s.send(r#"<stream:stream xmlns:stream="http://etherx.jabber.org/streams" to="127.0.0.1" xml:lang="en" xmlns:xml="http://www.w3.org/XML/1998/namespace" xmlns="jabber:client" version="1.0">"#).await?;
assert_matches!(s.next_xml_event().await?, Event::Decl(_) => {}); assert_matches!(s.next_xml_event().await?, Event::Decl(_) => {});
assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"stream")); assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"stream"));
assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"features")); s.expect_starttls_required().await?;
assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"starttls"));
assert_matches!(s.next_xml_event().await?, Event::Empty(b) => assert_eq!(b.local_name().into_inner(), b"required"));
assert_matches!(s.next_xml_event().await?, Event::End(b) => assert_eq!(b.local_name().into_inner(), b"starttls"));
assert_matches!(s.next_xml_event().await?, Event::End(b) => assert_eq!(b.local_name().into_inner(), b"features"));
s.send(r#"<starttls/>"#).await?; s.send(r#"<starttls/>"#).await?;
assert_matches!(s.next_xml_event().await?, Event::Empty(b) => assert_eq!(b.local_name().into_inner(), b"proceed")); assert_matches!(s.next_xml_event().await?, Event::Empty(b) => assert_eq!(b.local_name().into_inner(), b"proceed"));
let buffer = s.buffer; let buffer = s.buffer;
@ -202,6 +214,26 @@ async fn scenario_basic() -> Result<()> {
s.send(r#"<stream:stream xmlns:stream="http://etherx.jabber.org/streams" to="127.0.0.1" xml:lang="en" xmlns:xml="http://www.w3.org/XML/1998/namespace" xmlns="jabber:client" version="1.0">"#).await?; s.send(r#"<stream:stream xmlns:stream="http://etherx.jabber.org/streams" to="127.0.0.1" xml:lang="en" xmlns:xml="http://www.w3.org/XML/1998/namespace" xmlns="jabber:client" version="1.0">"#).await?;
assert_matches!(s.next_xml_event().await?, Event::Decl(_) => {}); assert_matches!(s.next_xml_event().await?, Event::Decl(_) => {});
assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"stream")); assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"stream"));
s.expect_auth_mechanisms().await?;
// base64-encoded b"\x00tester\x00password"
s.send(r#"<auth xmlns="urn:ietf:params:xml:ns:xmpp-sasl" mechanism="PLAIN">AHRlc3RlcgBwYXNzd29yZA==</auth>"#)
.await?;
assert_matches!(s.next_xml_event().await?, Event::Empty(b) => assert_eq!(b.local_name().into_inner(), b"success"));
s.send(r#"<?xml version="1.0"?>"#).await?;
s.send(r#"<stream:stream xmlns:stream="http://etherx.jabber.org/streams" to="127.0.0.1" xml:lang="en" xmlns:xml="http://www.w3.org/XML/1998/namespace" xmlns="jabber:client" version="1.0">"#).await?;
assert_matches!(s.next_xml_event().await?, Event::Decl(_) => {});
assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"stream"));
s.expect_bind_feature().await?;
s.send(r#"<iq id="bind_1" type="set"><bind xmlns="urn:ietf:params:xml:ns:xmpp-bind"><resource>kek</resource></bind></iq>"#).await?;
assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"iq"));
assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"bind"));
assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"jid"));
assert_matches!(s.next_xml_event().await?, Event::Text(b) => assert_eq!(&*b, b"tester@localhost/tester"));
assert_matches!(s.next_xml_event().await?, Event::End(b) => assert_eq!(b.local_name().into_inner(), b"jid"));
assert_matches!(s.next_xml_event().await?, Event::End(b) => assert_eq!(b.local_name().into_inner(), b"bind"));
assert_matches!(s.next_xml_event().await?, Event::End(b) => assert_eq!(b.local_name().into_inner(), b"iq"));
s.send(r#"<presence xmlns="jabber:client" type="unavailable"><status>Logged out</status></presence>"#).await?;
stream.shutdown().await?; stream.shutdown().await?;
@ -211,6 +243,61 @@ async fn scenario_basic() -> Result<()> {
Ok(()) Ok(())
} }
#[tokio::test]
async fn scenario_wrong_password() -> Result<()> {
let mut server = TestServer::start().await?;
// test scenario
server.storage.create_user("tester").await?;
Authenticator::new(&server.storage).set_password("tester", "password").await?;
let mut stream = TcpStream::connect(server.server.addr).await?;
let mut s = TestScope::new(&mut stream);
tracing::info!("TCP connection established");
s.send(r#"<?xml version="1.0"?>"#).await?;
s.send(r#"<stream:stream xmlns:stream="http://etherx.jabber.org/streams" to="127.0.0.1" xml:lang="en" xmlns:xml="http://www.w3.org/XML/1998/namespace" xmlns="jabber:client" version="1.0">"#).await?;
assert_matches!(s.next_xml_event().await?, Event::Decl(_) => {});
assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"stream"));
s.expect_starttls_required().await?;
s.send(r#"<starttls/>"#).await?;
assert_matches!(s.next_xml_event().await?, Event::Empty(b) => assert_eq!(b.local_name().into_inner(), b"proceed"));
let buffer = s.buffer;
tracing::info!("TLS feature negotiation complete");
let connector = TlsConnector::from(Arc::new(
ClientConfig::builder()
.with_safe_defaults()
.with_custom_certificate_verifier(Arc::new(IgnoreCertVerification))
.with_no_client_auth(),
));
tracing::info!("Initiating TLS connection...");
let mut stream = connector.connect(ServerName::IpAddress(server.server.addr.ip()), stream).await?;
tracing::info!("TLS connection established");
let mut s = TestScopeTls::new(&mut stream, buffer);
s.send(r#"<?xml version="1.0"?>"#).await?;
s.send(r#"<stream:stream xmlns:stream="http://etherx.jabber.org/streams" to="127.0.0.1" xml:lang="en" xmlns:xml="http://www.w3.org/XML/1998/namespace" xmlns="jabber:client" version="1.0">"#).await?;
assert_matches!(s.next_xml_event().await?, Event::Decl(_) => {});
assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"stream"));
s.expect_auth_mechanisms().await?;
// base64-encoded b"\x00tester\x00password2"
s.send(r#"<auth xmlns="urn:ietf:params:xml:ns:xmpp-sasl" mechanism="PLAIN">AHRlc3RlcgBwYXNzd29yZDI=</auth>"#)
.await?;
assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"failure"));
assert_matches!(s.next_xml_event().await?, Event::Empty(b) => assert_eq!(b.local_name().into_inner(), b"not-authorized"));
assert_matches!(s.next_xml_event().await?, Event::End(b) => assert_eq!(b.local_name().into_inner(), b"failure"));
let _ = stream.shutdown().await;
// wrap up
server.shutdown().await?;
Ok(())
}
#[tokio::test] #[tokio::test]
async fn scenario_basic_without_headers() -> Result<()> { async fn scenario_basic_without_headers() -> Result<()> {
let mut server = TestServer::start().await?; let mut server = TestServer::start().await?;
@ -227,11 +314,7 @@ async fn scenario_basic_without_headers() -> Result<()> {
s.send(r#"<stream:stream xmlns:stream="http://etherx.jabber.org/streams" to="127.0.0.1" xml:lang="en" xmlns:xml="http://www.w3.org/XML/1998/namespace" xmlns="jabber:client" version="1.0">"#).await?; s.send(r#"<stream:stream xmlns:stream="http://etherx.jabber.org/streams" to="127.0.0.1" xml:lang="en" xmlns:xml="http://www.w3.org/XML/1998/namespace" xmlns="jabber:client" version="1.0">"#).await?;
assert_matches!(s.next_xml_event().await?, Event::Decl(_) => {}); assert_matches!(s.next_xml_event().await?, Event::Decl(_) => {});
assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"stream")); assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"stream"));
assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"features")); s.expect_starttls_required().await?;
assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"starttls"));
assert_matches!(s.next_xml_event().await?, Event::Empty(b) => assert_eq!(b.local_name().into_inner(), b"required"));
assert_matches!(s.next_xml_event().await?, Event::End(b) => assert_eq!(b.local_name().into_inner(), b"starttls"));
assert_matches!(s.next_xml_event().await?, Event::End(b) => assert_eq!(b.local_name().into_inner(), b"features"));
s.send(r#"<starttls/>"#).await?; s.send(r#"<starttls/>"#).await?;
assert_matches!(s.next_xml_event().await?, Event::Empty(b) => assert_eq!(b.local_name().into_inner(), b"proceed")); assert_matches!(s.next_xml_event().await?, Event::Empty(b) => assert_eq!(b.local_name().into_inner(), b"proceed"));
let buffer = s.buffer; let buffer = s.buffer;
@ -279,11 +362,7 @@ async fn terminate_socket() -> Result<()> {
s.send(r#"<stream:stream xmlns:stream="http://etherx.jabber.org/streams" to="127.0.0.1" xml:lang="en" xmlns:xml="http://www.w3.org/XML/1998/namespace" xmlns="jabber:client" version="1.0">"#).await?; s.send(r#"<stream:stream xmlns:stream="http://etherx.jabber.org/streams" to="127.0.0.1" xml:lang="en" xmlns:xml="http://www.w3.org/XML/1998/namespace" xmlns="jabber:client" version="1.0">"#).await?;
assert_matches!(s.next_xml_event().await?, Event::Decl(_) => {}); assert_matches!(s.next_xml_event().await?, Event::Decl(_) => {});
assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"stream")); assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"stream"));
assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"features")); s.expect_starttls_required().await?;
assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"starttls"));
assert_matches!(s.next_xml_event().await?, Event::Empty(b) => assert_eq!(b.local_name().into_inner(), b"required"));
assert_matches!(s.next_xml_event().await?, Event::End(b) => assert_eq!(b.local_name().into_inner(), b"starttls"));
assert_matches!(s.next_xml_event().await?, Event::End(b) => assert_eq!(b.local_name().into_inner(), b"features"));
s.send(r#"<starttls/>"#).await?; s.send(r#"<starttls/>"#).await?;
assert_matches!(s.next_xml_event().await?, Event::Empty(b) => assert_eq!(b.local_name().into_inner(), b"proceed")); assert_matches!(s.next_xml_event().await?, Event::Empty(b) => assert_eq!(b.local_name().into_inner(), b"proceed"));

View File

@ -82,7 +82,7 @@ impl FromXml for BindRequest {
type P = impl Parser<Output = Result<Self>>; type P = impl Parser<Output = Result<Self>>;
fn parse() -> Self::P { fn parse() -> Self::P {
|(namespace, event): (ResolveResult<'static>, &'static Event<'static>)| -> Result<Self> { |(mut namespace, mut event): (ResolveResult<'static>, &'static Event<'static>)| -> Result<Self> {
let mut resource: Option<Str> = None; let mut resource: Option<Str> = None;
let Event::Start(bytes) = event else { let Event::Start(bytes) = event else {
return Err(anyhow!("Unexpected XML event: {event:?}")); return Err(anyhow!("Unexpected XML event: {event:?}"));
@ -97,15 +97,15 @@ impl FromXml for BindRequest {
return Err(anyhow!("Incorrect namespace")); return Err(anyhow!("Incorrect namespace"));
} }
loop { loop {
let (namespace, event) = yield; (namespace, event) = yield;
match event { match event {
Event::Start(bytes) if bytes.name().0 == b"resource" => { Event::Start(bytes) if bytes.name().0 == b"resource" => {
let (namespace, event) = yield; (namespace, event) = yield;
let Event::Text(text) = event else { let Event::Text(text) = event else {
return Err(anyhow!("Unexpected XML event: {event:?}")); return Err(anyhow!("Unexpected XML event: {event:?}"));
}; };
resource = Some(std::str::from_utf8(&*text)?.into()); resource = Some(std::str::from_utf8(&*text)?.into());
let (namespace, event) = yield; (namespace, event) = yield;
let Event::End(bytes) = event else { let Event::End(bytes) = event else {
return Err(anyhow!("Unexpected XML event: {event:?}")); return Err(anyhow!("Unexpected XML event: {event:?}"));
}; };

View File

@ -378,7 +378,7 @@ impl<T: FromXml> Parser for IqParser<T> {
} }
}, },
IqParserInner::Final(state) => { IqParserInner::Final(state) => {
if let Event::End(ref bytes) = event { if let Event::End(_) = event {
let id = fail_fast!(state.id.ok_or_else(|| ffail!("No id provided"))); let id = fail_fast!(state.id.ok_or_else(|| ffail!("No id provided")));
let r#type = fail_fast!(state.r#type.ok_or_else(|| ffail!("No type provided"))); let r#type = fail_fast!(state.r#type.ok_or_else(|| ffail!("No type provided")));
let body = fail_fast!(state.body.ok_or_else(|| ffail!("No body provided"))); let body = fail_fast!(state.body.ok_or_else(|| ffail!("No body provided")));
@ -528,7 +528,7 @@ impl<T: FromXml> FromXml for Presence<T> {
type P = impl Parser<Output = Result<Presence<T>>>; type P = impl Parser<Output = Result<Presence<T>>>;
fn parse() -> Self::P { fn parse() -> Self::P {
|(namespace, event): (ResolveResult<'static>, &'static Event<'static>)| -> Result<Self> { |(mut namespace, mut event): (ResolveResult<'static>, &'static Event<'static>)| -> Result<Self> {
let (bytes, end) = match event { let (bytes, end) = match event {
Event::Start(bytes) => (bytes, false), Event::Start(bytes) => (bytes, false),
Event::Empty(bytes) => (bytes, true), Event::Empty(bytes) => (bytes, true),
@ -557,37 +557,37 @@ impl<T: FromXml> FromXml for Presence<T> {
return Ok(p); return Ok(p);
} }
loop { loop {
let (namespace, event) = yield; (namespace, event) = yield;
match event { match event {
Event::Start(bytes) => match bytes.name().0 { Event::Start(bytes) => match bytes.name().0 {
b"show" => { b"show" => {
let (_, event) = yield; (namespace, event) = yield;
let Event::Text(bytes) = event else { let Event::Text(bytes) = event else {
return Err(ffail!("Unexpected XML event: {event:?}")); return Err(ffail!("Unexpected XML event: {event:?}"));
}; };
let i = PresenceShow::from_str(bytes)?; let i = PresenceShow::from_str(bytes)?;
p.show = Some(i); p.show = Some(i);
let (_, event) = yield; (namespace, event) = yield;
let Event::End(_) = event else { let Event::End(_) = event else {
return Err(ffail!("Unexpected XML event: {event:?}")); return Err(ffail!("Unexpected XML event: {event:?}"));
}; };
} }
b"status" => { b"status" => {
let (_, event) = yield; (namespace, event) = yield;
let Event::Text(bytes) = event else { let Event::Text(bytes) = event else {
return Err(ffail!("Unexpected XML event: {event:?}")); return Err(ffail!("Unexpected XML event: {event:?}"));
}; };
let s = std::str::from_utf8(bytes)?; let s = std::str::from_utf8(bytes)?;
p.status.push(s.to_string()); p.status.push(s.to_string());
let (_, event) = yield; (namespace, event) = yield;
let Event::End(_) = event else { let Event::End(_) = event else {
return Err(ffail!("Unexpected XML event: {event:?}")); return Err(ffail!("Unexpected XML event: {event:?}"));
}; };
} }
b"priority" => { b"priority" => {
let (_, event) = yield; (namespace, event) = yield;
let Event::Text(bytes) = event else { let Event::Text(bytes) = event else {
return Err(ffail!("Unexpected XML event: {event:?}")); return Err(ffail!("Unexpected XML event: {event:?}"));
}; };
@ -595,7 +595,7 @@ impl<T: FromXml> FromXml for Presence<T> {
let i = s.parse()?; let i = s.parse()?;
p.priority = Some(PresencePriority(i)); p.priority = Some(PresencePriority(i));
let (_, event) = yield; (namespace, event) = yield;
let Event::End(_) = event else { let Event::End(_) = event else {
return Err(ffail!("Unexpected XML event: {event:?}")); return Err(ffail!("Unexpected XML event: {event:?}"));
}; };

View File

@ -21,7 +21,7 @@ impl FromXml for InfoQuery {
type P = impl Parser<Output = Result<Self>>; type P = impl Parser<Output = Result<Self>>;
fn parse() -> Self::P { fn parse() -> Self::P {
|(namespace, event): (ResolveResult<'static>, &'static Event<'static>)| -> Result<Self> { |(mut namespace, mut event): (ResolveResult<'static>, &'static Event<'static>)| -> Result<Self> {
let mut node = None; let mut node = None;
let mut identity = vec![]; let mut identity = vec![];
let mut feature = vec![]; let mut feature = vec![];
@ -48,7 +48,7 @@ impl FromXml for InfoQuery {
}); });
} }
loop { loop {
let (namespace, event) = yield; (namespace, event) = yield;
let bytes = match event { let bytes = match event {
Event::Start(bytes) => bytes, Event::Start(bytes) => bytes,
Event::Empty(bytes) => bytes, Event::Empty(bytes) => bytes,
@ -141,7 +141,7 @@ impl FromXml for Identity {
type P = impl Parser<Output = Result<Self>>; type P = impl Parser<Output = Result<Self>>;
fn parse() -> Self::P { fn parse() -> Self::P {
|(namespace, event): (ResolveResult<'static>, &'static Event<'static>)| -> Result<Self> { |(mut namespace, mut event): (ResolveResult<'static>, &'static Event<'static>)| -> Result<Self> {
let mut category = None; let mut category = None;
let mut name = None; let mut name = None;
let mut r#type = None; let mut r#type = None;
@ -179,8 +179,8 @@ impl FromXml for Identity {
return Ok(item); return Ok(item);
} }
let (namespace, event) = yield; (namespace, event) = yield;
let Event::End(bytes) = event else { let Event::End(_) = event else {
return Err(ffail!("Unexpected XML event: {event:?}")); return Err(ffail!("Unexpected XML event: {event:?}"));
}; };
Ok(item) Ok(item)
@ -209,7 +209,7 @@ impl FromXml for Feature {
type P = impl Parser<Output = Result<Self>>; type P = impl Parser<Output = Result<Self>>;
fn parse() -> Self::P { fn parse() -> Self::P {
|(namespace, event): (ResolveResult<'static>, &'static Event<'static>)| -> Result<Self> { |(mut namespace, mut event): (ResolveResult<'static>, &'static Event<'static>)| -> Result<Self> {
let mut var = None; let mut var = None;
let (bytes, end) = match event { let (bytes, end) = match event {
Event::Start(bytes) => (bytes, false), Event::Start(bytes) => (bytes, false),
@ -234,8 +234,8 @@ impl FromXml for Feature {
return Ok(item); return Ok(item);
} }
let (namespace, event) = yield; (namespace, event) = yield;
let Event::End(bytes) = event else { let Event::End(_) = event else {
return Err(ffail!("Unexpected XML event: {event:?}")); return Err(ffail!("Unexpected XML event: {event:?}"));
}; };
Ok(item) Ok(item)
@ -258,9 +258,9 @@ impl FromXml for ItemQuery {
type P = impl Parser<Output = Result<Self>>; type P = impl Parser<Output = Result<Self>>;
fn parse() -> Self::P { fn parse() -> Self::P {
|(namespace, event): (ResolveResult<'static>, &'static Event<'static>)| -> Result<Self> { |(mut namespace, mut event): (ResolveResult<'static>, &'static Event<'static>)| -> Result<Self> {
let mut item = vec![]; let mut item = vec![];
let (bytes, end) = match event { let (_, end) = match event {
Event::Start(bytes) => (bytes, false), Event::Start(bytes) => (bytes, false),
Event::Empty(bytes) => (bytes, true), Event::Empty(bytes) => (bytes, true),
_ => return Err(ffail!("Unexpected XML event: {event:?}")), _ => return Err(ffail!("Unexpected XML event: {event:?}")),
@ -269,7 +269,7 @@ impl FromXml for ItemQuery {
return Ok(ItemQuery { item }); return Ok(ItemQuery { item });
} }
loop { loop {
let (namespace, event) = yield; (namespace, event) = yield;
let bytes = match event { let bytes = match event {
Event::Start(bytes) => bytes, Event::Start(bytes) => bytes,
Event::Empty(bytes) => bytes, Event::Empty(bytes) => bytes,
@ -296,7 +296,7 @@ impl FromXmlTag for ItemQuery {
impl ToXml for ItemQuery { impl ToXml for ItemQuery {
fn serialize(&self, events: &mut Vec<Event<'static>>) { fn serialize(&self, events: &mut Vec<Event<'static>>) {
let mut bytes = BytesStart::new(format!(r#"query xmlns="{}""#, XMLNS_ITEM)); let bytes = BytesStart::new(format!(r#"query xmlns="{}""#, XMLNS_ITEM));
let empty = self.item.is_empty(); let empty = self.item.is_empty();
if empty { if empty {
events.push(Event::Empty(bytes)); events.push(Event::Empty(bytes));
@ -342,7 +342,7 @@ impl FromXml for Item {
type P = impl Parser<Output = Result<Self>>; type P = impl Parser<Output = Result<Self>>;
fn parse() -> Self::P { fn parse() -> Self::P {
|(namespace, event): (ResolveResult<'static>, &'static Event<'static>)| -> Result<Self> { |(_, mut event): (ResolveResult<'static>, &'static Event<'static>)| -> Result<Self> {
let mut jid = None; let mut jid = None;
let mut name = None; let mut name = None;
let mut node = None; let mut node = None;
@ -378,8 +378,8 @@ impl FromXml for Item {
return Ok(item); return Ok(item);
} }
let (namespace, event) = yield; (_, event) = yield;
let Event::End(bytes) = event else { let Event::End(_) = event else {
return Err(ffail!("Unexpected XML event: {event:?}")); return Err(ffail!("Unexpected XML event: {event:?}"));
}; };
Ok(item) Ok(item)

View File

@ -10,6 +10,7 @@ pub mod sasl;
pub mod session; pub mod session;
pub mod stanzaerror; pub mod stanzaerror;
pub mod stream; pub mod stream;
pub mod streamerror;
pub mod tls; pub mod tls;
pub mod xml; pub mod xml;

View File

@ -19,7 +19,7 @@ impl FromXml for History {
type P = impl Parser<Output = Result<Self>>; type P = impl Parser<Output = Result<Self>>;
fn parse() -> Self::P { fn parse() -> Self::P {
|(namespace, event): (ResolveResult<'static>, &'static Event<'static>)| -> Result<Self> { |(mut namespace, mut event): (ResolveResult<'static>, &'static Event<'static>)| -> Result<Self> {
let mut history = History::default(); let mut history = History::default();
let (bytes, end) = match event { let (bytes, end) = match event {
Event::Start(bytes) if bytes.name().0 == Self::NAME.as_bytes() => (bytes, false), Event::Start(bytes) if bytes.name().0 == Self::NAME.as_bytes() => (bytes, false),
@ -51,7 +51,7 @@ impl FromXml for History {
return Ok(history); return Ok(history);
} }
let (namespace, event) = yield; (namespace, event) = yield;
let Event::End(bytes) = event else { let Event::End(bytes) = event else {
return Err(anyhow!("Unexpected XML event: {event:?}")); return Err(anyhow!("Unexpected XML event: {event:?}"));
}; };
@ -73,17 +73,17 @@ impl FromXml for Password {
type P = impl Parser<Output = Result<Self>>; type P = impl Parser<Output = Result<Self>>;
fn parse() -> Self::P { fn parse() -> Self::P {
|(namespace, event): (ResolveResult<'static>, &'static Event<'static>)| -> Result<Self> { |(mut namespace, mut event): (ResolveResult<'static>, &'static Event<'static>)| -> Result<Self> {
let bytes = match event { let bytes = match event {
Event::Start(bytes) if bytes.name().0 == Self::NAME.as_bytes() => bytes, Event::Start(bytes) if bytes.name().0 == Self::NAME.as_bytes() => bytes,
_ => return Err(anyhow!("Unexpected XML event: {event:?}")), _ => return Err(anyhow!("Unexpected XML event: {event:?}")),
}; };
let (namespace, event) = yield; (namespace, event) = yield;
let Event::Text(bytes) = event else { let Event::Text(bytes) = event else {
return Err(anyhow!("Unexpected XML event: {event:?}")); return Err(anyhow!("Unexpected XML event: {event:?}"));
}; };
let s = std::str::from_utf8(bytes)?.to_string(); let s = std::str::from_utf8(bytes)?.to_string();
let (namespace, event) = yield; (namespace, event) = yield;
let Event::End(bytes) = event else { let Event::End(bytes) = event else {
return Err(anyhow!("Unexpected XML event: {event:?}")); return Err(anyhow!("Unexpected XML event: {event:?}"));
}; };
@ -108,7 +108,7 @@ impl FromXml for X {
type P = impl Parser<Output = Result<Self>>; type P = impl Parser<Output = Result<Self>>;
fn parse() -> Self::P { fn parse() -> Self::P {
|(namespace, event): (ResolveResult<'static>, &'static Event<'static>)| -> Result<Self> { |(mut namespace, mut event): (ResolveResult<'static>, &'static Event<'static>)| -> Result<Self> {
let mut res = X::default(); let mut res = X::default();
let (_, end) = match event { let (_, end) = match event {
Event::Start(bytes) => (bytes, false), Event::Start(bytes) => (bytes, false),
@ -120,7 +120,7 @@ impl FromXml for X {
} }
loop { loop {
let (namespace, event) = yield; (namespace, event) = yield;
let bytes = match event { let bytes = match event {
Event::Start(bytes) => bytes, Event::Start(bytes) => bytes,
Event::Empty(bytes) => bytes, Event::Empty(bytes) => bytes,

View File

@ -1,7 +1,7 @@
use std::borrow::Borrow; use std::borrow::Borrow;
use anyhow::{anyhow, Result}; use anyhow::{anyhow, Result};
use quick_xml::events::{BytesStart, Event}; use quick_xml::events::{BytesEnd, BytesStart, Event};
use quick_xml::{NsReader, Writer}; use quick_xml::{NsReader, Writer};
use tokio::io::{AsyncBufRead, AsyncWrite}; use tokio::io::{AsyncBufRead, AsyncWrite};
@ -74,3 +74,16 @@ impl Success {
Ok(()) Ok(())
} }
} }
pub struct Failure;
impl Failure {
pub async fn write_xml(&self, writer: &mut Writer<impl AsyncWrite + Unpin>) -> Result<()> {
let event = BytesStart::new(r#"failure xmlns="urn:ietf:params:xml:ns:xmpp-sasl""#);
writer.write_event_async(Event::Start(event)).await?;
let event = BytesStart::new(r#"not-authorized"#);
writer.write_event_async(Event::Empty(event)).await?;
let event = BytesEnd::new(r#"failure"#);
writer.write_event_async(Event::End(event)).await?;
Ok(())
}
}

View File

@ -0,0 +1,41 @@
use crate::xml::ToXml;
use quick_xml::events::{BytesEnd, BytesStart, Event};
/// Stream error condition
///
/// [Spec](https://xmpp.org/rfcs/rfc6120.html#streams-error-conditions).
pub enum StreamErrorKind {
/// The server has experienced a misconfiguration or other internal error that prevents it from servicing the stream.
InternalServerError,
/// The server is being shut down and all active streams are being closed.
SystemShutdown,
}
impl StreamErrorKind {
pub fn from_str(s: &str) -> Option<Self> {
match s {
"internal-server-error" => Some(Self::InternalServerError),
"system-shutdown" => Some(Self::SystemShutdown),
_ => None,
}
}
pub fn as_str(&self) -> &'static str {
match self {
Self::InternalServerError => "internal-server-error",
Self::SystemShutdown => "system-shutdown",
}
}
}
pub struct StreamError {
pub kind: StreamErrorKind,
}
impl ToXml for StreamError {
fn serialize(&self, events: &mut Vec<Event<'static>>) {
events.push(Event::Start(BytesStart::new("stream:error")));
events.push(Event::Empty(BytesStart::new(format!(
r#"{} xmlns="urn:ietf:params:xml:ns:xmpp-streams""#,
self.kind.as_str()
))));
events.push(Event::End(BytesEnd::new("stream:error")));
}
}

View File

@ -10,9 +10,38 @@ use anyhow::Result;
mod ignore; mod ignore;
pub use ignore::Ignore; pub use ignore::Ignore;
/// Types which can be parsed from an XML input stream.
///
/// Example:
/// ```
/// #![feature(type_alias_impl_trait)]
/// #![feature(impl_trait_in_assoc_type)]
/// #![feature(coroutines)]
/// # use proto_xmpp::xml::FromXml;
/// # use quick_xml::events::Event;
/// # use quick_xml::name::ResolveResult;
/// # use proto_xmpp::xml::Parser;
/// # use anyhow::Result;
///
/// struct MyStruct;
/// impl FromXml for MyStruct {
/// type P = impl Parser<Output = Result<Self>>;
///
/// fn parse() -> Self::P {
/// |(mut namespace, mut event): (ResolveResult<'static>, &'static Event<'static>)| -> Result<Self> {
/// (namespace, event) = yield;
/// Ok(MyStruct)
/// }
/// }
/// }
/// ```
pub trait FromXml: Sized { pub trait FromXml: Sized {
/// The type of parser instances.
///
/// If the result type of the [parse] is anonymous, this type member can be defined by using `impl Trait`.
type P: Parser<Output = Result<Self>>; type P: Parser<Output = Result<Self>>;
/// Creates a new instance of a parser with an initial state.
fn parse() -> Self::P; fn parse() -> Self::P;
} }
@ -25,9 +54,18 @@ pub trait FromXmlTag: FromXml {
const NS: &'static str; const NS: &'static str;
} }
/// A stateful parser instance which consumes XML events until the parsing is complete.
///
/// Usually implemented with the experimental coroutine syntax, which yields to consume the next XML event,
/// and returns the final result when the parsing is done.
pub trait Parser: Sized { pub trait Parser: Sized {
type Output; type Output;
/// Advance the parsing by one XML event.
///
/// This method consumes `self`, but if the parsing is incomplete,
/// it will return the next state of the parser in the returned result.
/// Otherwise, it will return the final result of parsing.
fn consume<'a>(self: Self, namespace: ResolveResult, event: &Event<'a>) -> Continuation<Self, Self::Output>; fn consume<'a>(self: Self, namespace: ResolveResult, event: &Event<'a>) -> Continuation<Self, Self::Output>;
} }
@ -50,8 +88,11 @@ where
} }
} }
/// The result of a single parser iteration.
pub enum Continuation<Parser, Res> { pub enum Continuation<Parser, Res> {
/// The parsing is complete and the final result is available.
Final(Res), Final(Res),
/// The parsing is not complete and more XML events are required.
Continue(Parser), Continue(Parser),
} }
@ -89,8 +130,8 @@ macro_rules! delegate_parsing {
Continuation::Final(Ok(res)) => break Ok(res.into()), Continuation::Final(Ok(res)) => break Ok(res.into()),
Continuation::Final(Err(err)) => break Err(err), Continuation::Final(Err(err)) => break Err(err),
Continuation::Continue(p) => { Continuation::Continue(p) => {
let (namespace, event) = yield; ($namespace, $event) = yield;
parser = p.consume(namespace, event); parser = p.consume($namespace, $event);
} }
} }
} }

View File

@ -1,3 +1,4 @@
use chrono::Utc;
use std::convert::Infallible; use std::convert::Infallible;
use std::net::SocketAddr; use std::net::SocketAddr;
@ -8,14 +9,20 @@ use hyper::server::conn::http1;
use hyper::service::service_fn; use hyper::service::service_fn;
use hyper::{Method, Request, Response, StatusCode}; use hyper::{Method, Request, Response, StatusCode};
use hyper_util::rt::TokioIo; use hyper_util::rt::TokioIo;
use opentelemetry::propagation::Extractor;
use prometheus::{Encoder, Registry as MetricsRegistry, TextEncoder}; use prometheus::{Encoder, Registry as MetricsRegistry, TextEncoder};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use tokio::net::TcpListener; use tokio::net::TcpListener;
use tracing::Span;
use tracing_opentelemetry::OpenTelemetrySpanExt;
use lavina_core::auth::UpdatePasswordResult::PasswordUpdated;
use lavina_core::auth::{Authenticator, UpdatePasswordResult}; use lavina_core::auth::{Authenticator, UpdatePasswordResult};
use lavina_core::clustering::SendMessageReq;
use lavina_core::player::{PlayerId, PlayerRegistry, SendMessageResult};
use lavina_core::prelude::*; use lavina_core::prelude::*;
use lavina_core::repo::Storage; use lavina_core::repo::Storage;
use lavina_core::room::RoomRegistry; use lavina_core::room::{RoomId, RoomRegistry};
use lavina_core::terminator::Terminator; use lavina_core::terminator::Terminator;
use lavina_core::LavinaCore; use lavina_core::LavinaCore;
@ -75,18 +82,39 @@ async fn main_loop(
Ok(()) Ok(())
} }
#[tracing::instrument(skip_all, name = "route")]
async fn route( async fn route(
registry: MetricsRegistry, registry: MetricsRegistry,
core: LavinaCore, core: LavinaCore,
storage: Storage, storage: Storage,
request: Request<hyper::body::Incoming>, request: Request<hyper::body::Incoming>,
) -> HttpResult<Response<Full<Bytes>>> { ) -> HttpResult<Response<Full<Bytes>>> {
struct HttpReqExtractor<'a, T> {
req: &'a Request<T>,
}
impl<'a, T> Extractor for HttpReqExtractor<'a, T> {
fn get(&self, key: &str) -> Option<&str> {
self.req.headers().get(key).and_then(|v| v.to_str().ok())
}
fn keys(&self) -> Vec<&str> {
self.req.headers().keys().map(|k| k.as_str()).collect()
}
}
let ctx = opentelemetry::global::get_text_map_propagator(|pp| pp.extract(&HttpReqExtractor { req: &request }));
Span::current().set_parent(ctx);
let res = match (request.method(), request.uri().path()) { let res = match (request.method(), request.uri().path()) {
(&Method::GET, "/metrics") => endpoint_metrics(registry), (&Method::GET, "/metrics") => endpoint_metrics(registry),
(&Method::GET, "/rooms") => endpoint_rooms(core.rooms).await, (&Method::GET, "/rooms") => endpoint_rooms(core.rooms).await,
(&Method::POST, paths::CREATE_PLAYER) => endpoint_create_player(request, storage).await.or5xx(), (&Method::POST, paths::CREATE_PLAYER) => endpoint_create_player(request, storage).await.or5xx(),
(&Method::POST, paths::STOP_PLAYER) => endpoint_stop_player(request, core.players).await.or5xx(),
(&Method::POST, paths::SET_PASSWORD) => endpoint_set_password(request, storage).await.or5xx(), (&Method::POST, paths::SET_PASSWORD) => endpoint_set_password(request, storage).await.or5xx(),
_ => not_found(), (&Method::POST, rooms::paths::SEND_MESSAGE) => endpoint_send_room_message(request, core).await.or5xx(),
(&Method::POST, rooms::paths::SET_TOPIC) => endpoint_set_room_topic(request, core).await.or5xx(),
(&Method::POST, "/cluster/rooms/add_message") => endpoint_cluster_add_message(request, core).await.or5xx(),
_ => endpoint_not_found(),
}; };
Ok(res) Ok(res)
} }
@ -98,6 +126,7 @@ fn endpoint_metrics(registry: MetricsRegistry) -> Response<Full<Bytes>> {
Response::new(Full::new(Bytes::from(buffer))) Response::new(Full::new(Bytes::from(buffer)))
} }
#[tracing::instrument(skip_all)]
async fn endpoint_rooms(rooms: RoomRegistry) -> Response<Full<Bytes>> { async fn endpoint_rooms(rooms: RoomRegistry) -> Response<Full<Bytes>> {
// TODO introduce management API types independent from core-domain types // TODO introduce management API types independent from core-domain types
// TODO remove `Serialize` implementations from all core-domain types // TODO remove `Serialize` implementations from all core-domain types
@ -105,6 +134,7 @@ async fn endpoint_rooms(rooms: RoomRegistry) -> Response<Full<Bytes>> {
Response::new(room_list) Response::new(room_list)
} }
#[tracing::instrument(skip_all)]
async fn endpoint_create_player( async fn endpoint_create_player(
request: Request<hyper::body::Incoming>, request: Request<hyper::body::Incoming>,
mut storage: Storage, mut storage: Storage,
@ -120,6 +150,25 @@ async fn endpoint_create_player(
Ok(response) Ok(response)
} }
#[tracing::instrument(skip_all)]
async fn endpoint_stop_player(
request: Request<hyper::body::Incoming>,
players: PlayerRegistry,
) -> Result<Response<Full<Bytes>>> {
let str = request.collect().await?.to_bytes();
let Ok(res) = serde_json::from_slice::<StopPlayerRequest>(&str[..]) else {
return Ok(malformed_request());
};
let Ok(player_id) = PlayerId::from(res.name) else {
return Ok(player_not_found());
};
let Some(()) = players.stop_player(&player_id).await? else {
return Ok(player_not_found());
};
Ok(empty_204_request())
}
#[tracing::instrument(skip_all)]
async fn endpoint_set_password( async fn endpoint_set_password(
request: Request<hyper::body::Incoming>, request: Request<hyper::body::Incoming>,
storage: Storage, storage: Storage,
@ -132,22 +181,85 @@ async fn endpoint_set_password(
match verdict { match verdict {
UpdatePasswordResult::PasswordUpdated => {} UpdatePasswordResult::PasswordUpdated => {}
UpdatePasswordResult::UserNotFound => { UpdatePasswordResult::UserNotFound => {
let payload = ErrorResponse { return Ok(player_not_found());
code: errors::PLAYER_NOT_FOUND,
message: "No such player exists",
}
.to_body();
let mut response = Response::new(payload);
*response.status_mut() = StatusCode::UNPROCESSABLE_ENTITY;
return Ok(response);
} }
} }
let mut response = Response::new(Full::<Bytes>::default()); Ok(empty_204_request())
*response.status_mut() = StatusCode::NO_CONTENT;
Ok(response)
} }
pub fn not_found() -> Response<Full<Bytes>> { #[tracing::instrument(skip_all, name = "LavinaClient::endpoint_send_room_message")]
async fn endpoint_send_room_message(
request: Request<hyper::body::Incoming>,
mut core: LavinaCore,
) -> Result<Response<Full<Bytes>>> {
let str = request.collect().await?.to_bytes();
let Ok(req) = serde_json::from_slice::<rooms::SendMessageReq>(&str[..]) else {
return Ok(malformed_request());
};
let Ok(room_id) = RoomId::from(req.room_id) else {
return Ok(room_not_found());
};
let Ok(player_id) = PlayerId::from(req.author_id) else {
return Ok(player_not_found());
};
let mut player = core.players.connect_to_player(&player_id).await;
let res = player.send_message(room_id, req.message.into()).await?;
match res {
SendMessageResult::NoSuchRoom => Ok(room_not_found()),
SendMessageResult::Success(_) => Ok(empty_204_request()),
}
}
async fn endpoint_set_room_topic(
request: Request<hyper::body::Incoming>,
core: LavinaCore,
) -> Result<Response<Full<Bytes>>> {
let str = request.collect().await?.to_bytes();
let Ok(req) = serde_json::from_slice::<rooms::SetTopicReq>(&str[..]) else {
return Ok(malformed_request());
};
let Ok(room_id) = RoomId::from(req.room_id) else {
return Ok(room_not_found());
};
let Ok(player_id) = PlayerId::from(req.author_id) else {
return Ok(player_not_found());
};
let mut player = core.players.connect_to_player(&player_id).await;
player.change_topic(room_id, req.topic.into()).await?;
Ok(empty_204_request())
}
#[tracing::instrument(skip_all, name = "endpoint_cluster_add_message")]
async fn endpoint_cluster_add_message(
request: Request<hyper::body::Incoming>,
core: LavinaCore,
) -> Result<Response<Full<Bytes>>> {
let str = request.collect().await?.to_bytes();
let Ok(req) = serde_json::from_slice::<SendMessageReq>(&str[..]) else {
return Ok(malformed_request());
};
tracing::info!("Incoming request: {:?}", &req);
let Ok(created_at) = chrono::DateTime::parse_from_rfc3339(req.created_at) else {
dbg!(&req.created_at);
return Ok(malformed_request());
};
let Ok(room_id) = RoomId::from(req.room_id) else {
dbg!(&req.room_id);
return Ok(room_not_found());
};
let Ok(player_id) = PlayerId::from(req.player_id) else {
dbg!(&req.player_id);
return Ok(player_not_found());
};
let Some(room_handle) = core.rooms.get_room(&room_id).await else {
dbg!(&room_id);
return Ok(room_not_found());
};
room_handle.send_message(&player_id, req.message.into(), created_at.to_utc()).await;
Ok(empty_204_request())
}
fn endpoint_not_found() -> Response<Full<Bytes>> {
let payload = ErrorResponse { let payload = ErrorResponse {
code: errors::INVALID_PATH, code: errors::INVALID_PATH,
message: "The path does not exist", message: "The path does not exist",
@ -159,6 +271,28 @@ pub fn not_found() -> Response<Full<Bytes>> {
response response
} }
fn player_not_found() -> Response<Full<Bytes>> {
let payload = ErrorResponse {
code: errors::PLAYER_NOT_FOUND,
message: "No such player exists",
}
.to_body();
let mut response = Response::new(payload);
*response.status_mut() = StatusCode::UNPROCESSABLE_ENTITY;
response
}
fn room_not_found() -> Response<Full<Bytes>> {
let payload = ErrorResponse {
code: rooms::errors::ROOM_NOT_FOUND,
message: "No such room exists",
}
.to_body();
let mut response = Response::new(payload);
*response.status_mut() = StatusCode::UNPROCESSABLE_ENTITY;
response
}
fn malformed_request() -> Response<Full<Bytes>> { fn malformed_request() -> Response<Full<Bytes>> {
let payload = ErrorResponse { let payload = ErrorResponse {
code: errors::MALFORMED_REQUEST, code: errors::MALFORMED_REQUEST,
@ -171,9 +305,16 @@ fn malformed_request() -> Response<Full<Bytes>> {
return response; return response;
} }
fn empty_204_request() -> Response<Full<Bytes>> {
let mut response = Response::new(Full::<Bytes>::default());
*response.status_mut() = StatusCode::NO_CONTENT;
response
}
trait Or5xx { trait Or5xx {
fn or5xx(self) -> Response<Full<Bytes>>; fn or5xx(self) -> Response<Full<Bytes>>;
} }
impl Or5xx for Result<Response<Full<Bytes>>> { impl Or5xx for Result<Response<Full<Bytes>>> {
fn or5xx(self) -> Response<Full<Bytes>> { fn or5xx(self) -> Response<Full<Bytes>> {
self.unwrap_or_else(|e| { self.unwrap_or_else(|e| {
@ -187,6 +328,7 @@ impl Or5xx for Result<Response<Full<Bytes>>> {
trait ToBody { trait ToBody {
fn to_body(&self) -> Full<Bytes>; fn to_body(&self) -> Full<Bytes>;
} }
impl<T> ToBody for T impl<T> ToBody for T
where where
T: Serialize, T: Serialize,

View File

@ -6,8 +6,10 @@ use std::path::Path;
use clap::Parser; use clap::Parser;
use figment::providers::Format; use figment::providers::Format;
use figment::{providers::Toml, Figment}; use figment::{providers::Toml, Figment};
use opentelemetry::global::set_text_map_propagator;
use opentelemetry::KeyValue; use opentelemetry::KeyValue;
use opentelemetry_otlp::WithExportConfig; use opentelemetry_otlp::WithExportConfig;
use opentelemetry_sdk::propagation::TraceContextPropagator;
use opentelemetry_sdk::trace::{BatchConfig, RandomIdGenerator, Sampler}; use opentelemetry_sdk::trace::{BatchConfig, RandomIdGenerator, Sampler};
use opentelemetry_sdk::{runtime, Resource}; use opentelemetry_sdk::{runtime, Resource};
use opentelemetry_semantic_conventions::resource::SERVICE_NAME; use opentelemetry_semantic_conventions::resource::SERVICE_NAME;
@ -28,6 +30,7 @@ struct ServerConfig {
irc: projection_irc::ServerConfig, irc: projection_irc::ServerConfig,
xmpp: projection_xmpp::ServerConfig, xmpp: projection_xmpp::ServerConfig,
storage: lavina_core::repo::StorageConfig, storage: lavina_core::repo::StorageConfig,
cluster: lavina_core::clustering::ClusterConfig,
tracing: Option<TracingConfig>, tracing: Option<TracingConfig>,
} }
@ -63,11 +66,12 @@ async fn main() -> Result<()> {
irc: irc_config, irc: irc_config,
xmpp: xmpp_config, xmpp: xmpp_config,
storage: storage_config, storage: storage_config,
cluster: cluster_config,
tracing: _, tracing: _,
} = config; } = config;
let metrics = MetricsRegistry::new(); let metrics = MetricsRegistry::new();
let storage = Storage::open(storage_config).await?; let storage = Storage::open(storage_config).await?;
let core = LavinaCore::new(metrics.clone(), storage.clone()).await?; let core = LavinaCore::new(metrics.clone(), cluster_config, storage.clone()).await?;
let telemetry_terminator = http::launch(telemetry_config, metrics.clone(), core.clone(), storage.clone()).await?; let telemetry_terminator = http::launch(telemetry_config, metrics.clone(), core.clone(), storage.clone()).await?;
let irc = projection_irc::launch(irc_config, core.clone(), metrics.clone(), storage.clone()).await?; let irc = projection_irc::launch(irc_config, core.clone(), metrics.clone(), storage.clone()).await?;
let xmpp = projection_xmpp::launch(xmpp_config, core.clone(), metrics.clone(), storage.clone()).await?; let xmpp = projection_xmpp::launch(xmpp_config, core.clone(), metrics.clone(), storage.clone()).await?;
@ -109,7 +113,7 @@ fn set_up_logging(tracing_config: &Option<TracingConfig>) -> Result<()> {
let targets = { let targets = {
use std::{env, str::FromStr}; use std::{env, str::FromStr};
use tracing_subscriber::{filter::Targets, layer::SubscriberExt}; use tracing_subscriber::filter::Targets;
match env::var("RUST_LOG") { match env::var("RUST_LOG") {
Ok(var) => Targets::from_str(&var) Ok(var) => Targets::from_str(&var)
.map_err(|e| { .map_err(|e| {
@ -139,6 +143,7 @@ fn set_up_logging(tracing_config: &Option<TracingConfig>) -> Result<()> {
.with_exporter(trace_exporter) .with_exporter(trace_exporter)
.install_batch(runtime::Tokio)?; .install_batch(runtime::Tokio)?;
let subscriber = subscriber.with(OpenTelemetryLayer::new(tracer)); let subscriber = subscriber.with(OpenTelemetryLayer::new(tracer));
set_text_map_propagator(TraceContextPropagator::new());
targets.with_subscriber(subscriber).try_init()?; targets.with_subscriber(subscriber).try_init()?;
} else { } else {
targets.with_subscriber(subscriber).try_init()?; targets.with_subscriber(subscriber).try_init()?;