diff --git a/src/core/repo/mod.rs b/src/core/repo/mod.rs index ffc53cc..85c5e72 100644 --- a/src/core/repo/mod.rs +++ b/src/core/repo/mod.rs @@ -1,10 +1,12 @@ //! Storage and persistence logic. use std::str::FromStr; +use std::sync::Arc; use serde::Deserialize; use sqlx::sqlite::SqliteConnectOptions; -use sqlx::{ConnectOptions, SqliteConnection}; +use sqlx::{ConnectOptions, Connection, FromRow, SqliteConnection}; +use tokio::sync::Mutex; use crate::prelude::*; @@ -13,8 +15,9 @@ pub struct StorageConfig { pub db_path: String, } +#[derive(Clone)] pub struct Storage { - conn: SqliteConnection, + conn: Arc>, } impl Storage { pub async fn open(config: StorageConfig) -> Result { @@ -26,6 +29,38 @@ impl Storage { migrator.run(&mut conn).await?; log::info!("Migrations passed"); + let conn = Arc::new(Mutex::new(conn)); Ok(Storage { conn }) } + + pub async fn retrieve_user_by_name(&mut self, name: &str) -> Result> { + let mut executor = self.conn.lock().await; + let res = sqlx::query_as( + "select u.id, u.name, c.password + from users u left join challenges_plain_password c on u.id = c.user_id + where u.name = ?;", + ) + .bind(name) + .fetch_optional(&mut *executor) + .await?; + + Ok(res) + } + + pub async fn close(mut self) -> Result<()> { + let res = match Arc::try_unwrap(self.conn) { + Ok(e) => e, + Err(_) => return Err(fail("failed to acquire DB ownership on shutdown")), + }; + let res = res.into_inner(); + res.close().await?; + Ok(()) + } +} + +#[derive(FromRow)] +pub struct StoredUser { + pub id: u32, + pub name: String, + pub password: Option, } diff --git a/src/main.rs b/src/main.rs index acd4fd9..93a170c 100644 --- a/src/main.rs +++ b/src/main.rs @@ -57,7 +57,7 @@ async fn main() -> Result<()> { let players = PlayerRegistry::empty(rooms.clone(), &mut metrics)?; let telemetry_terminator = util::telemetry::launch(telemetry_config, metrics.clone(), rooms.clone()).await?; - let irc = projections::irc::launch(irc_config, players.clone(), rooms.clone(), metrics.clone()).await?; + let irc = projections::irc::launch(irc_config, players.clone(), rooms.clone(), metrics.clone(), storage.clone()).await?; let xmpp = projections::xmpp::launch(xmpp_config, players, rooms.clone(), metrics.clone()).await?; tracing::info!("Started"); @@ -66,6 +66,7 @@ async fn main() -> Result<()> { tracing::info!("Begin shutdown"); xmpp.terminate().await?; irc.terminate().await?; + storage.close().await?; telemetry_terminator.terminate().await?; tracing::info!("Shutdown complete"); Ok(()) diff --git a/src/projections/irc/mod.rs b/src/projections/irc/mod.rs index afbf10f..e80a562 100644 --- a/src/projections/irc/mod.rs +++ b/src/projections/irc/mod.rs @@ -10,6 +10,7 @@ use tokio::net::{TcpListener, TcpStream}; use tokio::sync::mpsc::channel; use crate::core::player::*; +use crate::core::repo::Storage; use crate::core::room::{RoomId, RoomInfo, RoomRegistry}; use crate::prelude::*; use crate::protos::irc::client::{client_message, ClientMessage}; @@ -45,6 +46,7 @@ async fn handle_socket( players: PlayerRegistry, rooms: RoomRegistry, termination: Deferred<()>, // TODO use it to stop the connection gracefully + mut storage: Storage, ) -> Result<()> { let (reader, writer) = stream.split(); let mut reader: BufReader = BufReader::new(reader); @@ -64,7 +66,8 @@ async fn handle_socket( writer.flush().await?; let registered_user: Result = - handle_registration(&mut reader, &mut writer).await; + handle_registration(&mut reader, &mut writer, &mut storage).await; + match registered_user { Ok(user) => { handle_registered_socket(config, players, rooms, &mut reader, &mut writer, user) @@ -80,13 +83,16 @@ async fn handle_socket( async fn handle_registration<'a>( reader: &mut BufReader>, writer: &mut BufWriter>, + storage: &mut Storage, ) -> Result { let mut buffer = vec![]; let mut future_nickname: Option = None; let mut future_username: Option<(Str, Str)> = None; - loop { + let mut pass: Option = None; + + let user = loop { let res = reader.read_until(b'\n', &mut buffer).await; let res = match res { Ok(len) => { @@ -110,6 +116,9 @@ async fn handle_registration<'a>( Ok((_, msg)) => { log::debug!("Incoming IRC message: {msg:?}"); match msg { + ClientMessage::Pass { password } => { + pass = Some(password); + } ClientMessage::Nick { nickname } => { if let Some((username, realname)) = future_username { break Ok(RegisteredUser { @@ -140,7 +149,29 @@ async fn handle_registration<'a>( } } buffer.clear(); + }?; + + + let stored_user = storage.retrieve_user_by_name(&*user.nickname).await?; + + let stored_user = match stored_user { + Some(u) => u, + None => { + log::info!("User '{}' not found", user.nickname); + return Err(fail("no user found")); + } + }; + if stored_user.password.is_none() { + log::info!("Password not defined for user '{}'", user.nickname); + return Err(fail("password is not defined")); } + if stored_user.password.as_deref() != pass.as_deref() { + log::info!("Incorrect password supplied for user '{}'", user.nickname); + return Err(fail("passwords do not match")); + } + // TODO properly implement session temination + + Ok(user) } async fn handle_registered_socket<'a>( @@ -663,6 +694,7 @@ pub async fn launch( players: PlayerRegistry, rooms: RoomRegistry, metrics: MetricsRegistry, + storage: Storage, ) -> Result { log::info!("Starting IRC projection"); let (stopped_tx, mut stopped_rx) = channel(32); @@ -708,8 +740,9 @@ pub async fn launch( let rooms = rooms.clone(); let current_connections_clone = current_connections.clone(); let stopped_tx = stopped_tx.clone(); + let storage = storage.clone(); async move { - match handle_socket(config, stream, &socket_addr, players, rooms, termination).await { + match handle_socket(config, stream, &socket_addr, players, rooms, termination, storage).await { Ok(_) => log::info!("Connection terminated"), Err(err) => log::warn!("Connection failed: {err}"), } diff --git a/src/protos/irc/client.rs b/src/protos/irc/client.rs index 7d73c87..827fe74 100644 --- a/src/protos/irc/client.rs +++ b/src/protos/irc/client.rs @@ -21,6 +21,10 @@ pub enum ClientMessage { Nick { nickname: Str, }, + /// PASS + Pass { + password: Str, + }, /// USER 0 * : User { username: Str, @@ -62,6 +66,7 @@ pub fn client_message(input: &str) -> IResult<&str, ClientMessage> { client_message_ping, client_message_pong, client_message_nick, + client_message_pass, client_message_user, client_message_join, client_message_mode, @@ -115,6 +120,21 @@ fn client_message_nick(input: &str) -> IResult<&str, ClientMessage> { }, )) } +fn client_message_pass(input: &str) -> IResult<&str, ClientMessage> { + let (input, _) = tag("PASS ")(input)?; + let (input, r) = opt(tag(":"))(input)?; + let (input, password) = match r { + Some(_) => token(input)?, + None => receiver(input)?, + }; + + Ok(( + input, + ClientMessage::Pass { + password: password.into(), + }, + )) +} fn client_message_user(input: &str) -> IResult<&str, ClientMessage> { let (input, _) = tag("USER ")(input)?;