forked from lavina/lavina
1
0
Fork 0

add management API endpoints

This commit is contained in:
Nikita Vilunov 2023-09-24 22:59:34 +02:00
parent df6cdd4861
commit 58f6a5d90a
10 changed files with 223 additions and 40 deletions

8
Cargo.lock generated
View File

@ -855,6 +855,7 @@ dependencies = [
"http-body-util",
"hyper 1.0.0-rc.3",
"lazy_static",
"mgmt-api",
"nom",
"nonempty",
"prometheus",
@ -942,6 +943,13 @@ version = "2.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2dffe52ecf27772e601905b7522cb4ef790d2cc203488bbd0e2fe85fcb74566d"
[[package]]
name = "mgmt-api"
version = "0.0.1-dev"
dependencies = [
"serde",
]
[[package]]
name = "mime"
version = "0.3.17"

View File

@ -1,7 +1,8 @@
[workspace]
members = [
".",
"crates/proto-irc"
"crates/proto-irc",
"crates/mgmt-api",
]
[workspace.package]
@ -15,6 +16,7 @@ futures-util = "0.3.25"
anyhow = "1.0.68" # error utils
nonempty = "0.8.1"
clap = { version = "4.4.4", features = ["derive"] }
serde = { version = "1.0.152", features = ["rc", "serde_derive"] }
[package]
name = "lavina"
@ -27,7 +29,7 @@ anyhow.workspace = true
figment = { version = "0.10.8", features = ["env", "toml"] } # configuration files
hyper = { version = "1.0.0-rc.3,<1.0.0-rc.4", features = ["server", "http1"] } # http server
http-body-util = "0.1.0-rc.3"
serde = { version = "1.0.152", features = ["rc", "serde_derive"] }
serde.workspace = true
serde_json = "1.0.93"
tokio.workspace = true
tracing = "0.1.37" # logging & tracing api
@ -45,6 +47,7 @@ derive_more = "0.99.17"
uuid = { version = "1.3.0", features = ["v4"] }
sqlx = { version = "0.7.0-alpha.2", features = ["sqlite", "migrate"] }
proto-irc = { path = "crates/proto-irc" }
mgmt-api = { path = "crates/mgmt-api" }
clap.workspace = true
[dev-dependencies]

View File

@ -0,0 +1,8 @@
[package]
name = "mgmt-api"
edition = "2021"
version.workspace = true
publish = false
[dependencies]
serde.workspace = true

View File

@ -0,0 +1,29 @@
use serde::{Deserialize, Serialize};
#[derive(Serialize, Deserialize)]
pub struct ErrorResponse<'a> {
pub code: &'a str,
pub message: &'a str,
}
#[derive(Serialize, Deserialize)]
pub struct CreatePlayerRequest<'a> {
pub name: &'a str,
}
#[derive(Serialize, Deserialize)]
pub struct ChangePasswordRequest<'a> {
pub player_name: &'a str,
pub password: &'a str,
}
pub mod paths {
pub const CREATE_PLAYER: &'static str = "/mgmt/create_player";
pub const SET_PASSWORD: &'static str = "/mgmt/set_password";
}
pub mod errors {
pub const INVALID_PATH: &'static str = "invalid_path";
pub const MALFORMED_REQUEST: &'static str = "malformed_request";
pub const PLAYER_NOT_FOUND: &'static str = "player_not_found";
}

1
rustfmt.toml Normal file
View File

@ -0,0 +1 @@
max_width = 120

View File

@ -5,7 +5,7 @@ use std::sync::Arc;
use serde::Deserialize;
use sqlx::sqlite::SqliteConnectOptions;
use sqlx::{ConnectOptions, Connection, FromRow, SqliteConnection};
use sqlx::{ConnectOptions, Connection, FromRow, Sqlite, SqliteConnection, Transaction};
use tokio::sync::Mutex;
use crate::prelude::*;
@ -102,6 +102,50 @@ impl Storage {
res.close().await?;
Ok(())
}
pub async fn create_user(&mut self, name: &str) -> Result<()> {
let query = sqlx::query(
"insert into users(name)
values (?);",
)
.bind(name);
let mut executor = self.conn.lock().await;
query.execute(&mut *executor).await?;
Ok(())
}
pub async fn set_password<'a>(&'a mut self, name: &'a str, pwd: &'a str) -> Result<Option<()>> {
async fn inner(txn: &mut Transaction<'_, Sqlite>, name: &str, pwd: &str) -> Result<Option<()>> {
let id: Option<(u32,)> = sqlx::query_as("select * from users where name = ? limit 1;")
.bind(name)
.fetch_optional(&mut **txn)
.await?;
let Some((id,)) = id else {
return Ok(None);
};
sqlx::query("insert or replace into challenges_plain_password(user_id, password) values (?, ?);")
.bind(id)
.bind(pwd)
.execute(&mut **txn)
.await?;
Ok(Some(()))
}
let mut executor = self.conn.lock().await;
let mut tx = executor.begin().await?;
let res = inner(&mut tx, name, pwd).await;
match res {
Ok(e) => {
tx.commit().await?;
Ok(e)
}
Err(e) => {
tx.rollback().await?;
Err(e)
}
}
}
}
#[derive(FromRow)]

View File

@ -65,7 +65,7 @@ async fn main() -> Result<()> {
let rooms = RoomRegistry::new(&mut metrics, storage.clone())?;
let mut players = PlayerRegistry::empty(rooms.clone(), &mut metrics)?;
let telemetry_terminator =
util::telemetry::launch(telemetry_config, metrics.clone(), rooms.clone()).await?;
util::telemetry::launch(telemetry_config, metrics.clone(), rooms.clone(), storage.clone()).await?;
let irc = projections::irc::launch(irc_config, players.clone(), rooms.clone(), metrics.clone(), storage.clone()).await?;
let xmpp = projections::xmpp::launch(xmpp_config, players.clone(), rooms.clone(), metrics.clone()).await?;
tracing::info!("Started");

View File

@ -1,10 +0,0 @@
use std::convert::Infallible;
use http_body_util::Full;
use hyper::{body::Bytes, Response, StatusCode};
pub fn not_found() -> std::result::Result<Response<Full<Bytes>>, Infallible> {
let mut response = Response::new(Full::new(Bytes::from("404")));
*response.status_mut() = StatusCode::NOT_FOUND;
Ok(response)
}

View File

@ -1,6 +1,5 @@
use crate::prelude::*;
pub mod http;
pub mod table;
pub mod telemetry;
#[cfg(test)]

View File

@ -6,18 +6,18 @@ use http_body_util::{BodyExt, Full};
use hyper::body::Bytes;
use hyper::server::conn::http1;
use hyper::service::service_fn;
use hyper::{Method, Request, Response};
use hyper::{Method, Request, Response, StatusCode};
use prometheus::{Encoder, Registry as MetricsRegistry, TextEncoder};
use serde::Deserialize;
use serde::{Deserialize, Serialize};
use tokio::net::TcpListener;
use crate::core::repo::Storage;
use crate::core::room::RoomRegistry;
use crate::prelude::*;
use crate::util::http::*;
use crate::util::Terminator;
type BoxBody = http_body_util::combinators::BoxBody<Bytes, Infallible>;
use mgmt_api::*;
type HttpResult<T> = std::result::Result<T, Infallible>;
#[derive(Deserialize, Debug)]
@ -29,11 +29,12 @@ pub async fn launch(
config: ServerConfig,
metrics: MetricsRegistry,
rooms: RoomRegistry,
storage: Storage,
) -> Result<Terminator> {
log::info!("Starting the telemetry service");
let listener = TcpListener::bind(config.listen_on).await?;
log::debug!("Listener started");
let terminator = Terminator::spawn(|rx| main_loop(listener, metrics, rooms, rx.map(|_| ())));
let terminator = Terminator::spawn(|rx| main_loop(listener, metrics, rooms, storage, rx.map(|_| ())));
Ok(terminator)
}
@ -41,6 +42,7 @@ async fn main_loop(
listener: TcpListener,
metrics: MetricsRegistry,
rooms: RoomRegistry,
storage: Storage,
termination: impl Future<Output = ()>,
) -> Result<()> {
pin!(termination);
@ -52,10 +54,12 @@ async fn main_loop(
let (stream, _) = result?;
let metrics = metrics.clone();
let rooms = rooms.clone();
let storage = storage.clone();
tokio::task::spawn(async move {
let registry = metrics.clone();
let rooms = rooms.clone();
let server = http1::Builder::new().serve_connection(stream, service_fn(move |r| route(registry.clone(), rooms.clone(), r)));
let storage = storage.clone();
let server = http1::Builder::new().serve_connection(stream, service_fn(move |r| route(registry.clone(), rooms.clone(), storage.clone(), r)));
if let Err(err) = server.await {
tracing::error!("Error serving connection: {:?}", err);
}
@ -70,27 +74,124 @@ async fn main_loop(
async fn route(
registry: MetricsRegistry,
rooms: RoomRegistry,
storage: Storage,
request: Request<hyper::body::Incoming>,
) -> std::result::Result<Response<BoxBody>, Infallible> {
match (request.method(), request.uri().path()) {
(&Method::GET, "/metrics") => Ok(endpoint_metrics(registry)?.map(BodyExt::boxed)),
(&Method::GET, "/rooms") => Ok(endpoint_rooms(rooms).await?.map(BodyExt::boxed)),
_ => Ok(not_found()?.map(BodyExt::boxed)),
}
) -> HttpResult<Response<Full<Bytes>>> {
let res = match (request.method(), request.uri().path()) {
(&Method::GET, "/metrics") => endpoint_metrics(registry),
(&Method::GET, "/rooms") => endpoint_rooms(rooms).await,
(&Method::POST, paths::CREATE_PLAYER) => endpoint_create_player(request, storage).await.or5xx(),
(&Method::POST, paths::SET_PASSWORD) => endpoint_set_password(request, storage).await.or5xx(),
_ => not_found(),
};
Ok(res)
}
fn endpoint_metrics(registry: MetricsRegistry) -> HttpResult<Response<Full<Bytes>>> {
fn endpoint_metrics(registry: MetricsRegistry) -> Response<Full<Bytes>> {
let mf = registry.gather();
let mut buffer = vec![];
TextEncoder
.encode(&mf, &mut buffer)
.expect("write to vec cannot fail");
Ok(Response::new(Full::new(Bytes::from(buffer))))
TextEncoder.encode(&mf, &mut buffer).expect("write to vec cannot fail");
Response::new(Full::new(Bytes::from(buffer)))
}
async fn endpoint_rooms(rooms: RoomRegistry) -> HttpResult<Response<Full<Bytes>>> {
let room_list = rooms.get_all_rooms().await;
let mut buffer = vec![];
serde_json::to_writer(&mut buffer, &room_list).expect("unexpected fail when writing to vec");
Ok(Response::new(Full::new(Bytes::from(buffer))))
async fn endpoint_rooms(rooms: RoomRegistry) -> Response<Full<Bytes>> {
// TODO introduce management API types independent from core-domain types
// TODO remove `Serialize` implementations from all core-domain types
let room_list = rooms.get_all_rooms().await.to_body();
Response::new(room_list)
}
async fn endpoint_create_player(
request: Request<hyper::body::Incoming>,
mut storage: Storage,
) -> Result<Response<Full<Bytes>>> {
let str = request.collect().await?.to_bytes();
let Ok(res) = serde_json::from_slice::<CreatePlayerRequest>(&str[..]) else {
let payload = ErrorResponse {
code: errors::MALFORMED_REQUEST,
message: "The request payload contains incorrect JSON value",
}
.to_body();
let mut response = Response::new(payload);
*response.status_mut() = StatusCode::BAD_REQUEST;
return Ok(response);
};
storage.create_user(&res.name).await?;
log::info!("Player {} created", res.name);
let mut response = Response::new(Full::<Bytes>::default());
*response.status_mut() = StatusCode::CREATED;
Ok(response)
}
async fn endpoint_set_password(
request: Request<hyper::body::Incoming>,
mut storage: Storage,
) -> Result<Response<Full<Bytes>>> {
let str = request.collect().await?.to_bytes();
let Ok(res) = serde_json::from_slice::<ChangePasswordRequest>(&str[..]) else {
let payload = ErrorResponse {
code: errors::MALFORMED_REQUEST,
message: "The request payload contains incorrect JSON value",
}
.to_body();
let mut response = Response::new(payload);
*response.status_mut() = StatusCode::BAD_REQUEST;
return Ok(response);
};
let Some(_) = storage.set_password(&res.player_name, &res.password).await? else {
let payload = ErrorResponse {
code: errors::PLAYER_NOT_FOUND,
message: "No such player exists",
}
.to_body();
let mut response = Response::new(payload);
*response.status_mut() = StatusCode::UNPROCESSABLE_ENTITY;
return Ok(response);
};
log::info!("Password changed for player {}", res.player_name);
let mut response = Response::new(Full::<Bytes>::default());
*response.status_mut() = StatusCode::NO_CONTENT;
Ok(response)
}
pub fn not_found() -> Response<Full<Bytes>> {
let payload = ErrorResponse {
code: errors::INVALID_PATH,
message: "The path does not exist",
}
.to_body();
let mut response = Response::new(payload);
*response.status_mut() = StatusCode::NOT_FOUND;
response
}
trait Or5xx {
fn or5xx(self) -> Response<Full<Bytes>>;
}
impl Or5xx for Result<Response<Full<Bytes>>> {
fn or5xx(self) -> Response<Full<Bytes>> {
match self {
Ok(e) => e,
Err(e) => {
let mut response = Response::new(Full::new(e.to_string().into()));
*response.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
response
}
}
}
}
trait ToBody {
fn to_body(&self) -> Full<Bytes>;
}
impl<T> ToBody for T
where
T: Serialize,
{
fn to_body(&self) -> Full<Bytes> {
let mut buffer = vec![];
serde_json::to_writer(&mut buffer, self).expect("unexpected fail when writing to vec");
Full::new(Bytes::from(buffer))
}
}