diff --git a/Cargo.lock b/Cargo.lock index cd416ce..80a40aa 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -855,6 +855,7 @@ dependencies = [ "http-body-util", "hyper 1.0.0-rc.3", "lazy_static", + "mgmt-api", "nom", "nonempty", "prometheus", @@ -942,6 +943,13 @@ version = "2.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2dffe52ecf27772e601905b7522cb4ef790d2cc203488bbd0e2fe85fcb74566d" +[[package]] +name = "mgmt-api" +version = "0.0.1-dev" +dependencies = [ + "serde", +] + [[package]] name = "mime" version = "0.3.17" diff --git a/Cargo.toml b/Cargo.toml index f2d6e9e..f8f0d55 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,7 +1,8 @@ [workspace] members = [ ".", - "crates/proto-irc" + "crates/proto-irc", + "crates/mgmt-api", ] [workspace.package] @@ -15,6 +16,7 @@ futures-util = "0.3.25" anyhow = "1.0.68" # error utils nonempty = "0.8.1" clap = { version = "4.4.4", features = ["derive"] } +serde = { version = "1.0.152", features = ["rc", "serde_derive"] } [package] name = "lavina" @@ -27,7 +29,7 @@ anyhow.workspace = true figment = { version = "0.10.8", features = ["env", "toml"] } # configuration files hyper = { version = "1.0.0-rc.3,<1.0.0-rc.4", features = ["server", "http1"] } # http server http-body-util = "0.1.0-rc.3" -serde = { version = "1.0.152", features = ["rc", "serde_derive"] } +serde.workspace = true serde_json = "1.0.93" tokio.workspace = true tracing = "0.1.37" # logging & tracing api @@ -45,6 +47,7 @@ derive_more = "0.99.17" uuid = { version = "1.3.0", features = ["v4"] } sqlx = { version = "0.7.0-alpha.2", features = ["sqlite", "migrate"] } proto-irc = { path = "crates/proto-irc" } +mgmt-api = { path = "crates/mgmt-api" } clap.workspace = true [dev-dependencies] diff --git a/crates/mgmt-api/Cargo.toml b/crates/mgmt-api/Cargo.toml new file mode 100644 index 0000000..030f3cf --- /dev/null +++ b/crates/mgmt-api/Cargo.toml @@ -0,0 +1,8 @@ +[package] +name = "mgmt-api" +edition = "2021" +version.workspace = true +publish = false + +[dependencies] +serde.workspace = true diff --git a/crates/mgmt-api/src/lib.rs b/crates/mgmt-api/src/lib.rs new file mode 100644 index 0000000..cfe5b69 --- /dev/null +++ b/crates/mgmt-api/src/lib.rs @@ -0,0 +1,29 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Serialize, Deserialize)] +pub struct ErrorResponse<'a> { + pub code: &'a str, + pub message: &'a str, +} + +#[derive(Serialize, Deserialize)] +pub struct CreatePlayerRequest<'a> { + pub name: &'a str, +} + +#[derive(Serialize, Deserialize)] +pub struct ChangePasswordRequest<'a> { + pub player_name: &'a str, + pub password: &'a str, +} + +pub mod paths { + pub const CREATE_PLAYER: &'static str = "/mgmt/create_player"; + pub const SET_PASSWORD: &'static str = "/mgmt/set_password"; +} + +pub mod errors { + pub const INVALID_PATH: &'static str = "invalid_path"; + pub const MALFORMED_REQUEST: &'static str = "malformed_request"; + pub const PLAYER_NOT_FOUND: &'static str = "player_not_found"; +} diff --git a/rustfmt.toml b/rustfmt.toml new file mode 100644 index 0000000..866c756 --- /dev/null +++ b/rustfmt.toml @@ -0,0 +1 @@ +max_width = 120 \ No newline at end of file diff --git a/src/core/repo/mod.rs b/src/core/repo/mod.rs index 10b601a..57371ed 100644 --- a/src/core/repo/mod.rs +++ b/src/core/repo/mod.rs @@ -5,7 +5,7 @@ use std::sync::Arc; use serde::Deserialize; use sqlx::sqlite::SqliteConnectOptions; -use sqlx::{ConnectOptions, Connection, FromRow, SqliteConnection}; +use sqlx::{ConnectOptions, Connection, FromRow, Sqlite, SqliteConnection, Transaction}; use tokio::sync::Mutex; use crate::prelude::*; @@ -102,6 +102,50 @@ impl Storage { res.close().await?; Ok(()) } + + pub async fn create_user(&mut self, name: &str) -> Result<()> { + let query = sqlx::query( + "insert into users(name) + values (?);", + ) + .bind(name); + let mut executor = self.conn.lock().await; + query.execute(&mut *executor).await?; + + Ok(()) + } + + pub async fn set_password<'a>(&'a mut self, name: &'a str, pwd: &'a str) -> Result> { + async fn inner(txn: &mut Transaction<'_, Sqlite>, name: &str, pwd: &str) -> Result> { + let id: Option<(u32,)> = sqlx::query_as("select * from users where name = ? limit 1;") + .bind(name) + .fetch_optional(&mut **txn) + .await?; + let Some((id,)) = id else { + return Ok(None); + }; + sqlx::query("insert or replace into challenges_plain_password(user_id, password) values (?, ?);") + .bind(id) + .bind(pwd) + .execute(&mut **txn) + .await?; + Ok(Some(())) + } + + let mut executor = self.conn.lock().await; + let mut tx = executor.begin().await?; + let res = inner(&mut tx, name, pwd).await; + match res { + Ok(e) => { + tx.commit().await?; + Ok(e) + } + Err(e) => { + tx.rollback().await?; + Err(e) + } + } + } } #[derive(FromRow)] diff --git a/src/main.rs b/src/main.rs index b6f19da..d10c9ee 100644 --- a/src/main.rs +++ b/src/main.rs @@ -65,7 +65,7 @@ async fn main() -> Result<()> { let rooms = RoomRegistry::new(&mut metrics, storage.clone())?; let mut players = PlayerRegistry::empty(rooms.clone(), &mut metrics)?; let telemetry_terminator = - util::telemetry::launch(telemetry_config, metrics.clone(), rooms.clone()).await?; + util::telemetry::launch(telemetry_config, metrics.clone(), rooms.clone(), storage.clone()).await?; let irc = projections::irc::launch(irc_config, players.clone(), rooms.clone(), metrics.clone(), storage.clone()).await?; let xmpp = projections::xmpp::launch(xmpp_config, players.clone(), rooms.clone(), metrics.clone()).await?; tracing::info!("Started"); diff --git a/src/util/http.rs b/src/util/http.rs deleted file mode 100644 index 1f79e32..0000000 --- a/src/util/http.rs +++ /dev/null @@ -1,10 +0,0 @@ -use std::convert::Infallible; - -use http_body_util::Full; -use hyper::{body::Bytes, Response, StatusCode}; - -pub fn not_found() -> std::result::Result>, Infallible> { - let mut response = Response::new(Full::new(Bytes::from("404"))); - *response.status_mut() = StatusCode::NOT_FOUND; - Ok(response) -} diff --git a/src/util/mod.rs b/src/util/mod.rs index 0593dbd..2fde69e 100644 --- a/src/util/mod.rs +++ b/src/util/mod.rs @@ -1,6 +1,5 @@ use crate::prelude::*; -pub mod http; pub mod table; pub mod telemetry; #[cfg(test)] diff --git a/src/util/telemetry.rs b/src/util/telemetry.rs index 26387dd..3bb5ec9 100644 --- a/src/util/telemetry.rs +++ b/src/util/telemetry.rs @@ -6,18 +6,18 @@ use http_body_util::{BodyExt, Full}; use hyper::body::Bytes; use hyper::server::conn::http1; use hyper::service::service_fn; -use hyper::{Method, Request, Response}; +use hyper::{Method, Request, Response, StatusCode}; use prometheus::{Encoder, Registry as MetricsRegistry, TextEncoder}; -use serde::Deserialize; +use serde::{Deserialize, Serialize}; use tokio::net::TcpListener; +use crate::core::repo::Storage; use crate::core::room::RoomRegistry; use crate::prelude::*; - -use crate::util::http::*; use crate::util::Terminator; -type BoxBody = http_body_util::combinators::BoxBody; +use mgmt_api::*; + type HttpResult = std::result::Result; #[derive(Deserialize, Debug)] @@ -29,11 +29,12 @@ pub async fn launch( config: ServerConfig, metrics: MetricsRegistry, rooms: RoomRegistry, + storage: Storage, ) -> Result { log::info!("Starting the telemetry service"); let listener = TcpListener::bind(config.listen_on).await?; log::debug!("Listener started"); - let terminator = Terminator::spawn(|rx| main_loop(listener, metrics, rooms, rx.map(|_| ()))); + let terminator = Terminator::spawn(|rx| main_loop(listener, metrics, rooms, storage, rx.map(|_| ()))); Ok(terminator) } @@ -41,6 +42,7 @@ async fn main_loop( listener: TcpListener, metrics: MetricsRegistry, rooms: RoomRegistry, + storage: Storage, termination: impl Future, ) -> Result<()> { pin!(termination); @@ -52,10 +54,12 @@ async fn main_loop( let (stream, _) = result?; let metrics = metrics.clone(); let rooms = rooms.clone(); + let storage = storage.clone(); tokio::task::spawn(async move { let registry = metrics.clone(); let rooms = rooms.clone(); - let server = http1::Builder::new().serve_connection(stream, service_fn(move |r| route(registry.clone(), rooms.clone(), r))); + let storage = storage.clone(); + let server = http1::Builder::new().serve_connection(stream, service_fn(move |r| route(registry.clone(), rooms.clone(), storage.clone(), r))); if let Err(err) = server.await { tracing::error!("Error serving connection: {:?}", err); } @@ -70,27 +74,124 @@ async fn main_loop( async fn route( registry: MetricsRegistry, rooms: RoomRegistry, + storage: Storage, request: Request, -) -> std::result::Result, Infallible> { - match (request.method(), request.uri().path()) { - (&Method::GET, "/metrics") => Ok(endpoint_metrics(registry)?.map(BodyExt::boxed)), - (&Method::GET, "/rooms") => Ok(endpoint_rooms(rooms).await?.map(BodyExt::boxed)), - _ => Ok(not_found()?.map(BodyExt::boxed)), +) -> HttpResult>> { + let res = match (request.method(), request.uri().path()) { + (&Method::GET, "/metrics") => endpoint_metrics(registry), + (&Method::GET, "/rooms") => endpoint_rooms(rooms).await, + (&Method::POST, paths::CREATE_PLAYER) => endpoint_create_player(request, storage).await.or5xx(), + (&Method::POST, paths::SET_PASSWORD) => endpoint_set_password(request, storage).await.or5xx(), + _ => not_found(), + }; + Ok(res) +} + +fn endpoint_metrics(registry: MetricsRegistry) -> Response> { + let mf = registry.gather(); + let mut buffer = vec![]; + TextEncoder.encode(&mf, &mut buffer).expect("write to vec cannot fail"); + Response::new(Full::new(Bytes::from(buffer))) +} + +async fn endpoint_rooms(rooms: RoomRegistry) -> Response> { + // TODO introduce management API types independent from core-domain types + // TODO remove `Serialize` implementations from all core-domain types + let room_list = rooms.get_all_rooms().await.to_body(); + Response::new(room_list) +} + +async fn endpoint_create_player( + request: Request, + mut storage: Storage, +) -> Result>> { + let str = request.collect().await?.to_bytes(); + let Ok(res) = serde_json::from_slice::(&str[..]) else { + let payload = ErrorResponse { + code: errors::MALFORMED_REQUEST, + message: "The request payload contains incorrect JSON value", + } + .to_body(); + let mut response = Response::new(payload); + *response.status_mut() = StatusCode::BAD_REQUEST; + return Ok(response); + }; + storage.create_user(&res.name).await?; + log::info!("Player {} created", res.name); + let mut response = Response::new(Full::::default()); + *response.status_mut() = StatusCode::CREATED; + Ok(response) +} + +async fn endpoint_set_password( + request: Request, + mut storage: Storage, +) -> Result>> { + let str = request.collect().await?.to_bytes(); + let Ok(res) = serde_json::from_slice::(&str[..]) else { + let payload = ErrorResponse { + code: errors::MALFORMED_REQUEST, + message: "The request payload contains incorrect JSON value", + } + .to_body(); + let mut response = Response::new(payload); + *response.status_mut() = StatusCode::BAD_REQUEST; + return Ok(response); + }; + let Some(_) = storage.set_password(&res.player_name, &res.password).await? else { + let payload = ErrorResponse { + code: errors::PLAYER_NOT_FOUND, + message: "No such player exists", + } + .to_body(); + let mut response = Response::new(payload); + *response.status_mut() = StatusCode::UNPROCESSABLE_ENTITY; + return Ok(response); + }; + log::info!("Password changed for player {}", res.player_name); + let mut response = Response::new(Full::::default()); + *response.status_mut() = StatusCode::NO_CONTENT; + Ok(response) +} + +pub fn not_found() -> Response> { + let payload = ErrorResponse { + code: errors::INVALID_PATH, + message: "The path does not exist", + } + .to_body(); + + let mut response = Response::new(payload); + *response.status_mut() = StatusCode::NOT_FOUND; + response +} + +trait Or5xx { + fn or5xx(self) -> Response>; +} +impl Or5xx for Result>> { + fn or5xx(self) -> Response> { + match self { + Ok(e) => e, + Err(e) => { + let mut response = Response::new(Full::new(e.to_string().into())); + *response.status_mut() = StatusCode::INTERNAL_SERVER_ERROR; + response + } + } } } -fn endpoint_metrics(registry: MetricsRegistry) -> HttpResult>> { - let mf = registry.gather(); - let mut buffer = vec![]; - TextEncoder - .encode(&mf, &mut buffer) - .expect("write to vec cannot fail"); - Ok(Response::new(Full::new(Bytes::from(buffer)))) +trait ToBody { + fn to_body(&self) -> Full; } - -async fn endpoint_rooms(rooms: RoomRegistry) -> HttpResult>> { - let room_list = rooms.get_all_rooms().await; - let mut buffer = vec![]; - serde_json::to_writer(&mut buffer, &room_list).expect("unexpected fail when writing to vec"); - Ok(Response::new(Full::new(Bytes::from(buffer)))) +impl ToBody for T +where + T: Serialize, +{ + fn to_body(&self) -> Full { + let mut buffer = vec![]; + serde_json::to_writer(&mut buffer, self).expect("unexpected fail when writing to vec"); + Full::new(Bytes::from(buffer)) + } }