move repo methods into submodules and clean up warnings

This commit is contained in:
Nikita Vilunov 2024-05-10 23:50:34 +02:00
parent 6749103726
commit d1a72e7978
10 changed files with 204 additions and 200 deletions

View File

@ -8,7 +8,6 @@
//! A player actor is a serial handler of commands from a single player. It is preferable to run all per-player validations in the player actor,
//! so that they don't overload the room actor.
use std::collections::{HashMap, HashSet};
use std::ops::Deref;
use std::sync::Arc;
use chrono::{DateTime, Utc};

View File

@ -3,11 +3,9 @@
use std::str::FromStr;
use std::sync::Arc;
use anyhow::anyhow;
use chrono::{DateTime, Utc};
use serde::Deserialize;
use sqlx::sqlite::SqliteConnectOptions;
use sqlx::{ConnectOptions, Connection, Execute, FromRow, Sqlite, SqliteConnection, Transaction};
use sqlx::{ConnectOptions, Connection, SqliteConnection};
use tokio::sync::Mutex;
use crate::prelude::*;
@ -40,98 +38,6 @@ impl Storage {
Ok(Storage { conn })
}
#[tracing::instrument(skip(self), name = "Storage::retrieve_user_by_name")]
pub async fn retrieve_user_by_name(&self, name: &str) -> Result<Option<StoredUser>> {
let mut executor = self.conn.lock().await;
let res = sqlx::query_as(
"select u.id, u.name, c.password, a.hash as argon2_hash
from users u left join challenges_plain_password c on u.id = c.user_id
left join challenges_argon2_password a on u.id = a.user_id
where u.name = ?;",
)
.bind(name)
.fetch_optional(&mut *executor)
.await?;
Ok(res)
}
#[tracing::instrument(skip(self), name = "Storage::check_user_existence")]
pub async fn check_user_existence(&self, username: &str) -> Result<bool> {
let mut executor = self.conn.lock().await;
let result: Option<(String,)> = sqlx::query_as("select name from users where name = ?;")
.bind(username)
.fetch_optional(&mut *executor)
.await?;
Ok(result.is_some())
}
#[tracing::instrument(skip(self), name = "Storage::retrieve_room_by_name")]
pub async fn retrieve_room_by_name(&self, name: &str) -> Result<Option<StoredRoom>> {
let mut executor = self.conn.lock().await;
let res = sqlx::query_as(
"select id, name, topic, message_count
from rooms
where name = ?;",
)
.bind(name)
.fetch_optional(&mut *executor)
.await?;
Ok(res)
}
#[tracing::instrument(skip(self, topic), name = "Storage::create_new_room")]
pub async fn create_new_room(&mut self, name: &str, topic: &str) -> Result<u32> {
let mut executor = self.conn.lock().await;
let (id,): (u32,) = sqlx::query_as(
"insert into rooms(name, topic)
values (?, ?)
returning id;",
)
.bind(name)
.bind(topic)
.fetch_one(&mut *executor)
.await?;
Ok(id)
}
#[tracing::instrument(skip(self, content, created_at), name = "Storage::insert_message")]
pub async fn insert_message(
&mut self,
room_id: u32,
id: u32,
content: &str,
author_id: &str,
created_at: &DateTime<Utc>,
) -> Result<()> {
let mut executor = self.conn.lock().await;
let res: Option<(u32,)> = sqlx::query_as("select id from users where name = ?;")
.bind(author_id)
.fetch_optional(&mut *executor)
.await?;
let Some((author_id,)) = res else {
return Err(anyhow!("No such user"));
};
sqlx::query(
"insert into messages(room_id, id, content, author_id, created_at)
values (?, ?, ?, ?, ?);
update rooms set message_count = message_count + 1 where id = ?;",
)
.bind(room_id)
.bind(id)
.bind(content)
.bind(author_id)
.bind(created_at.to_string())
.bind(room_id)
.execute(&mut *executor)
.await?;
Ok(())
}
pub async fn close(self) -> Result<()> {
let res = match Arc::try_unwrap(self.conn) {
Ok(e) => e,
@ -141,66 +47,4 @@ impl Storage {
res.close().await?;
Ok(())
}
#[tracing::instrument(skip(self), name = "Storage::create_user")]
pub async fn create_user(&self, name: &str) -> Result<()> {
let query = sqlx::query(
"insert into users(name)
values (?);",
)
.bind(name);
let mut executor = self.conn.lock().await;
query.execute(&mut *executor).await?;
Ok(())
}
#[tracing::instrument(skip(self, pwd), name = "Storage::set_password")]
pub async fn set_password<'a>(&'a self, name: &'a str, pwd: &'a str) -> Result<Option<()>> {
async fn inner(txn: &mut Transaction<'_, Sqlite>, name: &str, pwd: &str) -> Result<Option<()>> {
let id: Option<(u32,)> = sqlx::query_as("select * from users where name = ? limit 1;")
.bind(name)
.fetch_optional(&mut **txn)
.await?;
let Some((id,)) = id else {
return Ok(None);
};
sqlx::query("insert or replace into challenges_plain_password(user_id, password) values (?, ?);")
.bind(id)
.bind(pwd)
.execute(&mut **txn)
.await?;
Ok(Some(()))
}
let mut executor = self.conn.lock().await;
let mut tx = executor.begin().await?;
let res = inner(&mut tx, name, pwd).await;
match res {
Ok(e) => {
tx.commit().await?;
Ok(e)
}
Err(e) => {
tx.rollback().await?;
Err(e)
}
}
}
}
#[derive(FromRow)]
pub struct StoredUser {
pub id: u32,
pub name: String,
pub password: Option<String>,
pub argon2_hash: Option<Box<str>>,
}
#[derive(FromRow)]
pub struct StoredRoom {
pub id: u32,
pub name: String,
pub topic: String,
pub message_count: u32,
}

View File

@ -1,9 +1,84 @@
use crate::player::PlayerConnection;
use anyhow::Result;
use anyhow::{anyhow, Result};
use chrono::{DateTime, Utc};
use sqlx::FromRow;
use crate::repo::Storage;
use crate::room::RoomId;
#[derive(FromRow)]
pub struct StoredRoom {
pub id: u32,
pub name: String,
pub topic: String,
pub message_count: u32,
}
impl Storage {
#[tracing::instrument(skip(self), name = "Storage::retrieve_room_by_name")]
pub async fn retrieve_room_by_name(&self, name: &str) -> Result<Option<StoredRoom>> {
let mut executor = self.conn.lock().await;
let res = sqlx::query_as(
"select id, name, topic, message_count
from rooms
where name = ?;",
)
.bind(name)
.fetch_optional(&mut *executor)
.await?;
Ok(res)
}
#[tracing::instrument(skip(self, topic), name = "Storage::create_new_room")]
pub async fn create_new_room(&mut self, name: &str, topic: &str) -> Result<u32> {
let mut executor = self.conn.lock().await;
let (id,): (u32,) = sqlx::query_as(
"insert into rooms(name, topic)
values (?, ?)
returning id;",
)
.bind(name)
.bind(topic)
.fetch_one(&mut *executor)
.await?;
Ok(id)
}
#[tracing::instrument(skip(self, content, created_at), name = "Storage::insert_room_message")]
pub async fn insert_room_message(
&mut self,
room_id: u32,
id: u32,
content: &str,
author_id: &str,
created_at: &DateTime<Utc>,
) -> Result<()> {
let mut executor = self.conn.lock().await;
let res: Option<(u32,)> = sqlx::query_as("select id from users where name = ?;")
.bind(author_id)
.fetch_optional(&mut *executor)
.await?;
let Some((author_id,)) = res else {
return Err(anyhow!("No such user"));
};
sqlx::query(
"insert into messages(room_id, id, content, author_id, created_at)
values (?, ?, ?, ?, ?);
update rooms set message_count = message_count + 1 where id = ?;",
)
.bind(room_id)
.bind(id)
.bind(content)
.bind(author_id)
.bind(created_at.to_string())
.bind(room_id)
.execute(&mut *executor)
.await?;
Ok(())
}
#[tracing::instrument(skip(self), name = "Storage::is_room_member")]
pub async fn is_room_member(&self, room_id: u32, player_id: u32) -> Result<bool> {
let mut executor = self.conn.lock().await;
@ -70,6 +145,7 @@ impl Storage {
Ok(())
}
#[tracing::instrument(skip(self), name = "Storage::create_or_retrieve_room_id_by_name")]
pub async fn create_or_retrieve_room_id_by_name(&self, name: &str) -> Result<u32> {
// TODO we don't need any info except the name on non-owning nodes, should remove stubs here
let mut executor = self.conn.lock().await;
@ -85,4 +161,19 @@ impl Storage {
Ok(res.0)
}
#[tracing::instrument(skip(self), name = "Storage::get_rooms_of_a_user")]
pub async fn get_rooms_of_a_user(&self, user_id: u32) -> Result<Vec<RoomId>> {
let mut executor = self.conn.lock().await;
let res: Vec<(String,)> = sqlx::query_as(
"select r.name
from memberships m inner join rooms r on m.room_id = r.id
where m.user_id = ?;",
)
.bind(user_id)
.fetch_all(&mut *executor)
.await?;
res.into_iter().map(|(room_id,)| RoomId::from(room_id)).collect()
}
}

View File

@ -1,9 +1,91 @@
use anyhow::Result;
use sqlx::{Connection, FromRow, Sqlite, Transaction};
use crate::repo::Storage;
use crate::room::RoomId;
#[derive(FromRow)]
pub struct StoredUser {
pub id: u32,
pub name: String,
pub password: Option<String>,
pub argon2_hash: Option<Box<str>>,
}
impl Storage {
#[tracing::instrument(skip(self), name = "Storage::retrieve_user_by_name")]
pub async fn retrieve_user_by_name(&self, name: &str) -> Result<Option<StoredUser>> {
let mut executor = self.conn.lock().await;
let res = sqlx::query_as(
"select u.id, u.name, c.password, a.hash as argon2_hash
from users u left join challenges_plain_password c on u.id = c.user_id
left join challenges_argon2_password a on u.id = a.user_id
where u.name = ?;",
)
.bind(name)
.fetch_optional(&mut *executor)
.await?;
Ok(res)
}
#[tracing::instrument(skip(self), name = "Storage::check_user_existence")]
pub async fn check_user_existence(&self, username: &str) -> Result<bool> {
let mut executor = self.conn.lock().await;
let result: Option<(String,)> = sqlx::query_as("select name from users where name = ?;")
.bind(username)
.fetch_optional(&mut *executor)
.await?;
Ok(result.is_some())
}
#[tracing::instrument(skip(self), name = "Storage::create_user")]
pub async fn create_user(&self, name: &str) -> Result<()> {
let query = sqlx::query(
"insert into users(name)
values (?);",
)
.bind(name);
let mut executor = self.conn.lock().await;
query.execute(&mut *executor).await?;
Ok(())
}
#[tracing::instrument(skip(self, pwd), name = "Storage::set_password")]
pub async fn set_password(&self, name: &str, pwd: &str) -> Result<Option<()>> {
async fn inner(txn: &mut Transaction<'_, Sqlite>, name: &str, pwd: &str) -> Result<Option<()>> {
let id: Option<(u32,)> = sqlx::query_as("select * from users where name = ? limit 1;")
.bind(name)
.fetch_optional(&mut **txn)
.await?;
let Some((id,)) = id else {
return Ok(None);
};
sqlx::query("insert or replace into challenges_plain_password(user_id, password) values (?, ?);")
.bind(id)
.bind(pwd)
.execute(&mut **txn)
.await?;
Ok(Some(()))
}
let mut executor = self.conn.lock().await;
let mut tx = executor.begin().await?;
let res = inner(&mut tx, name, pwd).await;
match res {
Ok(e) => {
tx.commit().await?;
Ok(e)
}
Err(e) => {
tx.rollback().await?;
Err(e)
}
}
}
#[tracing::instrument(skip(self), name = "Storage::retrieve_user_id_by_name")]
pub async fn retrieve_user_id_by_name(&self, name: &str) -> Result<Option<u32>> {
let mut executor = self.conn.lock().await;
let res: Option<(u32,)> = sqlx::query_as("select u.id from users u where u.name = ?;")
@ -14,6 +96,7 @@ impl Storage {
Ok(res.map(|(id,)| id))
}
#[tracing::instrument(skip(self), name = "Storage::create_or_retrieve_user_id_by_name")]
pub async fn create_or_retrieve_user_id_by_name(&self, name: &str) -> Result<u32> {
let mut executor = self.conn.lock().await;
let res: (u32,) = sqlx::query_as(
@ -28,18 +111,4 @@ impl Storage {
Ok(res.0)
}
pub async fn get_rooms_of_a_user(&self, user_id: u32) -> Result<Vec<RoomId>> {
let mut executor = self.conn.lock().await;
let res: Vec<(String,)> = sqlx::query_as(
"select r.name
from memberships m inner join rooms r on m.room_id = r.id
where m.user_id = ?;",
)
.bind(user_id)
.fetch_all(&mut *executor)
.await?;
res.into_iter().map(|(room_id,)| RoomId::from(room_id)).collect()
}
}

View File

@ -243,7 +243,7 @@ impl Room {
async fn send_message(&mut self, author_id: &PlayerId, body: Str, created_at: DateTime<Utc>) -> Result<()> {
tracing::info!("Adding a message to room");
self.storage
.insert_message(
.insert_room_message(
self.storage_id,
self.message_count,
&body,

View File

@ -1,7 +1,10 @@
use lavina_core::{player::PlayerConnection, prelude::Str, LavinaCore};
use std::future::Future;
use tokio::io::AsyncWrite;
use lavina_core::player::PlayerConnection;
use lavina_core::prelude::Str;
pub struct IrcConnection<'a, T: AsyncWrite + Unpin> {
pub server_name: Str,
/// client is nick of requester

View File

@ -114,8 +114,8 @@ impl TestServer {
listen_on: "127.0.0.1:0".parse().unwrap(),
server_name: "testserver".into(),
};
let mut metrics = MetricsRegistry::new();
let mut storage = Storage::open(StorageConfig {
let metrics = MetricsRegistry::new();
let storage = Storage::open(StorageConfig {
db_path: ":memory:".into(),
})
.await?;
@ -153,7 +153,7 @@ impl TestServer {
let TestServer {
metrics: _,
storage,
mut core,
core,
server,
} = self;
server.terminate().await?;
@ -179,7 +179,7 @@ impl TestServer {
#[tokio::test]
async fn scenario_basic() -> Result<()> {
let mut server = TestServer::start().await?;
let server = TestServer::start().await?;
// test scenario
@ -208,7 +208,7 @@ async fn scenario_basic() -> Result<()> {
#[tokio::test]
async fn scenario_join_and_reboot() -> Result<()> {
let mut server = TestServer::start().await?;
let server = TestServer::start().await?;
// test scenario
@ -278,7 +278,7 @@ async fn scenario_join_and_reboot() -> Result<()> {
#[tokio::test]
async fn scenario_force_join_msg() -> Result<()> {
let mut server = TestServer::start().await?;
let server = TestServer::start().await?;
// test scenario
@ -344,7 +344,7 @@ async fn scenario_force_join_msg() -> Result<()> {
#[tokio::test]
async fn scenario_two_users() -> Result<()> {
let mut server = TestServer::start().await?;
let server = TestServer::start().await?;
// test scenario
@ -420,7 +420,7 @@ AUTHENTICATE doc: https://modern.ircdocs.horse/#authenticate-message
*/
#[tokio::test]
async fn scenario_cap_full_negotiation() -> Result<()> {
let mut server = TestServer::start().await?;
let server = TestServer::start().await?;
// test scenario
@ -460,7 +460,7 @@ async fn scenario_cap_full_negotiation() -> Result<()> {
#[tokio::test]
async fn scenario_cap_full_negotiation_nick_last() -> Result<()> {
let mut server = TestServer::start().await?;
let server = TestServer::start().await?;
// test scenario
@ -499,7 +499,7 @@ async fn scenario_cap_full_negotiation_nick_last() -> Result<()> {
#[tokio::test]
async fn scenario_cap_short_negotiation() -> Result<()> {
let mut server = TestServer::start().await?;
let server = TestServer::start().await?;
// test scenario
@ -537,7 +537,7 @@ async fn scenario_cap_short_negotiation() -> Result<()> {
#[tokio::test]
async fn scenario_cap_sasl_fail() -> Result<()> {
let mut server = TestServer::start().await?;
let server = TestServer::start().await?;
// test scenario
@ -581,7 +581,7 @@ async fn scenario_cap_sasl_fail() -> Result<()> {
#[tokio::test]
async fn terminate_socket_scenario() -> Result<()> {
let mut server = TestServer::start().await?;
let server = TestServer::start().await?;
// test scenario
@ -606,7 +606,7 @@ async fn terminate_socket_scenario() -> Result<()> {
#[tokio::test]
async fn server_time_capability() -> Result<()> {
let mut server = TestServer::start().await?;
let server = TestServer::start().await?;
// test scenario
@ -675,7 +675,7 @@ async fn server_time_capability() -> Result<()> {
#[tokio::test]
async fn scenario_two_players_dialog() -> Result<()> {
let mut server = TestServer::start().await?;
let server = TestServer::start().await?;
// test scenario

View File

@ -86,9 +86,9 @@ impl<'a> XmppConnection<'a> {
#[cfg(test)]
mod tests {
use crate::testkit::{expect_user_authenticated, TestServer};
use crate::{Authenticated, XmppConnection};
use crate::Authenticated;
use lavina_core::player::PlayerId;
use proto_xmpp::bind::{BindRequest, BindResponse, Jid, Name, Resource, Server};
use proto_xmpp::bind::{Jid, Name, Resource, Server};
use proto_xmpp::client::Presence;
#[tokio::test]

View File

@ -23,7 +23,6 @@ use lavina_core::clustering::{ClusterConfig, ClusterMetadata};
use lavina_core::repo::{Storage, StorageConfig};
use lavina_core::LavinaCore;
use projection_xmpp::{launch, RunningServer, ServerConfig};
use proto_xmpp::xml::{Continuation, FromXml, Parser};
fn element_name<'a>(local_name: &LocalName<'a>) -> &'a str {
from_utf8(local_name.into_inner()).unwrap()
@ -190,7 +189,7 @@ impl TestServer {
#[tokio::test]
async fn scenario_basic() -> Result<()> {
let mut server = TestServer::start().await?;
let server = TestServer::start().await?;
// test scenario
@ -258,7 +257,7 @@ async fn scenario_basic() -> Result<()> {
#[tokio::test]
async fn scenario_wrong_password() -> Result<()> {
let mut server = TestServer::start().await?;
let server = TestServer::start().await?;
// test scenario
@ -313,7 +312,7 @@ async fn scenario_wrong_password() -> Result<()> {
#[tokio::test]
async fn scenario_basic_without_headers() -> Result<()> {
let mut server = TestServer::start().await?;
let server = TestServer::start().await?;
// test scenario
@ -359,7 +358,7 @@ async fn scenario_basic_without_headers() -> Result<()> {
#[tokio::test]
async fn terminate_socket() -> Result<()> {
let mut server = TestServer::start().await?;
let server = TestServer::start().await?;
// test scenario
@ -399,7 +398,7 @@ async fn terminate_socket() -> Result<()> {
#[tokio::test]
async fn test_message_archive_request() -> Result<()> {
let mut server = TestServer::start().await?;
let server = TestServer::start().await?;
// test scenario

View File

@ -1,7 +1,6 @@
use anyhow::{anyhow, Result};
use quick_xml::events::{BytesEnd, BytesStart, BytesText, Event};
use quick_xml::name::{Namespace, ResolveResult};
use std::io::Read;
use crate::xml::*;