forked from lavina/lavina
1
0
Fork 0

Compare commits

..

7 Commits

25 changed files with 904 additions and 126 deletions

40
Cargo.lock generated
View File

@ -114,6 +114,18 @@ version = "1.0.82"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f538837af36e6f6a9be0faa67f9a314f8119e4e4b5867c6ab40ed60360142519" checksum = "f538837af36e6f6a9be0faa67f9a314f8119e4e4b5867c6ab40ed60360142519"
[[package]]
name = "argon2"
version = "0.5.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3c3610892ee6e0cbce8ae2700349fcf8f98adb0dbfbee85aec3c9179d29cc072"
dependencies = [
"base64ct",
"blake2",
"cpufeatures",
"password-hash",
]
[[package]] [[package]]
name = "assert_matches" name = "assert_matches"
version = "1.5.0" version = "1.5.0"
@ -192,6 +204,15 @@ dependencies = [
"serde", "serde",
] ]
[[package]]
name = "blake2"
version = "0.10.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "46502ad458c9a52b69d4d4d32775c788b7a1b85e8bc9d482d92250fc0e3f8efe"
dependencies = [
"digest",
]
[[package]] [[package]]
name = "block-buffer" name = "block-buffer"
version = "0.10.4" version = "0.10.4"
@ -882,8 +903,10 @@ name = "lavina-core"
version = "0.0.2-dev" version = "0.0.2-dev"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"argon2",
"chrono", "chrono",
"prometheus", "prometheus",
"rand_core",
"serde", "serde",
"sqlx", "sqlx",
"tokio", "tokio",
@ -1126,6 +1149,17 @@ dependencies = [
"windows-targets 0.48.5", "windows-targets 0.48.5",
] ]
[[package]]
name = "password-hash"
version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "346f04948ba92c43e8469c1ee6736c7563d71012b17d40745260fe106aac2166"
dependencies = [
"base64ct",
"rand_core",
"subtle",
]
[[package]] [[package]]
name = "paste" name = "paste"
version = "1.0.14" version = "1.0.14"
@ -1263,6 +1297,7 @@ version = "0.0.2-dev"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"bitflags 2.5.0", "bitflags 2.5.0",
"chrono",
"futures-util", "futures-util",
"lavina-core", "lavina-core",
"nonempty", "nonempty",
@ -1774,6 +1809,7 @@ dependencies = [
"atoi", "atoi",
"byteorder", "byteorder",
"bytes", "bytes",
"chrono",
"crc", "crc",
"crossbeam-queue", "crossbeam-queue",
"either", "either",
@ -1832,6 +1868,7 @@ dependencies = [
"sha2", "sha2",
"sqlx-core", "sqlx-core",
"sqlx-mysql", "sqlx-mysql",
"sqlx-postgres",
"sqlx-sqlite", "sqlx-sqlite",
"syn 1.0.109", "syn 1.0.109",
"tempfile", "tempfile",
@ -1849,6 +1886,7 @@ dependencies = [
"bitflags 2.5.0", "bitflags 2.5.0",
"byteorder", "byteorder",
"bytes", "bytes",
"chrono",
"crc", "crc",
"digest", "digest",
"dotenvy", "dotenvy",
@ -1890,6 +1928,7 @@ dependencies = [
"base64 0.21.7", "base64 0.21.7",
"bitflags 2.5.0", "bitflags 2.5.0",
"byteorder", "byteorder",
"chrono",
"crc", "crc",
"dotenvy", "dotenvy",
"etcetera", "etcetera",
@ -1925,6 +1964,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b244ef0a8414da0bed4bb1910426e890b19e5e9bccc27ada6b797d05c55ae0aa" checksum = "b244ef0a8414da0bed4bb1910426e890b19e5e9bccc27ada6b797d05c55ae0aa"
dependencies = [ dependencies = [
"atoi", "atoi",
"chrono",
"flume", "flume",
"futures-channel", "futures-channel",
"futures-core", "futures-core",

View File

@ -31,6 +31,7 @@ base64 = "0.22.0"
lavina-core = { path = "crates/lavina-core" } lavina-core = { path = "crates/lavina-core" }
tracing-subscriber = "0.3.16" tracing-subscriber = "0.3.16"
sasl = { path = "crates/sasl" } sasl = { path = "crates/sasl" }
chrono = "0.4.37"
[package] [package]
name = "lavina" name = "lavina"

View File

@ -5,9 +5,11 @@ version.workspace = true
[dependencies] [dependencies]
anyhow.workspace = true anyhow.workspace = true
sqlx = { version = "0.7.4", features = ["sqlite", "migrate"] } sqlx = { version = "0.7.4", features = ["sqlite", "migrate", "chrono"] }
serde.workspace = true serde.workspace = true
tokio.workspace = true tokio.workspace = true
tracing.workspace = true tracing.workspace = true
prometheus.workspace = true prometheus.workspace = true
chrono = "0.4.37" chrono.workspace = true
argon2 = { version = "0.5.3" }
rand_core = { version = "0.6.4", features = ["getrandom"] }

View File

@ -0,0 +1,17 @@
create table dialogs(
id integer primary key autoincrement not null,
participant_1 integer not null,
participant_2 integer not null,
created_at timestamp not null,
message_count integer not null default 0,
unique (participant_1, participant_2)
);
create table dialog_messages(
dialog_id integer not null,
id integer not null, -- unique per dialog, sequential in one dialog
author_id integer not null,
content string not null,
created_at timestamp not null,
primary key (dialog_id, id)
);

View File

@ -0,0 +1,4 @@
create table challenges_argon2_password(
user_id integer primary key not null,
hash string not null
);

View File

@ -0,0 +1,64 @@
use anyhow::{anyhow, Result};
use argon2::password_hash::{PasswordHash, PasswordHasher, PasswordVerifier, SaltString};
use argon2::Argon2;
use rand_core::OsRng;
use crate::prelude::log;
use crate::repo::Storage;
pub enum Verdict {
Authenticated,
UserNotFound,
InvalidPassword,
}
pub enum UpdatePasswordResult {
PasswordUpdated,
UserNotFound,
}
pub struct Authenticator<'a> {
storage: &'a Storage,
}
impl<'a> Authenticator<'a> {
pub fn new(storage: &'a Storage) -> Self {
Self { storage }
}
pub async fn authenticate(&self, login: &str, provided_password: &str) -> Result<Verdict> {
let Some(stored_user) = self.storage.retrieve_user_by_name(login).await? else {
return Ok(Verdict::UserNotFound);
};
if let Some(argon2_hash) = stored_user.argon2_hash {
let argon2 = Argon2::default();
let password_hash =
PasswordHash::new(&argon2_hash).map_err(|e| anyhow!("Failed to parse password hash: {e:?}"))?;
let password_verifier = argon2.verify_password(provided_password.as_bytes(), &password_hash);
if password_verifier.is_ok() {
return Ok(Verdict::Authenticated);
}
}
if let Some(expected_password) = stored_user.password {
if expected_password == provided_password {
return Ok(Verdict::Authenticated);
}
}
Ok(Verdict::InvalidPassword)
}
pub async fn set_password(&self, login: &str, provided_password: &str) -> Result<UpdatePasswordResult> {
let Some(u) = self.storage.retrieve_user_by_name(login).await? else {
return Ok(UpdatePasswordResult::UserNotFound);
};
let salt = SaltString::generate(&mut OsRng);
let argon2 = Argon2::default();
let password_hash = argon2
.hash_password(provided_password.as_bytes(), &salt)
.map_err(|e| anyhow!("Failed to hash password: {e:?}"))?;
self.storage.set_argon2_challenge(u.id, password_hash.to_string().as_str()).await?;
log::info!("Password changed for player {login}");
Ok(UpdatePasswordResult::PasswordUpdated)
}
}

View File

@ -0,0 +1,165 @@
//! Domain of dialogs conversations between two participants.
//!
//! Dialogs are different from rooms in that they are always between two participants.
//! There are no admins or other roles in dialogs, both participants have equal rights.
use std::collections::HashMap;
use std::sync::Arc;
use chrono::{DateTime, Utc};
use tokio::sync::RwLock as AsyncRwLock;
use crate::player::{PlayerId, PlayerRegistry, Updates};
use crate::prelude::*;
use crate::repo::Storage;
/// Id of a conversation between two players.
///
/// Dialogs are identified by the pair of participants' ids. The order of ids does not matter.
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct DialogId(PlayerId, PlayerId);
impl DialogId {
pub fn new(a: PlayerId, b: PlayerId) -> DialogId {
if a.as_inner() < b.as_inner() {
DialogId(a, b)
} else {
DialogId(b, a)
}
}
pub fn as_inner(&self) -> (&PlayerId, &PlayerId) {
(&self.0, &self.1)
}
pub fn into_inner(self) -> (PlayerId, PlayerId) {
(self.0, self.1)
}
}
struct Dialog {
storage_id: u32,
player_storage_id_1: u32,
player_storage_id_2: u32,
message_count: u32,
}
struct DialogRegistryInner {
dialogs: HashMap<DialogId, AsyncRwLock<Dialog>>,
players: Option<PlayerRegistry>,
storage: Storage,
}
#[derive(Clone)]
pub struct DialogRegistry(Arc<AsyncRwLock<DialogRegistryInner>>);
impl DialogRegistry {
pub async fn send_message(
&self,
from: PlayerId,
to: PlayerId,
body: Str,
created_at: &DateTime<Utc>,
) -> Result<()> {
let mut guard = self.0.read().await;
let id = DialogId::new(from.clone(), to.clone());
let dialog = guard.dialogs.get(&id);
if let Some(d) = dialog {
let mut d = d.write().await;
guard.storage.increment_dialog_message_count(d.storage_id).await?;
d.message_count += 1;
} else {
drop(guard);
let mut guard2 = self.0.write().await;
// double check in case concurrent access has loaded this dialog
if let Some(d) = guard2.dialogs.get(&id) {
let mut d = d.write().await;
guard2.storage.increment_dialog_message_count(d.storage_id).await?;
d.message_count += 1;
} else {
let (p1, p2) = id.as_inner();
tracing::info!("Dialog {id:?} not found locally, trying to load from storage");
let stored_dialog = match guard2.storage.retrieve_dialog(p1.as_inner(), p2.as_inner()).await? {
Some(t) => t,
None => {
tracing::info!("Dialog {id:?} does not exist, creating a new one in storage");
guard2.storage.initialize_dialog(p1.as_inner(), p2.as_inner(), created_at).await?
}
};
tracing::info!("Dialog {id:?} loaded");
guard2.storage.increment_dialog_message_count(stored_dialog.id).await?;
let dialog = Dialog {
storage_id: stored_dialog.id,
player_storage_id_1: stored_dialog.participant_1,
player_storage_id_2: stored_dialog.participant_2,
message_count: stored_dialog.message_count + 1,
};
guard2.dialogs.insert(id.clone(), AsyncRwLock::new(dialog));
}
guard = guard2.downgrade();
}
// TODO send message to the other player and persist it
let Some(players) = &guard.players else {
tracing::error!("No player registry present");
return Ok(());
};
let Some(player) = players.get_player(&to).await else {
tracing::debug!("Player {to:?} not active, not sending message");
return Ok(());
};
let update = Updates::NewDialogMessage {
sender: from.clone(),
receiver: to.clone(),
body: body.clone(),
created_at: created_at.clone(),
};
player.update(update).await;
return Ok(());
}
}
impl DialogRegistry {
pub fn new(storage: Storage) -> DialogRegistry {
DialogRegistry(Arc::new(AsyncRwLock::new(DialogRegistryInner {
dialogs: HashMap::new(),
players: None,
storage,
})))
}
pub async fn set_players(&self, players: PlayerRegistry) {
let mut guard = self.0.write().await;
guard.players = Some(players);
}
pub async fn unset_players(&self) {
let mut guard = self.0.write().await;
guard.players = None;
}
pub fn shutdown(self) -> Result<()> {
let res = match Arc::try_unwrap(self.0) {
Ok(e) => e,
Err(_) => return Err(fail("failed to acquire dialogs ownership on shutdown")),
};
let res = res.into_inner();
drop(res);
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_dialog_id_new() {
let a = PlayerId::from("a").unwrap();
let b = PlayerId::from("b").unwrap();
let id1 = DialogId::new(a.clone(), b.clone());
let id2 = DialogId::new(a.clone(), b.clone());
// Dialog ids are invariant with respect to the order of participants
assert_eq!(id1, id2);
assert_eq!(id1.as_inner(), (&a, &b));
assert_eq!(id2.as_inner(), (&a, &b));
}
}

View File

@ -2,10 +2,13 @@
use anyhow::Result; use anyhow::Result;
use prometheus::Registry as MetricsRegistry; use prometheus::Registry as MetricsRegistry;
use crate::dialog::DialogRegistry;
use crate::player::PlayerRegistry; use crate::player::PlayerRegistry;
use crate::repo::Storage; use crate::repo::Storage;
use crate::room::RoomRegistry; use crate::room::RoomRegistry;
pub mod auth;
pub mod dialog;
pub mod player; pub mod player;
pub mod prelude; pub mod prelude;
pub mod repo; pub mod repo;
@ -18,20 +21,29 @@ mod table;
pub struct LavinaCore { pub struct LavinaCore {
pub players: PlayerRegistry, pub players: PlayerRegistry,
pub rooms: RoomRegistry, pub rooms: RoomRegistry,
pub dialogs: DialogRegistry,
} }
impl LavinaCore { impl LavinaCore {
pub async fn new(mut metrics: MetricsRegistry, storage: Storage) -> Result<LavinaCore> { pub async fn new(mut metrics: MetricsRegistry, storage: Storage) -> Result<LavinaCore> {
// TODO shutdown all services in reverse order on error // TODO shutdown all services in reverse order on error
let rooms = RoomRegistry::new(&mut metrics, storage.clone())?; let rooms = RoomRegistry::new(&mut metrics, storage.clone())?;
let players = PlayerRegistry::empty(rooms.clone(), storage.clone(), &mut metrics)?; let dialogs = DialogRegistry::new(storage.clone());
Ok(LavinaCore { players, rooms }) let players = PlayerRegistry::empty(rooms.clone(), dialogs.clone(), storage.clone(), &mut metrics)?;
dialogs.set_players(players.clone()).await;
Ok(LavinaCore {
players,
rooms,
dialogs,
})
} }
pub async fn shutdown(mut self) -> Result<()> { pub async fn shutdown(mut self) -> Result<()> {
self.players.shutdown_all().await?; self.players.shutdown_all().await?;
drop(self.players); self.dialogs.unset_players().await;
drop(self.rooms); self.players.shutdown()?;
self.dialogs.shutdown()?;
self.rooms.shutdown()?;
Ok(()) Ok(())
} }
} }

View File

@ -10,11 +10,13 @@
use std::collections::{HashMap, HashSet}; use std::collections::{HashMap, HashSet};
use std::sync::Arc; use std::sync::Arc;
use chrono::{DateTime, Utc};
use prometheus::{IntGauge, Registry as MetricsRegistry}; use prometheus::{IntGauge, Registry as MetricsRegistry};
use serde::Serialize; use serde::Serialize;
use tokio::sync::mpsc::{channel, Receiver, Sender}; use tokio::sync::mpsc::{channel, Receiver, Sender};
use tokio::sync::RwLock; use tokio::sync::RwLock;
use crate::dialog::DialogRegistry;
use crate::prelude::*; use crate::prelude::*;
use crate::repo::Storage; use crate::repo::Storage;
use crate::room::{RoomHandle, RoomId, RoomInfo, RoomRegistry}; use crate::room::{RoomHandle, RoomId, RoomInfo, RoomRegistry};
@ -52,12 +54,12 @@ pub struct ConnectionId(pub AnonKey);
/// The connection is used to send commands to the player actor and to receive updates that might be sent to the client. /// The connection is used to send commands to the player actor and to receive updates that might be sent to the client.
pub struct PlayerConnection { pub struct PlayerConnection {
pub connection_id: ConnectionId, pub connection_id: ConnectionId,
pub receiver: Receiver<Updates>, pub receiver: Receiver<ConnectionMessage>,
player_handle: PlayerHandle, player_handle: PlayerHandle,
} }
impl PlayerConnection { impl PlayerConnection {
/// Handled in [Player::send_message]. /// Handled in [Player::send_message].
pub async fn send_message(&mut self, room_id: RoomId, body: Str) -> Result<()> { pub async fn send_message(&mut self, room_id: RoomId, body: Str) -> Result<SendMessageResult> {
let (promise, deferred) = oneshot(); let (promise, deferred) = oneshot();
let cmd = ClientCommand::SendMessage { room_id, body, promise }; let cmd = ClientCommand::SendMessage { room_id, body, promise };
self.player_handle.send(ActorCommand::ClientCommand(cmd, self.connection_id.clone())).await; self.player_handle.send(ActorCommand::ClientCommand(cmd, self.connection_id.clone())).await;
@ -103,6 +105,18 @@ impl PlayerConnection {
self.player_handle.send(ActorCommand::ClientCommand(cmd, self.connection_id.clone())).await; self.player_handle.send(ActorCommand::ClientCommand(cmd, self.connection_id.clone())).await;
Ok(deferred.await?) Ok(deferred.await?)
} }
/// Handler in [Player::send_dialog_message].
pub async fn send_dialog_message(&self, recipient: PlayerId, body: Str) -> Result<()> {
let (promise, deferred) = oneshot();
let cmd = ClientCommand::SendDialogMessage {
recipient,
body,
promise,
};
self.player_handle.send(ActorCommand::ClientCommand(cmd, self.connection_id.clone())).await;
Ok(deferred.await?)
}
} }
/// Handle to a player actor. /// Handle to a player actor.
@ -138,7 +152,7 @@ impl PlayerHandle {
enum ActorCommand { enum ActorCommand {
/// Establish a new connection. /// Establish a new connection.
AddConnection { AddConnection {
sender: Sender<Updates>, sender: Sender<ConnectionMessage>,
promise: Promise<ConnectionId>, promise: Promise<ConnectionId>,
}, },
/// Terminate an existing connection. /// Terminate an existing connection.
@ -163,7 +177,7 @@ pub enum ClientCommand {
SendMessage { SendMessage {
room_id: RoomId, room_id: RoomId,
body: Str, body: Str,
promise: Promise<()>, promise: Promise<SendMessageResult>,
}, },
ChangeTopic { ChangeTopic {
room_id: RoomId, room_id: RoomId,
@ -173,6 +187,11 @@ pub enum ClientCommand {
GetRooms { GetRooms {
promise: Promise<Vec<RoomInfo>>, promise: Promise<Vec<RoomInfo>>,
}, },
SendDialogMessage {
recipient: PlayerId,
body: Str,
promise: Promise<()>,
},
} }
pub enum JoinResult { pub enum JoinResult {
@ -181,6 +200,11 @@ pub enum JoinResult {
Banned, Banned,
} }
pub enum SendMessageResult {
Success(DateTime<Utc>),
NoSuchRoom,
}
/// Player update event type which is sent to a player actor and from there to a connection handler. /// Player update event type which is sent to a player actor and from there to a connection handler.
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub enum Updates { pub enum Updates {
@ -192,6 +216,7 @@ pub enum Updates {
room_id: RoomId, room_id: RoomId,
author_id: PlayerId, author_id: PlayerId,
body: Str, body: Str,
created_at: DateTime<Utc>,
}, },
RoomJoined { RoomJoined {
room_id: RoomId, room_id: RoomId,
@ -203,6 +228,12 @@ pub enum Updates {
}, },
/// The player was banned from the room and left it immediately. /// The player was banned from the room and left it immediately.
BannedFrom(RoomId), BannedFrom(RoomId),
NewDialogMessage {
sender: PlayerId,
receiver: PlayerId,
body: Str,
created_at: DateTime<Utc>,
},
} }
/// Handle to a player registry — a shared data structure containing information about players. /// Handle to a player registry — a shared data structure containing information about players.
@ -211,6 +242,7 @@ pub struct PlayerRegistry(Arc<RwLock<PlayerRegistryInner>>);
impl PlayerRegistry { impl PlayerRegistry {
pub fn empty( pub fn empty(
room_registry: RoomRegistry, room_registry: RoomRegistry,
dialogs: DialogRegistry,
storage: Storage, storage: Storage,
metrics: &mut MetricsRegistry, metrics: &mut MetricsRegistry,
) -> Result<PlayerRegistry> { ) -> Result<PlayerRegistry> {
@ -218,6 +250,7 @@ impl PlayerRegistry {
metrics.register(Box::new(metric_active_players.clone()))?; metrics.register(Box::new(metric_active_players.clone()))?;
let inner = PlayerRegistryInner { let inner = PlayerRegistryInner {
room_registry, room_registry,
dialogs,
storage, storage,
players: HashMap::new(), players: HashMap::new(),
metric_active_players, metric_active_players,
@ -225,15 +258,42 @@ impl PlayerRegistry {
Ok(PlayerRegistry(Arc::new(RwLock::new(inner)))) Ok(PlayerRegistry(Arc::new(RwLock::new(inner))))
} }
pub fn shutdown(self) -> Result<()> {
let res = match Arc::try_unwrap(self.0) {
Ok(e) => e,
Err(_) => return Err(fail("failed to acquire players ownership on shutdown")),
};
let res = res.into_inner();
drop(res);
Ok(())
}
pub async fn get_player(&self, id: &PlayerId) -> Option<PlayerHandle> {
let inner = self.0.read().await;
inner.players.get(id).map(|(handle, _)| handle.clone())
}
pub async fn get_or_launch_player(&mut self, id: &PlayerId) -> PlayerHandle { pub async fn get_or_launch_player(&mut self, id: &PlayerId) -> PlayerHandle {
let mut inner = self.0.write().await; let inner = self.0.read().await;
if let Some((handle, _)) = inner.players.get(id) { if let Some((handle, _)) = inner.players.get(id) {
handle.clone() handle.clone()
} else { } else {
let (handle, fiber) = Player::launch(id.clone(), inner.room_registry.clone(), inner.storage.clone()).await; drop(inner);
inner.players.insert(id.clone(), (handle.clone(), fiber)); let mut inner = self.0.write().await;
inner.metric_active_players.inc(); if let Some((handle, _)) = inner.players.get(id) {
handle handle.clone()
} else {
let (handle, fiber) = Player::launch(
id.clone(),
inner.room_registry.clone(),
inner.dialogs.clone(),
inner.storage.clone(),
)
.await;
inner.players.insert(id.clone(), (handle.clone(), fiber));
inner.metric_active_players.inc();
handle
}
} }
} }
@ -258,6 +318,7 @@ impl PlayerRegistry {
/// The player registry state representation. /// The player registry state representation.
struct PlayerRegistryInner { struct PlayerRegistryInner {
room_registry: RoomRegistry, room_registry: RoomRegistry,
dialogs: DialogRegistry,
storage: Storage, storage: Storage,
/// Active player actors. /// Active player actors.
players: HashMap<PlayerId, (PlayerHandle, JoinHandle<Player>)>, players: HashMap<PlayerId, (PlayerHandle, JoinHandle<Player>)>,
@ -268,16 +329,22 @@ struct PlayerRegistryInner {
struct Player { struct Player {
player_id: PlayerId, player_id: PlayerId,
storage_id: u32, storage_id: u32,
connections: AnonTable<Sender<Updates>>, connections: AnonTable<Sender<ConnectionMessage>>,
my_rooms: HashMap<RoomId, RoomHandle>, my_rooms: HashMap<RoomId, RoomHandle>,
banned_from: HashSet<RoomId>, banned_from: HashSet<RoomId>,
rx: Receiver<ActorCommand>, rx: Receiver<ActorCommand>,
handle: PlayerHandle, handle: PlayerHandle,
rooms: RoomRegistry, rooms: RoomRegistry,
dialogs: DialogRegistry,
storage: Storage, storage: Storage,
} }
impl Player { impl Player {
async fn launch(player_id: PlayerId, rooms: RoomRegistry, storage: Storage) -> (PlayerHandle, JoinHandle<Player>) { async fn launch(
player_id: PlayerId,
rooms: RoomRegistry,
dialogs: DialogRegistry,
storage: Storage,
) -> (PlayerHandle, JoinHandle<Player>) {
let (tx, rx) = channel(32); let (tx, rx) = channel(32);
let handle = PlayerHandle { tx }; let handle = PlayerHandle { tx };
let handle_clone = handle.clone(); let handle_clone = handle.clone();
@ -294,6 +361,7 @@ impl Player {
rx, rx,
handle, handle,
rooms, rooms,
dialogs,
storage, storage,
}; };
let fiber = tokio::task::spawn(player.main_loop()); let fiber = tokio::task::spawn(player.main_loop());
@ -333,7 +401,7 @@ impl Player {
/// Handle an incoming update by changing the internal state and broadcasting it to all connections if necessary. /// Handle an incoming update by changing the internal state and broadcasting it to all connections if necessary.
async fn handle_update(&mut self, update: Updates) { async fn handle_update(&mut self, update: Updates) {
log::info!( log::debug!(
"Player received an update, broadcasting to {} connections", "Player received an update, broadcasting to {} connections",
self.connections.len() self.connections.len()
); );
@ -345,7 +413,7 @@ impl Player {
_ => {} _ => {}
} }
for (_, connection) in &self.connections { for (_, connection) in &self.connections {
let _ = connection.send(update.clone()).await; let _ = connection.send(ConnectionMessage::Update(update.clone())).await;
} }
} }
@ -367,8 +435,8 @@ impl Player {
let _ = promise.send(()); let _ = promise.send(());
} }
ClientCommand::SendMessage { room_id, body, promise } => { ClientCommand::SendMessage { room_id, body, promise } => {
self.send_message(connection_id, room_id, body).await; let result = self.send_message(connection_id, room_id, body).await;
let _ = promise.send(()); let _ = promise.send(result);
} }
ClientCommand::ChangeTopic { ClientCommand::ChangeTopic {
room_id, room_id,
@ -382,6 +450,14 @@ impl Player {
let result = self.get_rooms().await; let result = self.get_rooms().await;
let _ = promise.send(result); let _ = promise.send(result);
} }
ClientCommand::SendDialogMessage {
recipient,
body,
promise,
} => {
self.send_dialog_message(connection_id, recipient, body).await;
let _ = promise.send(());
}
} }
} }
@ -425,18 +501,21 @@ impl Player {
self.broadcast_update(update, connection_id).await; self.broadcast_update(update, connection_id).await;
} }
async fn send_message(&mut self, connection_id: ConnectionId, room_id: RoomId, body: Str) { async fn send_message(&mut self, connection_id: ConnectionId, room_id: RoomId, body: Str) -> SendMessageResult {
let Some(room) = self.my_rooms.get(&room_id) else { let Some(room) = self.my_rooms.get(&room_id) else {
tracing::info!("no room found"); tracing::info!("no room found");
return; return SendMessageResult::NoSuchRoom;
}; };
room.send_message(&self.player_id, body.clone()).await; let created_at = chrono::Utc::now();
room.send_message(&self.player_id, body.clone(), created_at.clone()).await;
let update = Updates::NewMessage { let update = Updates::NewMessage {
room_id, room_id,
author_id: self.player_id.clone(), author_id: self.player_id.clone(),
body, body,
created_at,
}; };
self.broadcast_update(update, connection_id).await; self.broadcast_update(update, connection_id).await;
SendMessageResult::Success(created_at)
} }
async fn change_topic(&mut self, connection_id: ConnectionId, room_id: RoomId, new_topic: Str) { async fn change_topic(&mut self, connection_id: ConnectionId, room_id: RoomId, new_topic: Str) {
@ -457,6 +536,18 @@ impl Player {
response response
} }
async fn send_dialog_message(&self, connection_id: ConnectionId, recipient: PlayerId, body: Str) {
let created_at = chrono::Utc::now();
self.dialogs.send_message(self.player_id.clone(), recipient.clone(), body.clone(), &created_at).await.unwrap();
let update = Updates::NewDialogMessage {
sender: self.player_id.clone(),
receiver: recipient.clone(),
body,
created_at,
};
self.broadcast_update(update, connection_id).await;
}
/// Broadcasts an update to all connections except the one with the given id. /// Broadcasts an update to all connections except the one with the given id.
/// ///
/// This is called after handling a client command. /// This is called after handling a client command.
@ -466,7 +557,17 @@ impl Player {
if ConnectionId(a) == except { if ConnectionId(a) == except {
continue; continue;
} }
let _ = b.send(update.clone()).await; let _ = b.send(ConnectionMessage::Update(update.clone())).await;
} }
} }
} }
pub enum ConnectionMessage {
Update(Updates),
Stop(StopReason),
}
pub enum StopReason {
ServerShutdown,
InternalError,
}

View File

@ -0,0 +1,19 @@
use anyhow::Result;
use crate::repo::Storage;
impl Storage {
pub async fn set_argon2_challenge(&self, user_id: u32, hash: &str) -> Result<()> {
let mut executor = self.conn.lock().await;
sqlx::query(
"insert into challenges_argon2_password(user_id, hash)
values (?, ?)
on conflict(user_id) do update set hash = excluded.hash;",
)
.bind(user_id)
.bind(hash)
.execute(&mut *executor)
.await?;
Ok(())
}
}

View File

@ -0,0 +1,68 @@
use anyhow::Result;
use chrono::{DateTime, Utc};
use sqlx::FromRow;
use crate::repo::Storage;
impl Storage {
pub async fn retrieve_dialog(&self, participant_1: &str, participant_2: &str) -> Result<Option<StoredDialog>> {
let mut executor = self.conn.lock().await;
let res = sqlx::query_as(
"select r.id, r.participant_1, r.participant_2, r.message_count
from dialogs r join users u1 on r.participant_1 = u1.id join users u2 on r.participant_2 = u2.id
where u1.name = ? and u2.name = ?;",
)
.bind(participant_1)
.bind(participant_2)
.fetch_optional(&mut *executor)
.await?;
Ok(res)
}
pub async fn increment_dialog_message_count(&self, storage_id: u32) -> Result<()> {
let mut executor = self.conn.lock().await;
sqlx::query(
"update rooms set message_count = message_count + 1
where id = ?;",
)
.bind(storage_id)
.execute(&mut *executor)
.await?;
Ok(())
}
pub async fn initialize_dialog(
&self,
participant_1: &str,
participant_2: &str,
created_at: &DateTime<Utc>,
) -> Result<StoredDialog> {
let mut executor = self.conn.lock().await;
let res: StoredDialog = sqlx::query_as(
"insert into dialogs(participant_1, participant_2, created_at)
values (
(select id from users where name = ?),
(select id from users where name = ?),
?
)
returning id, participant_1, participant_2, message_count;",
)
.bind(participant_1)
.bind(participant_2)
.bind(&created_at)
.fetch_one(&mut *executor)
.await?;
Ok(res)
}
}
#[derive(FromRow)]
pub struct StoredDialog {
pub id: u32,
pub participant_1: u32,
pub participant_2: u32,
pub message_count: u32,
}

View File

@ -4,6 +4,7 @@ use std::str::FromStr;
use std::sync::Arc; use std::sync::Arc;
use anyhow::anyhow; use anyhow::anyhow;
use chrono::{DateTime, Utc};
use serde::Deserialize; use serde::Deserialize;
use sqlx::sqlite::SqliteConnectOptions; use sqlx::sqlite::SqliteConnectOptions;
use sqlx::{ConnectOptions, Connection, FromRow, Sqlite, SqliteConnection, Transaction}; use sqlx::{ConnectOptions, Connection, FromRow, Sqlite, SqliteConnection, Transaction};
@ -11,6 +12,8 @@ use tokio::sync::Mutex;
use crate::prelude::*; use crate::prelude::*;
mod auth;
mod dialog;
mod room; mod room;
mod user; mod user;
@ -37,11 +40,12 @@ impl Storage {
Ok(Storage { conn }) Ok(Storage { conn })
} }
pub async fn retrieve_user_by_name(&mut self, name: &str) -> Result<Option<StoredUser>> { pub async fn retrieve_user_by_name(&self, name: &str) -> Result<Option<StoredUser>> {
let mut executor = self.conn.lock().await; let mut executor = self.conn.lock().await;
let res = sqlx::query_as( let res = sqlx::query_as(
"select u.id, u.name, c.password "select u.id, u.name, c.password, a.hash as argon2_hash
from users u left join challenges_plain_password c on u.id = c.user_id from users u left join challenges_plain_password c on u.id = c.user_id
left join challenges_argon2_password a on u.id = a.user_id
where u.name = ?;", where u.name = ?;",
) )
.bind(name) .bind(name)
@ -80,7 +84,14 @@ impl Storage {
Ok(id) Ok(id)
} }
pub async fn insert_message(&mut self, room_id: u32, id: u32, content: &str, author_id: &str) -> Result<()> { pub async fn insert_message(
&mut self,
room_id: u32,
id: u32,
content: &str,
author_id: &str,
created_at: &DateTime<Utc>,
) -> Result<()> {
let mut executor = self.conn.lock().await; let mut executor = self.conn.lock().await;
let res: Option<(u32,)> = sqlx::query_as("select id from users where name = ?;") let res: Option<(u32,)> = sqlx::query_as("select id from users where name = ?;")
.bind(author_id) .bind(author_id)
@ -98,7 +109,7 @@ impl Storage {
.bind(id) .bind(id)
.bind(content) .bind(content)
.bind(author_id) .bind(author_id)
.bind(chrono::Utc::now().to_string()) .bind(created_at.to_string())
.bind(room_id) .bind(room_id)
.execute(&mut *executor) .execute(&mut *executor)
.await?; .await?;
@ -128,7 +139,7 @@ impl Storage {
Ok(()) Ok(())
} }
pub async fn set_password<'a>(&'a mut self, name: &'a str, pwd: &'a str) -> Result<Option<()>> { pub async fn set_password<'a>(&'a self, name: &'a str, pwd: &'a str) -> Result<Option<()>> {
async fn inner(txn: &mut Transaction<'_, Sqlite>, name: &str, pwd: &str) -> Result<Option<()>> { async fn inner(txn: &mut Transaction<'_, Sqlite>, name: &str, pwd: &str) -> Result<Option<()>> {
let id: Option<(u32,)> = sqlx::query_as("select * from users where name = ? limit 1;") let id: Option<(u32,)> = sqlx::query_as("select * from users where name = ? limit 1;")
.bind(name) .bind(name)
@ -166,6 +177,7 @@ pub struct StoredUser {
pub id: u32, pub id: u32,
pub name: String, pub name: String,
pub password: Option<String>, pub password: Option<String>,
pub argon2_hash: Option<Box<str>>,
} }
#[derive(FromRow)] #[derive(FromRow)]

View File

@ -2,6 +2,7 @@
use std::collections::HashSet; use std::collections::HashSet;
use std::{collections::HashMap, hash::Hash, sync::Arc}; use std::{collections::HashMap, hash::Hash, sync::Arc};
use chrono::{DateTime, Utc};
use prometheus::{IntGauge, Registry as MetricRegistry}; use prometheus::{IntGauge, Registry as MetricRegistry};
use serde::Serialize; use serde::Serialize;
use tokio::sync::RwLock as AsyncRwLock; use tokio::sync::RwLock as AsyncRwLock;
@ -47,6 +48,17 @@ impl RoomRegistry {
Ok(RoomRegistry(Arc::new(AsyncRwLock::new(inner)))) Ok(RoomRegistry(Arc::new(AsyncRwLock::new(inner))))
} }
pub fn shutdown(self) -> Result<()> {
let res = match Arc::try_unwrap(self.0) {
Ok(e) => e,
Err(_) => return Err(fail("failed to acquire rooms ownership on shutdown")),
};
let res = res.into_inner();
// TODO drop all rooms
drop(res);
Ok(())
}
pub async fn get_or_create_room(&mut self, room_id: RoomId) -> Result<RoomHandle> { pub async fn get_or_create_room(&mut self, room_id: RoomId) -> Result<RoomHandle> {
let mut inner = self.0.write().await; let mut inner = self.0.write().await;
if let Some(room_handle) = inner.get_or_load_room(&room_id).await? { if let Some(room_handle) = inner.get_or_load_room(&room_id).await? {
@ -163,9 +175,9 @@ impl RoomHandle {
lock.broadcast_update(update, player_id).await; lock.broadcast_update(update, player_id).await;
} }
pub async fn send_message(&self, player_id: &PlayerId, body: Str) { pub async fn send_message(&self, player_id: &PlayerId, body: Str, created_at: DateTime<Utc>) {
let mut lock = self.0.write().await; let mut lock = self.0.write().await;
let res = lock.send_message(player_id, body).await; let res = lock.send_message(player_id, body, created_at).await;
if let Err(err) = res { if let Err(err) = res {
log::warn!("Failed to send message: {err:?}"); log::warn!("Failed to send message: {err:?}");
} }
@ -208,14 +220,23 @@ struct Room {
storage: Storage, storage: Storage,
} }
impl Room { impl Room {
async fn send_message(&mut self, author_id: &PlayerId, body: Str) -> Result<()> { async fn send_message(&mut self, author_id: &PlayerId, body: Str, created_at: DateTime<Utc>) -> Result<()> {
tracing::info!("Adding a message to room"); tracing::info!("Adding a message to room");
self.storage.insert_message(self.storage_id, self.message_count, &body, &*author_id.as_inner()).await?; self.storage
.insert_message(
self.storage_id,
self.message_count,
&body,
&*author_id.as_inner(),
&created_at,
)
.await?;
self.message_count += 1; self.message_count += 1;
let update = Updates::NewMessage { let update = Updates::NewMessage {
room_id: self.room_id.clone(), room_id: self.room_id.clone(),
author_id: author_id.clone(), author_id: author_id.clone(),
body, body,
created_at,
}; };
self.broadcast_update(update, author_id).await; self.broadcast_update(update, author_id).await;
Ok(()) Ok(())

View File

@ -12,6 +12,7 @@ tokio.workspace = true
prometheus.workspace = true prometheus.workspace = true
futures-util.workspace = true futures-util.workspace = true
nonempty.workspace = true nonempty.workspace = true
chrono.workspace = true
bitflags = "2.4.1" bitflags = "2.4.1"
proto-irc = { path = "../proto-irc" } proto-irc = { path = "../proto-irc" }
sasl = { path = "../sasl" } sasl = { path = "../sasl" }

View File

@ -1,9 +1,10 @@
use bitflags::bitflags; use bitflags::bitflags;
bitflags! { bitflags! {
#[derive(Debug)] #[derive(Debug, Clone, Copy)]
pub struct Capabilities: u32 { pub struct Capabilities: u32 {
const None = 0; const None = 0;
const Sasl = 1 << 0; const Sasl = 1 << 0;
const ServerTime = 1 << 1;
} }
} }

View File

@ -2,6 +2,7 @@ use std::collections::HashMap;
use std::net::SocketAddr; use std::net::SocketAddr;
use anyhow::{anyhow, Result}; use anyhow::{anyhow, Result};
use chrono::SecondsFormat;
use futures_util::future::join_all; use futures_util::future::join_all;
use nonempty::nonempty; use nonempty::nonempty;
use nonempty::NonEmpty; use nonempty::NonEmpty;
@ -13,6 +14,7 @@ use tokio::net::tcp::{ReadHalf, WriteHalf};
use tokio::net::{TcpListener, TcpStream}; use tokio::net::{TcpListener, TcpStream};
use tokio::sync::mpsc::channel; use tokio::sync::mpsc::channel;
use lavina_core::auth::{Authenticator, Verdict};
use lavina_core::player::*; use lavina_core::player::*;
use lavina_core::prelude::*; use lavina_core::prelude::*;
use lavina_core::repo::Storage; use lavina_core::repo::Storage;
@ -24,7 +26,7 @@ use proto_irc::client::{client_message, ClientMessage};
use proto_irc::server::CapSubBody; use proto_irc::server::CapSubBody;
use proto_irc::server::{AwayStatus, ServerMessage, ServerMessageBody}; use proto_irc::server::{AwayStatus, ServerMessage, ServerMessageBody};
use proto_irc::user::PrefixedNick; use proto_irc::user::PrefixedNick;
use proto_irc::{Chan, Recipient}; use proto_irc::{Chan, Recipient, Tag};
use sasl::AuthBody; use sasl::AuthBody;
mod cap; mod cap;
@ -49,6 +51,7 @@ struct RegisteredUser {
*/ */
username: Str, username: Str,
realname: Str, realname: Str,
enabled_capabilities: Capabilities,
} }
async fn handle_socket( async fn handle_socket(
@ -136,7 +139,7 @@ impl RegistrationState {
sender: Some(config.server_name.clone().into()), sender: Some(config.server_name.clone().into()),
body: ServerMessageBody::Cap { body: ServerMessageBody::Cap {
target: self.future_nickname.clone().unwrap_or_else(|| "*".into()), target: self.future_nickname.clone().unwrap_or_else(|| "*".into()),
subcmd: CapSubBody::Ls("sasl=PLAIN".into()), subcmd: CapSubBody::Ls("sasl=PLAIN server-time".into()),
}, },
} }
.write_async(writer) .write_async(writer)
@ -156,16 +159,30 @@ impl RegistrationState {
self.enabled_capabilities |= Capabilities::Sasl; self.enabled_capabilities |= Capabilities::Sasl;
} }
acked.push(cap); acked.push(cap);
} else if &*cap.name == "server-time" {
if cap.to_disable {
self.enabled_capabilities &= !Capabilities::ServerTime;
} else {
self.enabled_capabilities |= Capabilities::ServerTime;
}
acked.push(cap);
} else { } else {
naked.push(cap); naked.push(cap);
} }
} }
let mut ack_body = String::new(); let mut ack_body = String::new();
for cap in acked { if let Some((first, tail)) = acked.split_first() {
if cap.to_disable { if first.to_disable {
ack_body.push('-'); ack_body.push('-');
} }
ack_body += &*cap.name; ack_body += &*first.name;
for cap in tail {
ack_body.push(' ');
if cap.to_disable {
ack_body.push('-');
}
ack_body += &*cap.name;
}
} }
ServerMessage { ServerMessage {
tags: vec![], tags: vec![],
@ -195,6 +212,7 @@ impl RegistrationState {
nickname: nickname.clone(), nickname: nickname.clone(),
username, username,
realname, realname,
enabled_capabilities: self.enabled_capabilities,
}; };
self.finalize_auth(candidate_user, writer, storage, config).await self.finalize_auth(candidate_user, writer, storage, config).await
} }
@ -208,6 +226,7 @@ impl RegistrationState {
nickname: nickname.clone(), nickname: nickname.clone(),
username: username.clone(), username: username.clone(),
realname: realname.clone(), realname: realname.clone(),
enabled_capabilities: self.enabled_capabilities,
}; };
self.finalize_auth(candidate_user, writer, storage, config).await self.finalize_auth(candidate_user, writer, storage, config).await
} else { } else {
@ -224,6 +243,7 @@ impl RegistrationState {
nickname: nickname.clone(), nickname: nickname.clone(),
username, username,
realname, realname,
enabled_capabilities: self.enabled_capabilities,
}; };
self.finalize_auth(candidate_user, writer, storage, config).await self.finalize_auth(candidate_user, writer, storage, config).await
} else { } else {
@ -386,24 +406,13 @@ fn sasl_fail_message(sender: Str, nick: Str, text: Str) -> ServerMessage {
} }
async fn auth_user(storage: &mut Storage, login: &str, plain_password: &str) -> Result<()> { async fn auth_user(storage: &mut Storage, login: &str, plain_password: &str) -> Result<()> {
let stored_user = storage.retrieve_user_by_name(login).await?; let verdict = Authenticator::new(storage).authenticate(login, plain_password).await?;
// TODO properly map these onto protocol messages
let stored_user = match stored_user { match verdict {
Some(u) => u, Verdict::Authenticated => Ok(()),
None => { Verdict::UserNotFound => Err(anyhow!("no user found")),
log::info!("User '{}' not found", login); Verdict::InvalidPassword => Err(anyhow!("incorrect credentials")),
return Err(anyhow!("no user found"));
}
};
let Some(expected_password) = stored_user.password else {
log::info!("Password not defined for user '{}'", login);
return Err(anyhow!("password is not defined"));
};
if expected_password != plain_password {
log::info!("Incorrect password supplied for user '{}'", login);
return Err(anyhow!("passwords do not match"));
} }
Ok(())
} }
async fn handle_registered_socket<'a>( async fn handle_registered_socket<'a>(
@ -587,9 +596,18 @@ async fn handle_update(
author_id, author_id,
room_id, room_id,
body, body,
created_at,
} => { } => {
let mut tags = vec![];
if user.enabled_capabilities.contains(Capabilities::ServerTime) {
let tag = Tag {
key: "time".into(),
value: Some(created_at.to_rfc3339_opts(SecondsFormat::Millis, true).into()),
};
tags.push(tag);
}
ServerMessage { ServerMessage {
tags: vec![], tags,
sender: Some(author_id.as_inner().clone()), sender: Some(author_id.as_inner().clone()),
body: ServerMessageBody::PrivateMessage { body: ServerMessageBody::PrivateMessage {
target: Recipient::Chan(Chan::Global(room_id.as_inner().clone())), target: Recipient::Chan(Chan::Global(room_id.as_inner().clone())),
@ -625,6 +643,32 @@ async fn handle_update(
.await?; .await?;
writer.flush().await? writer.flush().await?
} }
Updates::NewDialogMessage {
sender,
receiver,
body,
created_at,
} => {
let mut tags = vec![];
if user.enabled_capabilities.contains(Capabilities::ServerTime) {
let tag = Tag {
key: "time".into(),
value: Some(created_at.to_rfc3339_opts(SecondsFormat::Millis, true).into()),
};
tags.push(tag);
}
ServerMessage {
tags,
sender: Some(sender.as_inner().clone()),
body: ServerMessageBody::PrivateMessage {
target: Recipient::Nick(receiver.as_inner().clone()),
body: body.clone(),
},
}
.write_async(writer)
.await?;
writer.flush().await?
}
} }
Ok(()) Ok(())
} }
@ -671,6 +715,10 @@ async fn handle_incoming_message(
let room_id = RoomId::from(chan)?; let room_id = RoomId::from(chan)?;
user_handle.send_message(room_id, body).await?; user_handle.send_message(room_id, body).await?;
} }
Recipient::Nick(nick) => {
let receiver = PlayerId::from(nick)?;
user_handle.send_dialog_message(receiver, body).await?;
}
_ => log::warn!("Unsupported target type"), _ => log::warn!("Unsupported target type"),
}, },
ClientMessage::Topic { chan, topic } => { ClientMessage::Topic { chan, topic } => {

View File

@ -1,17 +1,21 @@
use std::io::ErrorKind; use std::io::ErrorKind;
use std::net::SocketAddr;
use std::time::Duration; use std::time::Duration;
use anyhow::{anyhow, Result}; use anyhow::{anyhow, Result};
use chrono::{DateTime, SecondsFormat};
use prometheus::Registry as MetricsRegistry; use prometheus::Registry as MetricsRegistry;
use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader}; use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader};
use tokio::net::tcp::{ReadHalf, WriteHalf}; use tokio::net::tcp::{ReadHalf, WriteHalf};
use tokio::net::TcpStream; use tokio::net::TcpStream;
use lavina_core::auth::Authenticator;
use lavina_core::player::{JoinResult, PlayerId, SendMessageResult};
use lavina_core::repo::{Storage, StorageConfig}; use lavina_core::repo::{Storage, StorageConfig};
use lavina_core::room::RoomId;
use lavina_core::LavinaCore; use lavina_core::LavinaCore;
use projection_irc::APP_VERSION; use projection_irc::APP_VERSION;
use projection_irc::{launch, read_irc_message, RunningServer, ServerConfig}; use projection_irc::{launch, read_irc_message, RunningServer, ServerConfig};
struct TestScope<'a> { struct TestScope<'a> {
reader: BufReader<ReadHalf<'a>>, reader: BufReader<ReadHalf<'a>>,
writer: WriteHalf<'a>, writer: WriteHalf<'a>,
@ -24,7 +28,7 @@ impl<'a> TestScope<'a> {
let (reader, writer) = stream.split(); let (reader, writer) = stream.split();
let reader = BufReader::new(reader); let reader = BufReader::new(reader);
let buffer = vec![]; let buffer = vec![];
let timeout = Duration::from_millis(100); let timeout = Duration::from_millis(1000);
TestScope { TestScope {
reader, reader,
writer, writer,
@ -89,6 +93,11 @@ impl<'a> TestScope<'a> {
Err(_) => Ok(()), Err(_) => Ok(()),
} }
} }
async fn expect_cap_ls(&mut self) -> Result<()> {
self.expect(":testserver CAP * LS :sasl=PLAIN server-time").await?;
Ok(())
}
} }
struct TestServer { struct TestServer {
@ -142,6 +151,13 @@ impl TestServer {
server, server,
}) })
} }
async fn shutdown(self) -> Result<()> {
self.server.terminate().await?;
self.core.shutdown().await?;
self.storage.close().await?;
Ok(())
}
} }
#[tokio::test] #[tokio::test]
@ -151,7 +167,7 @@ async fn scenario_basic() -> Result<()> {
// test scenario // test scenario
server.storage.create_user("tester").await?; server.storage.create_user("tester").await?;
server.storage.set_password("tester", "password").await?; Authenticator::new(&server.storage).set_password("tester", "password").await?;
let mut stream = TcpStream::connect(server.server.addr).await?; let mut stream = TcpStream::connect(server.server.addr).await?;
let mut s = TestScope::new(&mut stream); let mut s = TestScope::new(&mut stream);
@ -169,7 +185,7 @@ async fn scenario_basic() -> Result<()> {
// wrap up // wrap up
server.server.terminate().await?; server.shutdown().await?;
Ok(()) Ok(())
} }
@ -180,7 +196,7 @@ async fn scenario_join_and_reboot() -> Result<()> {
// test scenario // test scenario
server.storage.create_user("tester").await?; server.storage.create_user("tester").await?;
server.storage.set_password("tester", "password").await?; Authenticator::new(&server.storage).set_password("tester", "password").await?;
let mut stream = TcpStream::connect(server.server.addr).await?; let mut stream = TcpStream::connect(server.server.addr).await?;
let mut s = TestScope::new(&mut stream); let mut s = TestScope::new(&mut stream);
@ -239,7 +255,7 @@ async fn scenario_join_and_reboot() -> Result<()> {
// wrap up // wrap up
server.server.terminate().await?; server.shutdown().await?;
Ok(()) Ok(())
} }
@ -250,7 +266,7 @@ async fn scenario_force_join_msg() -> Result<()> {
// test scenario // test scenario
server.storage.create_user("tester").await?; server.storage.create_user("tester").await?;
server.storage.set_password("tester", "password").await?; Authenticator::new(&server.storage).set_password("tester", "password").await?;
let mut stream1 = TcpStream::connect(server.server.addr).await?; let mut stream1 = TcpStream::connect(server.server.addr).await?;
let mut s1 = TestScope::new(&mut stream1); let mut s1 = TestScope::new(&mut stream1);
@ -305,7 +321,7 @@ async fn scenario_force_join_msg() -> Result<()> {
// wrap up // wrap up
server.server.terminate().await?; server.shutdown().await?;
Ok(()) Ok(())
} }
@ -316,9 +332,9 @@ async fn scenario_two_users() -> Result<()> {
// test scenario // test scenario
server.storage.create_user("tester1").await?; server.storage.create_user("tester1").await?;
server.storage.set_password("tester1", "password").await?; Authenticator::new(&server.storage).set_password("tester1", "password").await?;
server.storage.create_user("tester2").await?; server.storage.create_user("tester2").await?;
server.storage.set_password("tester2", "password").await?; Authenticator::new(&server.storage).set_password("tester2", "password").await?;
let mut stream1 = TcpStream::connect(server.server.addr).await?; let mut stream1 = TcpStream::connect(server.server.addr).await?;
let mut s1 = TestScope::new(&mut stream1); let mut s1 = TestScope::new(&mut stream1);
@ -366,6 +382,11 @@ async fn scenario_two_users() -> Result<()> {
s1.expect(":tester1 PART #test").await?; s1.expect(":tester1 PART #test").await?;
// The second user should receive the PART message // The second user should receive the PART message
s2.expect(":tester1 PART #test").await?; s2.expect(":tester1 PART #test").await?;
stream1.shutdown().await?;
stream2.shutdown().await?;
server.shutdown().await?;
Ok(()) Ok(())
} }
@ -380,7 +401,7 @@ async fn scenario_cap_full_negotiation() -> Result<()> {
// test scenario // test scenario
server.storage.create_user("tester").await?; server.storage.create_user("tester").await?;
server.storage.set_password("tester", "password").await?; Authenticator::new(&server.storage).set_password("tester", "password").await?;
let mut stream = TcpStream::connect(server.server.addr).await?; let mut stream = TcpStream::connect(server.server.addr).await?;
let mut s = TestScope::new(&mut stream); let mut s = TestScope::new(&mut stream);
@ -388,7 +409,7 @@ async fn scenario_cap_full_negotiation() -> Result<()> {
s.send("CAP LS 302").await?; s.send("CAP LS 302").await?;
s.send("NICK tester").await?; s.send("NICK tester").await?;
s.send("USER UserName 0 * :Real Name").await?; s.send("USER UserName 0 * :Real Name").await?;
s.expect(":testserver CAP * LS :sasl=PLAIN").await?; s.expect_cap_ls().await?;
s.send("CAP REQ :sasl").await?; s.send("CAP REQ :sasl").await?;
s.expect(":testserver CAP tester ACK :sasl").await?; s.expect(":testserver CAP tester ACK :sasl").await?;
s.send("AUTHENTICATE PLAIN").await?; s.send("AUTHENTICATE PLAIN").await?;
@ -409,7 +430,7 @@ async fn scenario_cap_full_negotiation() -> Result<()> {
// wrap up // wrap up
server.server.terminate().await?; server.shutdown().await?;
Ok(()) Ok(())
} }
@ -420,13 +441,13 @@ async fn scenario_cap_full_negotiation_nick_last() -> Result<()> {
// test scenario // test scenario
server.storage.create_user("tester").await?; server.storage.create_user("tester").await?;
server.storage.set_password("tester", "password").await?; Authenticator::new(&server.storage).set_password("tester", "password").await?;
let mut stream = TcpStream::connect(server.server.addr).await?; let mut stream = TcpStream::connect(server.server.addr).await?;
let mut s = TestScope::new(&mut stream); let mut s = TestScope::new(&mut stream);
s.send("CAP LS 302").await?; s.send("CAP LS 302").await?;
s.expect(":testserver CAP * LS :sasl=PLAIN").await?; s.expect_cap_ls().await?;
s.send("CAP REQ :sasl").await?; s.send("CAP REQ :sasl").await?;
s.expect(":testserver CAP * ACK :sasl").await?; s.expect(":testserver CAP * ACK :sasl").await?;
s.send("AUTHENTICATE PLAIN").await?; s.send("AUTHENTICATE PLAIN").await?;
@ -448,7 +469,7 @@ async fn scenario_cap_full_negotiation_nick_last() -> Result<()> {
// wrap up // wrap up
server.server.terminate().await?; server.shutdown().await?;
Ok(()) Ok(())
} }
@ -459,7 +480,7 @@ async fn scenario_cap_short_negotiation() -> Result<()> {
// test scenario // test scenario
server.storage.create_user("tester").await?; server.storage.create_user("tester").await?;
server.storage.set_password("tester", "password").await?; Authenticator::new(&server.storage).set_password("tester", "password").await?;
let mut stream = TcpStream::connect(server.server.addr).await?; let mut stream = TcpStream::connect(server.server.addr).await?;
let mut s = TestScope::new(&mut stream); let mut s = TestScope::new(&mut stream);
@ -486,7 +507,7 @@ async fn scenario_cap_short_negotiation() -> Result<()> {
// wrap up // wrap up
server.server.terminate().await?; server.shutdown().await?;
Ok(()) Ok(())
} }
@ -497,7 +518,7 @@ async fn scenario_cap_sasl_fail() -> Result<()> {
// test scenario // test scenario
server.storage.create_user("tester").await?; server.storage.create_user("tester").await?;
server.storage.set_password("tester", "password").await?; Authenticator::new(&server.storage).set_password("tester", "password").await?;
let mut stream = TcpStream::connect(server.server.addr).await?; let mut stream = TcpStream::connect(server.server.addr).await?;
let mut s = TestScope::new(&mut stream); let mut s = TestScope::new(&mut stream);
@ -505,7 +526,7 @@ async fn scenario_cap_sasl_fail() -> Result<()> {
s.send("CAP LS 302").await?; s.send("CAP LS 302").await?;
s.send("NICK tester").await?; s.send("NICK tester").await?;
s.send("USER UserName 0 * :Real Name").await?; s.send("USER UserName 0 * :Real Name").await?;
s.expect(":testserver CAP * LS :sasl=PLAIN").await?; s.expect_cap_ls().await?;
s.send("CAP REQ :sasl").await?; s.send("CAP REQ :sasl").await?;
s.expect(":testserver CAP tester ACK :sasl").await?; s.expect(":testserver CAP tester ACK :sasl").await?;
s.send("AUTHENTICATE SHA256").await?; s.send("AUTHENTICATE SHA256").await?;
@ -530,7 +551,7 @@ async fn scenario_cap_sasl_fail() -> Result<()> {
// wrap up // wrap up
server.server.terminate().await?; server.shutdown().await?;
Ok(()) Ok(())
} }
@ -541,7 +562,7 @@ async fn terminate_socket_scenario() -> Result<()> {
// test scenario // test scenario
server.storage.create_user("tester").await?; server.storage.create_user("tester").await?;
server.storage.set_password("tester", "password").await?; Authenticator::new(&server.storage).set_password("tester", "password").await?;
let mut stream = TcpStream::connect(server.server.addr).await?; let mut stream = TcpStream::connect(server.server.addr).await?;
let mut s = TestScope::new(&mut stream); let mut s = TestScope::new(&mut stream);
@ -553,8 +574,142 @@ async fn terminate_socket_scenario() -> Result<()> {
s.send("AUTHENTICATE PLAIN").await?; s.send("AUTHENTICATE PLAIN").await?;
s.expect(":testserver AUTHENTICATE +").await?; s.expect(":testserver AUTHENTICATE +").await?;
server.server.terminate().await?; server.shutdown().await?;
assert_eq!(stream.read_u8().await.unwrap_err().kind(), ErrorKind::UnexpectedEof); assert_eq!(stream.read_u8().await.unwrap_err().kind(), ErrorKind::UnexpectedEof);
Ok(()) Ok(())
} }
#[tokio::test]
async fn server_time_capability() -> Result<()> {
let mut server = TestServer::start().await?;
// test scenario
server.storage.create_user("tester").await?;
Authenticator::new(&server.storage).set_password("tester", "password").await?;
let mut stream = TcpStream::connect(server.server.addr).await?;
let mut s = TestScope::new(&mut stream);
s.send("CAP LS 302").await?;
s.send("NICK tester").await?;
s.send("USER UserName 0 * :Real Name").await?;
s.expect_cap_ls().await?;
s.send("CAP REQ :sasl server-time").await?;
s.expect(":testserver CAP tester ACK :sasl server-time").await?;
s.send("AUTHENTICATE PLAIN").await?;
s.expect(":testserver AUTHENTICATE +").await?;
s.send("AUTHENTICATE dGVzdGVyAHRlc3RlcgBwYXNzd29yZA==").await?; // base64-encoded 'tester\x00tester\x00password'
s.expect(":testserver 900 tester tester tester :You are now logged in as tester").await?;
s.expect(":testserver 903 tester :SASL authentication successful").await?;
s.send("CAP END").await?;
s.expect_server_introduction("tester").await?;
s.expect_nothing().await?;
s.send("JOIN #test").await?;
s.expect(":tester JOIN #test").await?;
s.expect(":testserver 332 tester #test :New room").await?;
s.expect(":testserver 353 tester = #test :tester").await?;
s.expect(":testserver 366 tester #test :End of /NAMES list").await?;
server.storage.create_user("some_guy").await?;
let mut conn = server.core.players.connect_to_player(&PlayerId::from("some_guy").unwrap()).await;
let res = conn.join_room(RoomId::from("test").unwrap()).await?;
let JoinResult::Success(_) = res else {
panic!("Failed to join room");
};
s.expect(":some_guy JOIN #test").await?;
let SendMessageResult::Success(res) = conn.send_message(RoomId::from("test").unwrap(), "Hello".into()).await?
else {
panic!("Failed to send message");
};
s.expect(&format!(
"@time={} :some_guy PRIVMSG #test :Hello",
res.to_rfc3339_opts(SecondsFormat::Millis, true)
))
.await?;
// formatting check
assert_eq!(
DateTime::parse_from_rfc3339(&"2024-01-01T10:00:32.123Z").unwrap().to_rfc3339_opts(SecondsFormat::Millis, true),
"2024-01-01T10:00:32.123Z"
);
s.send("QUIT :Leaving").await?;
s.expect(":testserver ERROR :Leaving the server").await?;
s.expect_eof().await?;
stream.shutdown().await?;
// wrap up
server.shutdown().await?;
Ok(())
}
#[tokio::test]
async fn scenario_two_players_dialog() -> Result<()> {
let mut server = TestServer::start().await?;
// test scenario
server.storage.create_user("tester1").await?;
server.storage.set_password("tester1", "password").await?;
server.storage.create_user("tester2").await?;
server.storage.set_password("tester2", "password").await?;
let mut stream1 = TcpStream::connect(server.server.addr).await?;
let mut s1 = TestScope::new(&mut stream1);
let mut stream2 = TcpStream::connect(server.server.addr).await?;
let mut s2 = TestScope::new(&mut stream2);
s1.send("CAP LS 302").await?;
s1.send("NICK tester1").await?;
s1.send("USER UserName 0 * :Real Name").await?;
s1.expect_cap_ls().await?;
s1.send("CAP REQ :sasl").await?;
s1.expect(":testserver CAP tester1 ACK :sasl").await?;
s1.send("AUTHENTICATE PLAIN").await?;
s1.expect(":testserver AUTHENTICATE +").await?;
s1.send("AUTHENTICATE dGVzdGVyMQB0ZXN0ZXIxAHBhc3N3b3Jk").await?; // base64-encoded 'tester1\x00tester1\x00password'
s1.expect(":testserver 900 tester1 tester1 tester1 :You are now logged in as tester1").await?;
s1.expect(":testserver 903 tester1 :SASL authentication successful").await?;
s1.send("CAP END").await?;
s1.expect_server_introduction("tester1").await?;
s1.expect_nothing().await?;
s2.send("CAP LS 302").await?;
s2.send("NICK tester2").await?;
s2.send("USER UserName 0 * :Real Name").await?;
s2.expect_cap_ls().await?;
s2.send("CAP REQ :sasl").await?;
s2.expect(":testserver CAP tester2 ACK :sasl").await?;
s2.send("AUTHENTICATE PLAIN").await?;
s2.expect(":testserver AUTHENTICATE +").await?;
s2.send("AUTHENTICATE dGVzdGVyMgB0ZXN0ZXIyAHBhc3N3b3Jk").await?; // base64-encoded 'tester2\x00tester2\x00password'
s2.expect(":testserver 900 tester2 tester2 tester2 :You are now logged in as tester2").await?;
s2.expect(":testserver 903 tester2 :SASL authentication successful").await?;
s2.send("CAP END").await?;
s2.expect_server_introduction("tester2").await?;
s2.expect_nothing().await?;
s1.send("PRIVMSG tester2 :Henlo! How are you?").await?;
s1.expect_nothing().await?;
s2.expect(":tester1 PRIVMSG tester2 :Henlo! How are you?").await?;
s2.expect_nothing().await?;
s2.send("PRIVMSG tester1 good").await?;
s2.expect_nothing().await?;
s1.expect(":tester2 PRIVMSG tester1 :good").await?;
s1.expect_nothing().await?;
stream1.shutdown().await?;
stream2.shutdown().await?;
server.shutdown().await?;
Ok(())
}

View File

@ -9,6 +9,7 @@ use std::net::SocketAddr;
use std::path::PathBuf; use std::path::PathBuf;
use std::sync::Arc; use std::sync::Arc;
use anyhow::anyhow;
use futures_util::future::join_all; use futures_util::future::join_all;
use prometheus::Registry as MetricsRegistry; use prometheus::Registry as MetricsRegistry;
use quick_xml::events::{BytesDecl, Event}; use quick_xml::events::{BytesDecl, Event};
@ -21,6 +22,7 @@ use tokio::sync::mpsc::channel;
use tokio_rustls::rustls::{Certificate, PrivateKey}; use tokio_rustls::rustls::{Certificate, PrivateKey};
use tokio_rustls::TlsAcceptor; use tokio_rustls::TlsAcceptor;
use lavina_core::auth::{Authenticator, Verdict};
use lavina_core::player::{PlayerConnection, PlayerId, PlayerRegistry}; use lavina_core::player::{PlayerConnection, PlayerId, PlayerRegistry};
use lavina_core::prelude::*; use lavina_core::prelude::*;
use lavina_core::repo::Storage; use lavina_core::repo::Storage;
@ -300,28 +302,18 @@ async fn socket_auth(
match AuthBody::from_str(&auth.body) { match AuthBody::from_str(&auth.body) {
Ok(logopass) => { Ok(logopass) => {
let name = &logopass.login; let name = &logopass.login;
let stored_user = storage.retrieve_user_by_name(name).await?; let verdict = Authenticator::new(storage).authenticate(name, &logopass.password).await?;
let stored_user = match stored_user {
Some(u) => u,
None => {
log::info!("User '{}' not found", name);
return Err(fail("no user found"));
}
};
// TODO return proper XML errors to the client // TODO return proper XML errors to the client
match verdict {
if stored_user.password.is_none() { Verdict::Authenticated => {}
log::info!("Password not defined for user '{}'", name); Verdict::UserNotFound => {
return Err(fail("password is not defined")); return Err(anyhow!("no user found"));
}
Verdict::InvalidPassword => {
return Err(anyhow!("incorrect credentials"));
}
} }
if stored_user.password.as_deref() != Some(&logopass.password) {
log::info!("Incorrect password supplied for user '{}'", name);
return Err(fail("passwords do not match"));
}
let name: Str = name.as_str().into(); let name: Str = name.as_str().into();
Ok(Authenticated { Ok(Authenticated {
player_id: PlayerId::from(name.clone())?, player_id: PlayerId::from(name.clone())?,
xmpp_name: Name(name.clone()), xmpp_name: Name(name.clone()),

View File

@ -1,5 +1,6 @@
//! Handling of all client2server message stanzas //! Handling of all client2server message stanzas
use lavina_core::player::PlayerId;
use quick_xml::events::Event; use quick_xml::events::Event;
use lavina_core::prelude::*; use lavina_core::prelude::*;
@ -40,6 +41,9 @@ impl<'a> XmppConnection<'a> {
} }
.serialize(output); .serialize(output);
Ok(()) Ok(())
} else if server.0.as_ref() == &*self.hostname && m.r#type == MessageType::Chat {
self.user_handle.send_dialog_message(PlayerId::from(name.0.clone())?, m.body.clone()).await?;
Ok(())
} else { } else {
todo!() todo!()
} }

View File

@ -17,6 +17,7 @@ impl<'a> XmppConnection<'a> {
room_id, room_id,
author_id, author_id,
body, body,
created_at: _,
} => { } => {
Message::<()> { Message::<()> {
to: Some(Jid { to: Some(Jid {
@ -38,6 +39,34 @@ impl<'a> XmppConnection<'a> {
} }
.serialize(output); .serialize(output);
} }
Updates::NewDialogMessage {
sender,
receiver,
body,
created_at: _,
} => {
if receiver == self.user.player_id {
Message::<()> {
to: Some(Jid {
name: Some(self.user.xmpp_name.clone()),
server: Server(self.hostname.clone()),
resource: Some(self.user.xmpp_resource.clone()),
}),
from: Some(Jid {
name: Some(Name(sender.as_inner().clone())),
server: Server(self.hostname.clone()),
resource: Some(Resource(sender.into_inner())),
}),
id: None,
r#type: MessageType::Chat,
lang: None,
subject: None,
body: body.into(),
custom: vec![],
}
.serialize(output);
}
}
_ => {} _ => {}
} }
Ok(()) Ok(())

View File

@ -16,6 +16,7 @@ use tokio_rustls::rustls::client::ServerCertVerifier;
use tokio_rustls::rustls::{ClientConfig, ServerName}; use tokio_rustls::rustls::{ClientConfig, ServerName};
use tokio_rustls::TlsConnector; use tokio_rustls::TlsConnector;
use lavina_core::auth::Authenticator;
use lavina_core::repo::{Storage, StorageConfig}; use lavina_core::repo::{Storage, StorageConfig};
use lavina_core::LavinaCore; use lavina_core::LavinaCore;
use projection_xmpp::{launch, RunningServer, ServerConfig}; use projection_xmpp::{launch, RunningServer, ServerConfig};
@ -149,6 +150,13 @@ impl TestServer {
server, server,
}) })
} }
async fn shutdown(self) -> Result<()> {
self.server.terminate().await?;
self.core.shutdown().await?;
self.storage.close().await?;
Ok(())
}
} }
#[tokio::test] #[tokio::test]
@ -158,7 +166,7 @@ async fn scenario_basic() -> Result<()> {
// test scenario // test scenario
server.storage.create_user("tester").await?; server.storage.create_user("tester").await?;
server.storage.set_password("tester", "password").await?; Authenticator::new(&server.storage).set_password("tester", "password").await?;
let mut stream = TcpStream::connect(server.server.addr).await?; let mut stream = TcpStream::connect(server.server.addr).await?;
let mut s = TestScope::new(&mut stream); let mut s = TestScope::new(&mut stream);
@ -199,7 +207,7 @@ async fn scenario_basic() -> Result<()> {
// wrap up // wrap up
server.server.terminate().await?; server.shutdown().await?;
Ok(()) Ok(())
} }
@ -210,7 +218,7 @@ async fn scenario_basic_without_headers() -> Result<()> {
// test scenario // test scenario
server.storage.create_user("tester").await?; server.storage.create_user("tester").await?;
server.storage.set_password("tester", "password").await?; Authenticator::new(&server.storage).set_password("tester", "password").await?;
let mut stream = TcpStream::connect(server.server.addr).await?; let mut stream = TcpStream::connect(server.server.addr).await?;
let mut s = TestScope::new(&mut stream); let mut s = TestScope::new(&mut stream);
@ -249,7 +257,7 @@ async fn scenario_basic_without_headers() -> Result<()> {
// wrap up // wrap up
server.server.terminate().await?; server.shutdown().await?;
Ok(()) Ok(())
} }
@ -260,7 +268,7 @@ async fn terminate_socket() -> Result<()> {
// test scenario // test scenario
server.storage.create_user("tester").await?; server.storage.create_user("tester").await?;
server.storage.set_password("tester", "password").await?; Authenticator::new(&server.storage).set_password("tester", "password").await?;
let mut stream = TcpStream::connect(server.server.addr).await?; let mut stream = TcpStream::connect(server.server.addr).await?;
let mut s = TestScope::new(&mut stream); let mut s = TestScope::new(&mut stream);
@ -290,7 +298,7 @@ async fn terminate_socket() -> Result<()> {
let mut stream = connector.connect(ServerName::IpAddress(server.server.addr.ip()), stream).await?; let mut stream = connector.connect(ServerName::IpAddress(server.server.addr.ip()), stream).await?;
tracing::info!("TLS connection established"); tracing::info!("TLS connection established");
server.server.terminate().await?; server.shutdown().await?;
assert_eq!(stream.read_u8().await.unwrap_err().kind(), ErrorKind::UnexpectedEof); assert_eq!(stream.read_u8().await.unwrap_err().kind(), ErrorKind::UnexpectedEof);

View File

@ -18,8 +18,19 @@ use tokio::io::{AsyncWrite, AsyncWriteExt};
/// Single message tag value. /// Single message tag value.
#[derive(Clone, Debug, PartialEq, Eq)] #[derive(Clone, Debug, PartialEq, Eq)]
pub struct Tag { pub struct Tag {
key: Str, pub key: Str,
value: Option<u8>, pub value: Option<Str>,
}
impl Tag {
pub async fn write_async(&self, writer: &mut (impl AsyncWrite + Unpin)) -> std::io::Result<()> {
writer.write_all(self.key.as_bytes()).await?;
if let Some(value) = &self.value {
writer.write_all(b"=").await?;
writer.write_all(value.as_bytes()).await?;
}
Ok(())
}
} }
fn receiver(input: &str) -> IResult<&str, &str> { fn receiver(input: &str) -> IResult<&str, &str> {

View File

@ -19,6 +19,13 @@ pub struct ServerMessage {
impl ServerMessage { impl ServerMessage {
pub async fn write_async(&self, writer: &mut (impl AsyncWrite + Unpin)) -> std::io::Result<()> { pub async fn write_async(&self, writer: &mut (impl AsyncWrite + Unpin)) -> std::io::Result<()> {
if !self.tags.is_empty() {
for tag in &self.tags {
writer.write_all(b"@").await?;
tag.write_async(writer).await?;
writer.write_all(b" ").await?;
}
}
match &self.sender { match &self.sender {
Some(ref sender) => { Some(ref sender) => {
writer.write_all(b":").await?; writer.write_all(b":").await?;

View File

@ -79,10 +79,6 @@ mod test {
fn test_fail_if_size_less_then_3() { fn test_fail_if_size_less_then_3() {
let orig = b"login\x00pass"; let orig = b"login\x00pass";
let encoded = general_purpose::STANDARD.encode(orig); let encoded = general_purpose::STANDARD.encode(orig);
let expected = AuthBody {
login: "login".to_string(),
password: "pass".to_string(),
};
let result = AuthBody::from_str(encoded.as_bytes()); let result = AuthBody::from_str(encoded.as_bytes());
assert!(result.is_err()); assert!(result.is_err());
@ -92,10 +88,6 @@ mod test {
fn test_fail_if_size_greater_then_3() { fn test_fail_if_size_greater_then_3() {
let orig = b"first\x00login\x00pass\x00other"; let orig = b"first\x00login\x00pass\x00other";
let encoded = general_purpose::STANDARD.encode(orig); let encoded = general_purpose::STANDARD.encode(orig);
let expected = AuthBody {
login: "login".to_string(),
password: "pass".to_string(),
};
let result = AuthBody::from_str(encoded.as_bytes()); let result = AuthBody::from_str(encoded.as_bytes());
assert!(result.is_err()); assert!(result.is_err());

View File

@ -12,6 +12,7 @@ use prometheus::{Encoder, Registry as MetricsRegistry, TextEncoder};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use tokio::net::TcpListener; use tokio::net::TcpListener;
use lavina_core::auth::{Authenticator, UpdatePasswordResult};
use lavina_core::prelude::*; use lavina_core::prelude::*;
use lavina_core::repo::Storage; use lavina_core::repo::Storage;
use lavina_core::room::RoomRegistry; use lavina_core::room::RoomRegistry;
@ -141,17 +142,20 @@ async fn endpoint_set_password(
*response.status_mut() = StatusCode::BAD_REQUEST; *response.status_mut() = StatusCode::BAD_REQUEST;
return Ok(response); return Ok(response);
}; };
let Some(_) = storage.set_password(&res.player_name, &res.password).await? else { let verdict = Authenticator::new(&storage).set_password(&res.player_name, &res.password).await?;
let payload = ErrorResponse { match verdict {
code: errors::PLAYER_NOT_FOUND, UpdatePasswordResult::PasswordUpdated => {}
message: "No such player exists", UpdatePasswordResult::UserNotFound => {
let payload = ErrorResponse {
code: errors::PLAYER_NOT_FOUND,
message: "No such player exists",
}
.to_body();
let mut response = Response::new(payload);
*response.status_mut() = StatusCode::UNPROCESSABLE_ENTITY;
return Ok(response);
} }
.to_body(); }
let mut response = Response::new(payload);
*response.status_mut() = StatusCode::UNPROCESSABLE_ENTITY;
return Ok(response);
};
log::info!("Password changed for player {}", res.player_name);
let mut response = Response::new(Full::<Bytes>::default()); let mut response = Response::new(Full::<Bytes>::default());
*response.status_mut() = StatusCode::NO_CONTENT; *response.status_mut() = StatusCode::NO_CONTENT;
Ok(response) Ok(response)