diff --git a/crates/lavina-core/src/player.rs b/crates/lavina-core/src/player.rs index 872e8bc..234719f 100644 --- a/crates/lavina-core/src/player.rs +++ b/crates/lavina-core/src/player.rs @@ -8,7 +8,6 @@ //! A player actor is a serial handler of commands from a single player. It is preferable to run all per-player validations in the player actor, //! so that they don't overload the room actor. use std::collections::{HashMap, HashSet}; -use std::ops::Deref; use std::sync::Arc; use chrono::{DateTime, Utc}; diff --git a/crates/lavina-core/src/repo/mod.rs b/crates/lavina-core/src/repo/mod.rs index 0089c22..059f332 100644 --- a/crates/lavina-core/src/repo/mod.rs +++ b/crates/lavina-core/src/repo/mod.rs @@ -3,11 +3,9 @@ use std::str::FromStr; use std::sync::Arc; -use anyhow::anyhow; -use chrono::{DateTime, Utc}; use serde::Deserialize; use sqlx::sqlite::SqliteConnectOptions; -use sqlx::{ConnectOptions, Connection, Execute, FromRow, Sqlite, SqliteConnection, Transaction}; +use sqlx::{ConnectOptions, Connection, SqliteConnection}; use tokio::sync::Mutex; use crate::prelude::*; @@ -40,98 +38,6 @@ impl Storage { Ok(Storage { conn }) } - #[tracing::instrument(skip(self), name = "Storage::retrieve_user_by_name")] - pub async fn retrieve_user_by_name(&self, name: &str) -> Result> { - let mut executor = self.conn.lock().await; - let res = sqlx::query_as( - "select u.id, u.name, c.password, a.hash as argon2_hash - from users u left join challenges_plain_password c on u.id = c.user_id - left join challenges_argon2_password a on u.id = a.user_id - where u.name = ?;", - ) - .bind(name) - .fetch_optional(&mut *executor) - .await?; - - Ok(res) - } - - #[tracing::instrument(skip(self), name = "Storage::check_user_existence")] - pub async fn check_user_existence(&self, username: &str) -> Result { - let mut executor = self.conn.lock().await; - let result: Option<(String,)> = sqlx::query_as("select name from users where name = ?;") - .bind(username) - .fetch_optional(&mut *executor) - .await?; - - Ok(result.is_some()) - } - - #[tracing::instrument(skip(self), name = "Storage::retrieve_room_by_name")] - pub async fn retrieve_room_by_name(&self, name: &str) -> Result> { - let mut executor = self.conn.lock().await; - let res = sqlx::query_as( - "select id, name, topic, message_count - from rooms - where name = ?;", - ) - .bind(name) - .fetch_optional(&mut *executor) - .await?; - - Ok(res) - } - - #[tracing::instrument(skip(self, topic), name = "Storage::create_new_room")] - pub async fn create_new_room(&mut self, name: &str, topic: &str) -> Result { - let mut executor = self.conn.lock().await; - let (id,): (u32,) = sqlx::query_as( - "insert into rooms(name, topic) - values (?, ?) - returning id;", - ) - .bind(name) - .bind(topic) - .fetch_one(&mut *executor) - .await?; - - Ok(id) - } - - #[tracing::instrument(skip(self, content, created_at), name = "Storage::insert_message")] - pub async fn insert_message( - &mut self, - room_id: u32, - id: u32, - content: &str, - author_id: &str, - created_at: &DateTime, - ) -> Result<()> { - let mut executor = self.conn.lock().await; - let res: Option<(u32,)> = sqlx::query_as("select id from users where name = ?;") - .bind(author_id) - .fetch_optional(&mut *executor) - .await?; - let Some((author_id,)) = res else { - return Err(anyhow!("No such user")); - }; - sqlx::query( - "insert into messages(room_id, id, content, author_id, created_at) - values (?, ?, ?, ?, ?); - update rooms set message_count = message_count + 1 where id = ?;", - ) - .bind(room_id) - .bind(id) - .bind(content) - .bind(author_id) - .bind(created_at.to_string()) - .bind(room_id) - .execute(&mut *executor) - .await?; - - Ok(()) - } - pub async fn close(self) -> Result<()> { let res = match Arc::try_unwrap(self.conn) { Ok(e) => e, @@ -141,66 +47,4 @@ impl Storage { res.close().await?; Ok(()) } - - #[tracing::instrument(skip(self), name = "Storage::create_user")] - pub async fn create_user(&self, name: &str) -> Result<()> { - let query = sqlx::query( - "insert into users(name) - values (?);", - ) - .bind(name); - let mut executor = self.conn.lock().await; - query.execute(&mut *executor).await?; - - Ok(()) - } - - #[tracing::instrument(skip(self, pwd), name = "Storage::set_password")] - pub async fn set_password<'a>(&'a self, name: &'a str, pwd: &'a str) -> Result> { - async fn inner(txn: &mut Transaction<'_, Sqlite>, name: &str, pwd: &str) -> Result> { - let id: Option<(u32,)> = sqlx::query_as("select * from users where name = ? limit 1;") - .bind(name) - .fetch_optional(&mut **txn) - .await?; - let Some((id,)) = id else { - return Ok(None); - }; - sqlx::query("insert or replace into challenges_plain_password(user_id, password) values (?, ?);") - .bind(id) - .bind(pwd) - .execute(&mut **txn) - .await?; - Ok(Some(())) - } - - let mut executor = self.conn.lock().await; - let mut tx = executor.begin().await?; - let res = inner(&mut tx, name, pwd).await; - match res { - Ok(e) => { - tx.commit().await?; - Ok(e) - } - Err(e) => { - tx.rollback().await?; - Err(e) - } - } - } -} - -#[derive(FromRow)] -pub struct StoredUser { - pub id: u32, - pub name: String, - pub password: Option, - pub argon2_hash: Option>, -} - -#[derive(FromRow)] -pub struct StoredRoom { - pub id: u32, - pub name: String, - pub topic: String, - pub message_count: u32, } diff --git a/crates/lavina-core/src/repo/room.rs b/crates/lavina-core/src/repo/room.rs index 95a4de6..5ca5168 100644 --- a/crates/lavina-core/src/repo/room.rs +++ b/crates/lavina-core/src/repo/room.rs @@ -1,9 +1,84 @@ -use crate::player::PlayerConnection; -use anyhow::Result; +use anyhow::{anyhow, Result}; +use chrono::{DateTime, Utc}; +use sqlx::FromRow; use crate::repo::Storage; +use crate::room::RoomId; + +#[derive(FromRow)] +pub struct StoredRoom { + pub id: u32, + pub name: String, + pub topic: String, + pub message_count: u32, +} impl Storage { + #[tracing::instrument(skip(self), name = "Storage::retrieve_room_by_name")] + pub async fn retrieve_room_by_name(&self, name: &str) -> Result> { + let mut executor = self.conn.lock().await; + let res = sqlx::query_as( + "select id, name, topic, message_count + from rooms + where name = ?;", + ) + .bind(name) + .fetch_optional(&mut *executor) + .await?; + + Ok(res) + } + + #[tracing::instrument(skip(self, topic), name = "Storage::create_new_room")] + pub async fn create_new_room(&mut self, name: &str, topic: &str) -> Result { + let mut executor = self.conn.lock().await; + let (id,): (u32,) = sqlx::query_as( + "insert into rooms(name, topic) + values (?, ?) + returning id;", + ) + .bind(name) + .bind(topic) + .fetch_one(&mut *executor) + .await?; + + Ok(id) + } + + #[tracing::instrument(skip(self, content, created_at), name = "Storage::insert_room_message")] + pub async fn insert_room_message( + &mut self, + room_id: u32, + id: u32, + content: &str, + author_id: &str, + created_at: &DateTime, + ) -> Result<()> { + let mut executor = self.conn.lock().await; + let res: Option<(u32,)> = sqlx::query_as("select id from users where name = ?;") + .bind(author_id) + .fetch_optional(&mut *executor) + .await?; + let Some((author_id,)) = res else { + return Err(anyhow!("No such user")); + }; + sqlx::query( + "insert into messages(room_id, id, content, author_id, created_at) + values (?, ?, ?, ?, ?); + update rooms set message_count = message_count + 1 where id = ?;", + ) + .bind(room_id) + .bind(id) + .bind(content) + .bind(author_id) + .bind(created_at.to_string()) + .bind(room_id) + .execute(&mut *executor) + .await?; + + Ok(()) + } + #[tracing::instrument(skip(self), name = "Storage::is_room_member")] pub async fn is_room_member(&self, room_id: u32, player_id: u32) -> Result { let mut executor = self.conn.lock().await; @@ -70,6 +145,7 @@ impl Storage { Ok(()) } + #[tracing::instrument(skip(self), name = "Storage::create_or_retrieve_room_id_by_name")] pub async fn create_or_retrieve_room_id_by_name(&self, name: &str) -> Result { // TODO we don't need any info except the name on non-owning nodes, should remove stubs here let mut executor = self.conn.lock().await; @@ -85,4 +161,19 @@ impl Storage { Ok(res.0) } + + #[tracing::instrument(skip(self), name = "Storage::get_rooms_of_a_user")] + pub async fn get_rooms_of_a_user(&self, user_id: u32) -> Result> { + let mut executor = self.conn.lock().await; + let res: Vec<(String,)> = sqlx::query_as( + "select r.name + from memberships m inner join rooms r on m.room_id = r.id + where m.user_id = ?;", + ) + .bind(user_id) + .fetch_all(&mut *executor) + .await?; + + res.into_iter().map(|(room_id,)| RoomId::from(room_id)).collect() + } } diff --git a/crates/lavina-core/src/repo/user.rs b/crates/lavina-core/src/repo/user.rs index a27c245..a5471d0 100644 --- a/crates/lavina-core/src/repo/user.rs +++ b/crates/lavina-core/src/repo/user.rs @@ -1,9 +1,91 @@ use anyhow::Result; +use sqlx::{Connection, FromRow, Sqlite, Transaction}; use crate::repo::Storage; -use crate::room::RoomId; + +#[derive(FromRow)] +pub struct StoredUser { + pub id: u32, + pub name: String, + pub password: Option, + pub argon2_hash: Option>, +} impl Storage { + #[tracing::instrument(skip(self), name = "Storage::retrieve_user_by_name")] + pub async fn retrieve_user_by_name(&self, name: &str) -> Result> { + let mut executor = self.conn.lock().await; + let res = sqlx::query_as( + "select u.id, u.name, c.password, a.hash as argon2_hash + from users u left join challenges_plain_password c on u.id = c.user_id + left join challenges_argon2_password a on u.id = a.user_id + where u.name = ?;", + ) + .bind(name) + .fetch_optional(&mut *executor) + .await?; + + Ok(res) + } + + #[tracing::instrument(skip(self), name = "Storage::check_user_existence")] + pub async fn check_user_existence(&self, username: &str) -> Result { + let mut executor = self.conn.lock().await; + let result: Option<(String,)> = sqlx::query_as("select name from users where name = ?;") + .bind(username) + .fetch_optional(&mut *executor) + .await?; + + Ok(result.is_some()) + } + + #[tracing::instrument(skip(self), name = "Storage::create_user")] + pub async fn create_user(&self, name: &str) -> Result<()> { + let query = sqlx::query( + "insert into users(name) + values (?);", + ) + .bind(name); + let mut executor = self.conn.lock().await; + query.execute(&mut *executor).await?; + + Ok(()) + } + + #[tracing::instrument(skip(self, pwd), name = "Storage::set_password")] + pub async fn set_password(&self, name: &str, pwd: &str) -> Result> { + async fn inner(txn: &mut Transaction<'_, Sqlite>, name: &str, pwd: &str) -> Result> { + let id: Option<(u32,)> = sqlx::query_as("select * from users where name = ? limit 1;") + .bind(name) + .fetch_optional(&mut **txn) + .await?; + let Some((id,)) = id else { + return Ok(None); + }; + sqlx::query("insert or replace into challenges_plain_password(user_id, password) values (?, ?);") + .bind(id) + .bind(pwd) + .execute(&mut **txn) + .await?; + Ok(Some(())) + } + + let mut executor = self.conn.lock().await; + let mut tx = executor.begin().await?; + let res = inner(&mut tx, name, pwd).await; + match res { + Ok(e) => { + tx.commit().await?; + Ok(e) + } + Err(e) => { + tx.rollback().await?; + Err(e) + } + } + } + + #[tracing::instrument(skip(self), name = "Storage::retrieve_user_id_by_name")] pub async fn retrieve_user_id_by_name(&self, name: &str) -> Result> { let mut executor = self.conn.lock().await; let res: Option<(u32,)> = sqlx::query_as("select u.id from users u where u.name = ?;") @@ -14,6 +96,7 @@ impl Storage { Ok(res.map(|(id,)| id)) } + #[tracing::instrument(skip(self), name = "Storage::create_or_retrieve_user_id_by_name")] pub async fn create_or_retrieve_user_id_by_name(&self, name: &str) -> Result { let mut executor = self.conn.lock().await; let res: (u32,) = sqlx::query_as( @@ -28,18 +111,4 @@ impl Storage { Ok(res.0) } - - pub async fn get_rooms_of_a_user(&self, user_id: u32) -> Result> { - let mut executor = self.conn.lock().await; - let res: Vec<(String,)> = sqlx::query_as( - "select r.name - from memberships m inner join rooms r on m.room_id = r.id - where m.user_id = ?;", - ) - .bind(user_id) - .fetch_all(&mut *executor) - .await?; - - res.into_iter().map(|(room_id,)| RoomId::from(room_id)).collect() - } } diff --git a/crates/lavina-core/src/room.rs b/crates/lavina-core/src/room.rs index bbffd15..54d550e 100644 --- a/crates/lavina-core/src/room.rs +++ b/crates/lavina-core/src/room.rs @@ -243,7 +243,7 @@ impl Room { async fn send_message(&mut self, author_id: &PlayerId, body: Str, created_at: DateTime) -> Result<()> { tracing::info!("Adding a message to room"); self.storage - .insert_message( + .insert_room_message( self.storage_id, self.message_count, &body, diff --git a/crates/projection-irc/src/handler.rs b/crates/projection-irc/src/handler.rs index 2e7e0da..15a1e56 100644 --- a/crates/projection-irc/src/handler.rs +++ b/crates/projection-irc/src/handler.rs @@ -1,7 +1,10 @@ -use lavina_core::{player::PlayerConnection, prelude::Str, LavinaCore}; use std::future::Future; + use tokio::io::AsyncWrite; +use lavina_core::player::PlayerConnection; +use lavina_core::prelude::Str; + pub struct IrcConnection<'a, T: AsyncWrite + Unpin> { pub server_name: Str, /// client is nick of requester diff --git a/crates/projection-irc/tests/lib.rs b/crates/projection-irc/tests/lib.rs index 9b29646..db6b55f 100644 --- a/crates/projection-irc/tests/lib.rs +++ b/crates/projection-irc/tests/lib.rs @@ -114,8 +114,8 @@ impl TestServer { listen_on: "127.0.0.1:0".parse().unwrap(), server_name: "testserver".into(), }; - let mut metrics = MetricsRegistry::new(); - let mut storage = Storage::open(StorageConfig { + let metrics = MetricsRegistry::new(); + let storage = Storage::open(StorageConfig { db_path: ":memory:".into(), }) .await?; @@ -153,7 +153,7 @@ impl TestServer { let TestServer { metrics: _, storage, - mut core, + core, server, } = self; server.terminate().await?; @@ -179,7 +179,7 @@ impl TestServer { #[tokio::test] async fn scenario_basic() -> Result<()> { - let mut server = TestServer::start().await?; + let server = TestServer::start().await?; // test scenario @@ -208,7 +208,7 @@ async fn scenario_basic() -> Result<()> { #[tokio::test] async fn scenario_join_and_reboot() -> Result<()> { - let mut server = TestServer::start().await?; + let server = TestServer::start().await?; // test scenario @@ -278,7 +278,7 @@ async fn scenario_join_and_reboot() -> Result<()> { #[tokio::test] async fn scenario_force_join_msg() -> Result<()> { - let mut server = TestServer::start().await?; + let server = TestServer::start().await?; // test scenario @@ -344,7 +344,7 @@ async fn scenario_force_join_msg() -> Result<()> { #[tokio::test] async fn scenario_two_users() -> Result<()> { - let mut server = TestServer::start().await?; + let server = TestServer::start().await?; // test scenario @@ -420,7 +420,7 @@ AUTHENTICATE doc: https://modern.ircdocs.horse/#authenticate-message */ #[tokio::test] async fn scenario_cap_full_negotiation() -> Result<()> { - let mut server = TestServer::start().await?; + let server = TestServer::start().await?; // test scenario @@ -460,7 +460,7 @@ async fn scenario_cap_full_negotiation() -> Result<()> { #[tokio::test] async fn scenario_cap_full_negotiation_nick_last() -> Result<()> { - let mut server = TestServer::start().await?; + let server = TestServer::start().await?; // test scenario @@ -499,7 +499,7 @@ async fn scenario_cap_full_negotiation_nick_last() -> Result<()> { #[tokio::test] async fn scenario_cap_short_negotiation() -> Result<()> { - let mut server = TestServer::start().await?; + let server = TestServer::start().await?; // test scenario @@ -537,7 +537,7 @@ async fn scenario_cap_short_negotiation() -> Result<()> { #[tokio::test] async fn scenario_cap_sasl_fail() -> Result<()> { - let mut server = TestServer::start().await?; + let server = TestServer::start().await?; // test scenario @@ -581,7 +581,7 @@ async fn scenario_cap_sasl_fail() -> Result<()> { #[tokio::test] async fn terminate_socket_scenario() -> Result<()> { - let mut server = TestServer::start().await?; + let server = TestServer::start().await?; // test scenario @@ -606,7 +606,7 @@ async fn terminate_socket_scenario() -> Result<()> { #[tokio::test] async fn server_time_capability() -> Result<()> { - let mut server = TestServer::start().await?; + let server = TestServer::start().await?; // test scenario @@ -675,7 +675,7 @@ async fn server_time_capability() -> Result<()> { #[tokio::test] async fn scenario_two_players_dialog() -> Result<()> { - let mut server = TestServer::start().await?; + let server = TestServer::start().await?; // test scenario diff --git a/crates/projection-xmpp/src/presence.rs b/crates/projection-xmpp/src/presence.rs index c2d41b2..5dc9498 100644 --- a/crates/projection-xmpp/src/presence.rs +++ b/crates/projection-xmpp/src/presence.rs @@ -86,9 +86,9 @@ impl<'a> XmppConnection<'a> { #[cfg(test)] mod tests { use crate::testkit::{expect_user_authenticated, TestServer}; - use crate::{Authenticated, XmppConnection}; + use crate::Authenticated; use lavina_core::player::PlayerId; - use proto_xmpp::bind::{BindRequest, BindResponse, Jid, Name, Resource, Server}; + use proto_xmpp::bind::{Jid, Name, Resource, Server}; use proto_xmpp::client::Presence; #[tokio::test] diff --git a/crates/projection-xmpp/tests/lib.rs b/crates/projection-xmpp/tests/lib.rs index 411737c..ef5e979 100644 --- a/crates/projection-xmpp/tests/lib.rs +++ b/crates/projection-xmpp/tests/lib.rs @@ -23,7 +23,6 @@ use lavina_core::clustering::{ClusterConfig, ClusterMetadata}; use lavina_core::repo::{Storage, StorageConfig}; use lavina_core::LavinaCore; use projection_xmpp::{launch, RunningServer, ServerConfig}; -use proto_xmpp::xml::{Continuation, FromXml, Parser}; fn element_name<'a>(local_name: &LocalName<'a>) -> &'a str { from_utf8(local_name.into_inner()).unwrap() @@ -190,7 +189,7 @@ impl TestServer { #[tokio::test] async fn scenario_basic() -> Result<()> { - let mut server = TestServer::start().await?; + let server = TestServer::start().await?; // test scenario @@ -258,7 +257,7 @@ async fn scenario_basic() -> Result<()> { #[tokio::test] async fn scenario_wrong_password() -> Result<()> { - let mut server = TestServer::start().await?; + let server = TestServer::start().await?; // test scenario @@ -313,7 +312,7 @@ async fn scenario_wrong_password() -> Result<()> { #[tokio::test] async fn scenario_basic_without_headers() -> Result<()> { - let mut server = TestServer::start().await?; + let server = TestServer::start().await?; // test scenario @@ -359,7 +358,7 @@ async fn scenario_basic_without_headers() -> Result<()> { #[tokio::test] async fn terminate_socket() -> Result<()> { - let mut server = TestServer::start().await?; + let server = TestServer::start().await?; // test scenario @@ -399,7 +398,7 @@ async fn terminate_socket() -> Result<()> { #[tokio::test] async fn test_message_archive_request() -> Result<()> { - let mut server = TestServer::start().await?; + let server = TestServer::start().await?; // test scenario diff --git a/crates/proto-xmpp/src/mam.rs b/crates/proto-xmpp/src/mam.rs index c8151f2..7c4a9f9 100644 --- a/crates/proto-xmpp/src/mam.rs +++ b/crates/proto-xmpp/src/mam.rs @@ -1,7 +1,6 @@ use anyhow::{anyhow, Result}; use quick_xml::events::{BytesEnd, BytesStart, BytesText, Event}; use quick_xml::name::{Namespace, ResolveResult}; -use std::io::Read; use crate::xml::*;