forked from lavina/lavina
				
			
		
			
				
	
	
		
			178 lines
		
	
	
		
			5.0 KiB
		
	
	
	
		
			Rust
		
	
	
	
			
		
		
	
	
			178 lines
		
	
	
		
			5.0 KiB
		
	
	
	
		
			Rust
		
	
	
	
| //! Storage and persistence logic.
 | |
| 
 | |
| use std::str::FromStr;
 | |
| use std::sync::Arc;
 | |
| 
 | |
| use anyhow::anyhow;
 | |
| use serde::Deserialize;
 | |
| use sqlx::sqlite::SqliteConnectOptions;
 | |
| use sqlx::{ConnectOptions, Connection, FromRow, Sqlite, SqliteConnection, Transaction};
 | |
| use tokio::sync::Mutex;
 | |
| 
 | |
| use crate::prelude::*;
 | |
| 
 | |
| mod room;
 | |
| mod user;
 | |
| 
 | |
| #[derive(Deserialize, Debug, Clone)]
 | |
| pub struct StorageConfig {
 | |
|     pub db_path: String,
 | |
| }
 | |
| 
 | |
| #[derive(Clone)]
 | |
| pub struct Storage {
 | |
|     conn: Arc<Mutex<SqliteConnection>>,
 | |
| }
 | |
| impl Storage {
 | |
|     pub async fn open(config: StorageConfig) -> Result<Storage> {
 | |
|         let opts = SqliteConnectOptions::from_str(&*config.db_path)?.create_if_missing(true);
 | |
|         let mut conn = opts.connect().await?;
 | |
| 
 | |
|         let migrator = sqlx::migrate!();
 | |
| 
 | |
|         migrator.run(&mut conn).await?;
 | |
|         log::info!("Migrations passed");
 | |
| 
 | |
|         let conn = Arc::new(Mutex::new(conn));
 | |
|         Ok(Storage { conn })
 | |
|     }
 | |
| 
 | |
|     pub async fn retrieve_user_by_name(&mut self, name: &str) -> Result<Option<StoredUser>> {
 | |
|         let mut executor = self.conn.lock().await;
 | |
|         let res = sqlx::query_as(
 | |
|             "select u.id, u.name, c.password
 | |
|             from users u left join challenges_plain_password c on u.id = c.user_id
 | |
|             where u.name = ?;",
 | |
|         )
 | |
|         .bind(name)
 | |
|         .fetch_optional(&mut *executor)
 | |
|         .await?;
 | |
| 
 | |
|         Ok(res)
 | |
|     }
 | |
| 
 | |
|     pub async fn retrieve_room_by_name(&self, name: &str) -> Result<Option<StoredRoom>> {
 | |
|         let mut executor = self.conn.lock().await;
 | |
|         let res = sqlx::query_as(
 | |
|             "select id, name, topic, message_count
 | |
|             from rooms
 | |
|             where name = ?;",
 | |
|         )
 | |
|         .bind(name)
 | |
|         .fetch_optional(&mut *executor)
 | |
|         .await?;
 | |
| 
 | |
|         Ok(res)
 | |
|     }
 | |
| 
 | |
|     pub async fn create_new_room(&mut self, name: &str, topic: &str) -> Result<u32> {
 | |
|         let mut executor = self.conn.lock().await;
 | |
|         let (id,): (u32,) = sqlx::query_as(
 | |
|             "insert into rooms(name, topic)
 | |
|             values (?, ?)
 | |
|             returning id;",
 | |
|         )
 | |
|         .bind(name)
 | |
|         .bind(topic)
 | |
|         .fetch_one(&mut *executor)
 | |
|         .await?;
 | |
| 
 | |
|         Ok(id)
 | |
|     }
 | |
| 
 | |
|     pub async fn insert_message(&mut self, room_id: u32, id: u32, content: &str, author_id: &str) -> Result<()> {
 | |
|         let mut executor = self.conn.lock().await;
 | |
|         let res: Option<(u32,)> = sqlx::query_as("select id from users where name = ?;")
 | |
|             .bind(author_id)
 | |
|             .fetch_optional(&mut *executor)
 | |
|             .await?;
 | |
|         let Some((author_id,)) = res else {
 | |
|             return Err(anyhow!("No such user"));
 | |
|         };
 | |
|         sqlx::query(
 | |
|             "insert into messages(room_id, id, content, author_id, created_at)
 | |
|             values (?, ?, ?, ?, ?);
 | |
|             update rooms set message_count = message_count + 1 where id = ?;",
 | |
|         )
 | |
|         .bind(room_id)
 | |
|         .bind(id)
 | |
|         .bind(content)
 | |
|         .bind(author_id)
 | |
|         .bind(chrono::Utc::now().to_string())
 | |
|         .bind(room_id)
 | |
|         .execute(&mut *executor)
 | |
|         .await?;
 | |
| 
 | |
|         Ok(())
 | |
|     }
 | |
| 
 | |
|     pub async fn close(self) -> Result<()> {
 | |
|         let res = match Arc::try_unwrap(self.conn) {
 | |
|             Ok(e) => e,
 | |
|             Err(_) => return Err(fail("failed to acquire DB ownership on shutdown")),
 | |
|         };
 | |
|         let res = res.into_inner();
 | |
|         res.close().await?;
 | |
|         Ok(())
 | |
|     }
 | |
| 
 | |
|     pub async fn create_user(&mut self, name: &str) -> Result<()> {
 | |
|         let query = sqlx::query(
 | |
|             "insert into users(name)
 | |
|             values (?);",
 | |
|         )
 | |
|         .bind(name);
 | |
|         let mut executor = self.conn.lock().await;
 | |
|         query.execute(&mut *executor).await?;
 | |
| 
 | |
|         Ok(())
 | |
|     }
 | |
| 
 | |
|     pub async fn set_password<'a>(&'a mut self, name: &'a str, pwd: &'a str) -> Result<Option<()>> {
 | |
|         async fn inner(txn: &mut Transaction<'_, Sqlite>, name: &str, pwd: &str) -> Result<Option<()>> {
 | |
|             let id: Option<(u32,)> = sqlx::query_as("select * from users where name = ? limit 1;")
 | |
|                 .bind(name)
 | |
|                 .fetch_optional(&mut **txn)
 | |
|                 .await?;
 | |
|             let Some((id,)) = id else {
 | |
|                 return Ok(None);
 | |
|             };
 | |
|             sqlx::query("insert or replace into challenges_plain_password(user_id, password) values (?, ?);")
 | |
|                 .bind(id)
 | |
|                 .bind(pwd)
 | |
|                 .execute(&mut **txn)
 | |
|                 .await?;
 | |
|             Ok(Some(()))
 | |
|         }
 | |
| 
 | |
|         let mut executor = self.conn.lock().await;
 | |
|         let mut tx = executor.begin().await?;
 | |
|         let res = inner(&mut tx, name, pwd).await;
 | |
|         match res {
 | |
|             Ok(e) => {
 | |
|                 tx.commit().await?;
 | |
|                 Ok(e)
 | |
|             }
 | |
|             Err(e) => {
 | |
|                 tx.rollback().await?;
 | |
|                 Err(e)
 | |
|             }
 | |
|         }
 | |
|     }
 | |
| }
 | |
| 
 | |
| #[derive(FromRow)]
 | |
| pub struct StoredUser {
 | |
|     pub id: u32,
 | |
|     pub name: String,
 | |
|     pub password: Option<String>,
 | |
| }
 | |
| 
 | |
| #[derive(FromRow)]
 | |
| pub struct StoredRoom {
 | |
|     pub id: u32,
 | |
|     pub name: String,
 | |
|     pub topic: String,
 | |
|     pub message_count: u32,
 | |
| }
 |