//! Storage and persistence logic. use std::str::FromStr; use std::sync::Arc; use serde::Deserialize; use sqlx::sqlite::SqliteConnectOptions; use sqlx::{ConnectOptions, Connection, FromRow, SqliteConnection}; use tokio::sync::Mutex; use crate::prelude::*; #[derive(Deserialize, Debug, Clone)] pub struct StorageConfig { pub db_path: String, } #[derive(Clone)] pub struct Storage { conn: Arc>, } impl Storage { pub async fn open(config: StorageConfig) -> Result { let opts = SqliteConnectOptions::from_str(&*config.db_path)?.create_if_missing(true); let mut conn = opts.connect().await?; let migrator = sqlx::migrate!(); migrator.run(&mut conn).await?; log::info!("Migrations passed"); let conn = Arc::new(Mutex::new(conn)); Ok(Storage { conn }) } pub async fn retrieve_user_by_name(&mut self, name: &str) -> Result> { let mut executor = self.conn.lock().await; let res = sqlx::query_as( "select u.id, u.name, c.password from users u left join challenges_plain_password c on u.id = c.user_id where u.name = ?;", ) .bind(name) .fetch_optional(&mut *executor) .await?; Ok(res) } pub async fn retrieve_room_by_name(&mut self, name: &str) -> Result> { let mut executor = self.conn.lock().await; let res = sqlx::query_as( "select id, name, topic, message_count from rooms where name = ?;", ) .bind(name) .fetch_optional(&mut *executor) .await?; Ok(res) } pub async fn create_new_room(&mut self, name: &str, topic: &str) -> Result { let mut executor = self.conn.lock().await; let (id,): (u32,) = sqlx::query_as( "insert into rooms(name, topic) values (?, ?) returning id;", ) .bind(name) .bind(topic) .fetch_one(&mut *executor) .await?; Ok(id) } pub async fn insert_message(&mut self, room_id: u32, id: u32, content: &str) -> Result<()> { let mut executor = self.conn.lock().await; sqlx::query( "insert into messages(room_id, id, content) values (?, ?, ?); update rooms set message_count = message_count + 1 where id = ?;", ) .bind(room_id) .bind(id) .bind(content) .bind(room_id) .execute(&mut *executor) .await?; Ok(()) } pub async fn close(mut self) -> Result<()> { let res = match Arc::try_unwrap(self.conn) { Ok(e) => e, Err(e) => return Err(fail("failed to acquire DB ownership on shutdown")), }; let res = res.into_inner(); res.close().await?; Ok(()) } } #[derive(FromRow)] pub struct StoredUser { pub id: u32, pub name: String, pub password: Option, } #[derive(FromRow)] pub struct StoredRoom { pub id: u32, pub name: String, pub topic: String, pub message_count: u32, }