//! Storage and persistence logic. use std::str::FromStr; use std::sync::Arc; use anyhow::anyhow; use serde::Deserialize; use sqlx::sqlite::SqliteConnectOptions; use sqlx::{ConnectOptions, Connection, FromRow, Sqlite, SqliteConnection, Transaction}; use tokio::sync::Mutex; use crate::prelude::*; #[derive(Deserialize, Debug, Clone)] pub struct StorageConfig { pub db_path: String, } #[derive(Clone)] pub struct Storage { conn: Arc>, } impl Storage { pub async fn open(config: StorageConfig) -> Result { let opts = SqliteConnectOptions::from_str(&*config.db_path)?.create_if_missing(true); let mut conn = opts.connect().await?; let migrator = sqlx::migrate!(); migrator.run(&mut conn).await?; log::info!("Migrations passed"); let conn = Arc::new(Mutex::new(conn)); Ok(Storage { conn }) } pub async fn retrieve_user_by_name(&mut self, name: &str) -> Result> { let mut executor = self.conn.lock().await; let res = sqlx::query_as( "select u.id, u.name, c.password from users u left join challenges_plain_password c on u.id = c.user_id where u.name = ?;", ) .bind(name) .fetch_optional(&mut *executor) .await?; Ok(res) } pub async fn retrieve_room_by_name(&mut self, name: &str) -> Result> { let mut executor = self.conn.lock().await; let res = sqlx::query_as( "select id, name, topic, message_count from rooms where name = ?;", ) .bind(name) .fetch_optional(&mut *executor) .await?; Ok(res) } pub async fn create_new_room(&mut self, name: &str, topic: &str) -> Result { let mut executor = self.conn.lock().await; let (id,): (u32,) = sqlx::query_as( "insert into rooms(name, topic) values (?, ?) returning id;", ) .bind(name) .bind(topic) .fetch_one(&mut *executor) .await?; Ok(id) } pub async fn insert_message(&mut self, room_id: u32, id: u32, content: &str, author_id: &str) -> Result<()> { let mut executor = self.conn.lock().await; let res: Option<(u32,)> = sqlx::query_as("select id from users where name = ?;") .bind(author_id) .fetch_optional(&mut *executor) .await?; let Some((author_id,)) = res else { return Err(anyhow!("No such user")); }; sqlx::query( "insert into messages(room_id, id, content, author_id) values (?, ?, ?, ?); update rooms set message_count = message_count + 1 where id = ?;", ) .bind(room_id) .bind(id) .bind(content) .bind(author_id) .bind(room_id) .execute(&mut *executor) .await?; Ok(()) } pub async fn close(self) -> Result<()> { let res = match Arc::try_unwrap(self.conn) { Ok(e) => e, Err(_) => return Err(fail("failed to acquire DB ownership on shutdown")), }; let res = res.into_inner(); res.close().await?; Ok(()) } pub async fn create_user(&mut self, name: &str) -> Result<()> { let query = sqlx::query( "insert into users(name) values (?);", ) .bind(name); let mut executor = self.conn.lock().await; query.execute(&mut *executor).await?; Ok(()) } pub async fn set_password<'a>(&'a mut self, name: &'a str, pwd: &'a str) -> Result> { async fn inner(txn: &mut Transaction<'_, Sqlite>, name: &str, pwd: &str) -> Result> { let id: Option<(u32,)> = sqlx::query_as("select * from users where name = ? limit 1;") .bind(name) .fetch_optional(&mut **txn) .await?; let Some((id,)) = id else { return Ok(None); }; sqlx::query("insert or replace into challenges_plain_password(user_id, password) values (?, ?);") .bind(id) .bind(pwd) .execute(&mut **txn) .await?; Ok(Some(())) } let mut executor = self.conn.lock().await; let mut tx = executor.begin().await?; let res = inner(&mut tx, name, pwd).await; match res { Ok(e) => { tx.commit().await?; Ok(e) } Err(e) => { tx.rollback().await?; Err(e) } } } } #[derive(FromRow)] pub struct StoredUser { pub id: u32, pub name: String, pub password: Option, } #[derive(FromRow)] pub struct StoredRoom { pub id: u32, pub name: String, pub topic: String, pub message_count: u32, }