Compare commits

...

74 Commits

Author SHA1 Message Date
Nikita Vilunov 3476f2e367 update dependencies 2024-06-18 14:26:39 +02:00
Nikita Vilunov c0aaa26010 irc: split test timeouts and increase for password verification (#75)
Reviewed-on: lavina/lavina#75
2024-06-09 09:22:35 +00:00
Nikita Vilunov 26cc2f178c irc: remove "None" fake capability 2024-06-05 03:05:54 +02:00
Nikita Vilunov 25fe041698 irc: move CHATHISTORY message handling into a separate module 2024-06-05 02:59:22 +02:00
Nikita Vilunov a22cde0ea8 irc: rename Handler to IrcCommand and its method to handle_with 2024-06-05 02:57:20 +02:00
Nikita Vilunov 59528909c7 core: ADT for results of room history queries 2024-06-05 02:04:11 +02:00
Nikita Vilunov f3cd794431 core: do not ignore errors on sending to channels 2024-06-05 00:43:39 +02:00
Nikita Vilunov 2c828b482e core: ADT for results of room joins 2024-06-05 00:10:36 +02:00
Nikita Vilunov 4e8eb09184 reduce usage of unwraps (#70)
Reviewed-on: lavina/lavina#70
2024-06-04 21:54:57 +00:00
Nikita Vilunov d0420ec834 core: use members instead of subscribers in RoomInfo 2024-06-03 13:06:41 +02:00
Nikita Vilunov e48a6d3b0b irc: test scenario with reboot and two users to cover bug #72 2024-06-03 13:06:29 +02:00
Mikhail bb0b911e5e irc: basic `chathistory` capability support without `batch` markers (#73)
Reviewed-on: lavina/lavina#73
Co-authored-by: Mikhail <mikhail@liamets.dev>
Co-committed-by: Mikhail <mikhail@liamets.dev>
2024-06-01 11:34:53 +00:00
Nikita Vilunov 1a21c05d7d xmpp: add support leaving MUCs via unavailable presence (#71)
Reviewed-on: lavina/lavina#71
2024-05-27 14:24:23 +00:00
Nikita Vilunov 43d105ab23 core: subscribe to rooms on player actor startup 2024-05-26 18:56:28 +02:00
Nikita Vilunov f02971d407 xmpp: add tracing instrumentations 2024-05-26 13:55:28 +02:00
Mikhail 381b5650bc xmpp, core: Send message history on MUC join (#68)
Re-send the entire message history on MUC join. Contributes to #5.

Reviewed-on: lavina/lavina#68
Co-authored-by: Mikhail <mikhail@liamets.dev>
Co-committed-by: Mikhail <mikhail@liamets.dev>
2024-05-26 11:20:26 +00:00
Nikita Vilunov bce8b332d2 irc: handle repeated joins correctly 2024-05-25 12:40:58 +02:00
Nikita Vilunov 580923814b xmpp: send user-item and empty room subject on muc join (#69)
Co-authored-by: Mikhail <mikhail@liamets.dev>
Reviewed-on: lavina/lavina#69
2024-05-22 09:29:44 +00:00
Nikita Vilunov 1b59250042 xmpp: add x-user element to muc presence response (#67)
Reviewed-on: lavina/lavina#67
2024-05-14 14:44:49 +00:00
Nikita Vilunov 89918d9de1 xmpp: add item-not-found error condition to room disco#info iq 2024-05-13 16:42:52 +02:00
Nikita Vilunov 26720a2a08 core: separate the model from the logic implementation (#66)
This separates the core in two layers – the model objects and the `LavinaCore` service. Service is responsible for implementing the application logic and exposing it as core's public API to projections, while the model objects will be independent of each other and responsible only for managing and owning in-memory data.

The model objects include:
1. `Storage` – the open connection to the SQLite DB.
2. `PlayerRegistry` – creates, stores refs to, and stops player actors.
3. `RoomRegistry` – manages active rooms.
4. `DialogRegistry` – manages active dialogs.
5. `Broadcasting` – manages subscriptions of players to rooms on remote cluster nodes.
6. `LavinaClient` – manages HTTP connections to remote cluster nodes.
7. `ClusterMetadata` – read-only configuration of the cluster metadata, i.e. allocation of entities to nodes.

As a result:
1. Model objects will be fully independent of each other, e.g. it's no longer necessary to provide a `Storage` to all registries, or to provide `PlayerRegistry` and `DialogRegistry` to each other.
2. Model objects will no longer be `Arc`-wrapped; instead the whole `Services` object will be `Arc`ed and provided to projections.
3. The public API of `lavina-core` will be properly delimited by the APIs of `LavinaCore`, `PlayerConnection` and so on.
4. `LavinaCore` and `PlayerConnection` will also contain APIs of all features, unlike it was previously with `RoomRegistry` and `DialogRegistry`. This is unfortunate, but it could be improved in future.

Reviewed-on: lavina/lavina#66
2024-05-13 14:32:45 +00:00
Nikita Vilunov d1a72e7978 move repo methods into submodules and clean up warnings 2024-05-10 23:50:34 +02:00
Nikita Vilunov 6749103726 scalability: initial support for remote rooms (#61)
Reviewed-on: lavina/lavina#61
2024-05-10 20:44:24 +00:00
Mikhail 3b454ad7cd xmpp: unit-tests for resource bind it and muc presence
Reviewed-on: lavina/lavina#64
Co-authored-by: Mikhail <mikhail@liamets.dev>
Co-committed-by: Mikhail <mikhail@liamets.dev>
2024-05-10 13:35:34 +00:00
Mikhail 5512a74999 Check if user is a member before inserting a membership (#62)
It would typically fail on insertion due to uniqueness constraints: user id - room id.

Reviewed-on: lavina/lavina#62
Co-authored-by: Mikhail <mikhail@liamets.dev>
Co-committed-by: Mikhail <mikhail@liamets.dev>
2024-05-08 22:10:32 +00:00
Nikita Vilunov 7f2c6a1013 continue propagated traces in http request handlers 2024-05-05 19:24:58 +02:00
Nikita Vilunov bb0fe3bf0b use borrows in http endpoint handlers 2024-05-05 19:24:42 +02:00
Nikita Vilunov 8ac64ba8f5 get rid of storage usages in projections 2024-05-05 19:24:23 +02:00
homycdev abe9a26925 irc: implement WHOIS command (#43)
Reviewed-on: lavina/lavina#43
Co-authored-by: homycdev <abdulkhamid98@gmail.com>
Co-committed-by: homycdev <abdulkhamid98@gmail.com>
2024-05-05 17:21:40 +00:00
Mikhail adf1d8c14c xmpp: Implement Message Archive Management stub for XEP-0313 (#60)
https://xmpp.org/extensions/xep-0313.html
Reviewed-on: lavina/lavina#60
Co-authored-by: Mikhail <mikhail@liamets.dev>
Co-committed-by: Mikhail <mikhail@liamets.dev>
2024-05-05 15:12:58 +00:00
Nikita Vilunov 9a09ff717e management api endpoints for rooms 2024-05-01 17:30:31 +02:00
Nikita Vilunov a87f7c9d73 xmpp: extract common fragments of integration tests into functions 2024-04-29 23:56:18 +02:00
Nikita Vilunov 25605322a0 player shutdown API (#58)
Reviewed-on: lavina/lavina#58
2024-04-29 17:24:43 +00:00
Nikita Vilunov 31f9da9b05 xmpp: fix incorrect auth test 2024-04-29 19:13:32 +02:00
Nikita Vilunov c1dc2df150 xmpp: document xml parsing types 2024-04-28 17:29:31 +02:00
Nikita Vilunov c69513f38b xmpp: use mutable namespace and event in parser coroutines 2024-04-28 17:19:31 +02:00
Nikita Vilunov 8ec9ecfe2c xmpp: handle incorrect credentials by replying with an error 2024-04-28 17:11:29 +02:00
Nikita Vilunov a047d55ab5 xmpp: handle correctly unavailable self-presence and improve basic test scenario 2024-04-28 15:43:22 +02:00
Nikita Vilunov ea81ddadfc dialog message persistence 2024-04-27 12:58:27 +02:00
Nikita Vilunov 4b5ab02322 start next version 2024-04-26 13:43:43 +02:00
Nikita Vilunov 843d0e9c82 bump version 2024-04-26 13:31:47 +02:00
Nikita Vilunov 72f5010988 clean up http.rs a little 2024-04-26 12:28:13 +02:00
Nikita Vilunov 4ff09ea05f tracing otlp exporter and instrumentation annotations (#57)
Resolves #56

Reviewed-on: lavina/lavina#57
2024-04-26 10:16:23 +00:00
Nikita Vilunov ec49489ef1 validate that rooms and dialogs are owned exclusively on shutdown 2024-04-23 19:18:46 +02:00
Nikita Vilunov d305f5bf77 argon2-based password hashing (#55)
Reviewed-on: lavina/lavina#55
2024-04-23 16:31:00 +00:00
Nikita Vilunov 799da8366c basic dialog implementation with irc and xmpp support (#53)
Reviewed-on: lavina/lavina#53
2024-04-23 16:26:40 +00:00
Nikita Vilunov d805061d5b refactor auth logic into a common module (#54)
Reviewed-on: lavina/lavina#54
2024-04-23 10:10:10 +00:00
Nikita Vilunov 6c08d69f41 sasl: remove unused code 2024-04-23 00:41:54 +02:00
Nikita Vilunov 12d30ca5c2 irc: implement server-time capability for incoming messages (#52)
Spec: https://ircv3.net/specs/extensions/server-time
Reviewed-on: lavina/lavina#52
2024-04-21 21:00:44 +00:00
Nikita Vilunov ddb348bee9 refactor lavina core by grouping public services into a new LavinaCore struct.
this will be useful in future when additional services will be introduced and passed as dependencies
2024-04-21 19:45:50 +02:00
Nikita Vilunov 5a09b743c9 return AlreadyJoined when a player attempts to join a room they are already in 2024-04-20 17:09:44 +02:00
Nikita Vilunov cebe354179 update libraries 2024-04-19 14:27:19 +02:00
Nikita Vilunov 02a8309d9e xmpp: relax the jid regex a bit 2024-04-18 01:42:28 +02:00
Nikita Vilunov fbb3d4f4f9 xmpp: rewrite xml element parsers using coroutines 2024-04-16 17:44:34 +02:00
Nikita Vilunov 048660624d irc: support registration with different order of NICK/USER/CAP END commands (#51)
Resolves #33

Reviewed-on: lavina/lavina#51
2024-04-16 11:35:14 +00:00
Nikita Vilunov 6bba699d87 xmpp: disco-info iq for rooms 2024-04-15 23:08:43 +02:00
Nikita Vilunov 6d493d83a3 xmpp: use the Jid type in IQs' to and from fields, separate presence handling 2024-04-15 18:18:51 +02:00
Nikita Vilunov 757d7c5665 persistent room topics (#50)
Reviewed-on: lavina/lavina#50
2024-04-15 09:12:23 +00:00
Nikita Vilunov 0105a5b710 persistent memberships (#49)
Reviewed-on: lavina/lavina#49
2024-04-15 09:06:10 +00:00
Nikita Vilunov 57b6af8732 xmpp: configurable server hostname (#47)
Reviewed-on: lavina/lavina#47
2024-04-15 00:33:26 +00:00
Nikita Vilunov 0944c449ca xmpp: in integration tests extract server startup code 2024-04-13 02:32:41 +02:00
Mikhail fd694cd75c Add message timestamps (#41)
Resolves #38

Reviewed-on: lavina/lavina#41
Co-authored-by: Mikhail <mikhail@liamets.dev>
Co-committed-by: Mikhail <mikhail@liamets.dev>
2024-04-12 21:32:21 +00:00
Nikita Vilunov cccc05afe9 xmpp: ignore text elements with spaces at the stream root 2024-04-11 23:08:09 +02:00
Nikita Vilunov 8b099f9be2 xmpp: fix handling of the `bind` iq 2024-04-07 12:06:23 +00:00
Nikita Vilunov 36b0d50d51 irc: allow PART without a reason 2024-04-06 23:01:24 +00:00
Nikita Vilunov adece11fef xmpp: make xml-headers optional in the c2s stream 2024-04-06 22:37:27 +00:00
Nikita Vilunov ab61e975bf xmpp: send correct errors on unknown iqs 2024-04-06 22:37:27 +00:00
Nikita Vilunov fd437df67e xmpp: buffer data outgoing over a TLS-stream 2024-04-06 22:35:01 +00:00
Nikita Vilunov a325c7307c irc: improve tests and remove tail space in chan member list 2024-04-06 22:34:11 +00:00
Nikita Vilunov d436631450 improve docs and split command handlers into methods (#40) 2024-03-26 16:26:31 +00:00
Nikita Vilunov 878ec33cbb apply uniform formatting 2024-03-20 19:59:15 +01:00
Nikita Vilunov 1d9937319e update dependencies 2024-03-20 19:53:51 +01:00
homycdev 4b1958b5ae irc: remove hardcoded text from welcome messages
- use server name in welcome message
- use app version of crate in app_version field

Reviewed-on: lavina/lavina#35
Co-authored-by: homycdev <abdulkhamid98@gmail.com>
Co-committed-by: homycdev <abdulkhamid98@gmail.com>
2024-03-15 00:54:55 +00:00
JustTestingV c6fb74a848 termination usage for stopping the socket connection gracefully (#34)
Reviewed-on: lavina/lavina#34
Co-authored-by: JustTestingV <JustTestingV@gmail.com>
Co-committed-by: JustTestingV <JustTestingV@gmail.com>
2024-02-18 16:46:29 +00:00
77 changed files with 7089 additions and 1951 deletions

View File

@ -12,7 +12,7 @@ jobs:
uses: https://github.com/actions-rs/cargo@v1 uses: https://github.com/actions-rs/cargo@v1
with: with:
command: fmt command: fmt
args: "--check -p mgmt-api -p lavina-core -p projection-irc -p projection-xmpp -p sasl" args: "--check --all"
- name: cargo check - name: cargo check
uses: https://github.com/actions-rs/cargo@v1 uses: https://github.com/actions-rs/cargo@v1
with: with:

2
.gitignore vendored
View File

@ -1,4 +1,4 @@
/target /target
/db.sqlite *.sqlite
.idea/ .idea/
.DS_Store .DS_Store

21
.pre-commit-config.yaml Normal file
View File

@ -0,0 +1,21 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
hooks:
- id: check-toml
- id: end-of-file-fixer
- id: fix-byte-order-marker
- id: mixed-line-ending
- id: trailing-whitespace
- repo: local
hooks:
- id: fmt
name: fmt
description: Format
entry: cargo fmt
language: system
args:
- --all
types: [ rust ]
pass_filenames: false

21
.run/Run lavina.run.xml Normal file
View File

@ -0,0 +1,21 @@
<component name="ProjectRunConfigurationManager">
<configuration default="false" name="Run lavina" type="CargoCommandRunConfiguration" factoryName="Cargo Command">
<option name="command" value="run --package lavina --bin lavina -- --config config.toml" />
<option name="workingDirectory" value="file://$PROJECT_DIR$" />
<envs>
<env name="RUST_LOG" value="debug" />
</envs>
<option name="emulateTerminal" value="true" />
<option name="channel" value="DEFAULT" />
<option name="requiredFeatures" value="true" />
<option name="allFeatures" value="false" />
<option name="withSudo" value="false" />
<option name="buildTarget" value="REMOTE" />
<option name="backtrace" value="FULL" />
<option name="isRedirectInput" value="false" />
<option name="redirectInputPath" value="" />
<method v="2">
<option name="CARGO.BUILD_TASK_PROVIDER" enabled="true" />
</method>
</configuration>
</component>

1570
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -10,7 +10,7 @@ members = [
] ]
[workspace.package] [workspace.package]
version = "0.0.2-dev" version = "0.0.3-dev"
[workspace.dependencies] [workspace.dependencies]
nom = "7.1.3" nom = "7.1.3"
@ -18,8 +18,8 @@ assert_matches = "1.5.0"
tokio = { version = "1.24.1", features = ["full"] } # async runtime tokio = { version = "1.24.1", features = ["full"] } # async runtime
futures-util = "0.3.25" futures-util = "0.3.25"
anyhow = "1.0.68" # error utils anyhow = "1.0.68" # error utils
nonempty = "0.8.1" nonempty = "0.10.0"
quick-xml = { version = "0.30.0", features = ["async-tokio"] } quick-xml = { version = "0.32.0", features = ["async-tokio"] }
lazy_static = "1.4.0" lazy_static = "1.4.0"
regex = "1.7.1" regex = "1.7.1"
derive_more = "0.99.17" derive_more = "0.99.17"
@ -27,10 +27,13 @@ clap = { version = "4.4.4", features = ["derive"] }
serde = { version = "1.0.152", features = ["rc", "serde_derive"] } serde = { version = "1.0.152", features = ["rc", "serde_derive"] }
tracing = "0.1.37" # logging & tracing api tracing = "0.1.37" # logging & tracing api
prometheus = { version = "0.13.3", default-features = false } prometheus = { version = "0.13.3", default-features = false }
base64 = "0.21.3" base64 = "0.22.0"
lavina-core = { path = "crates/lavina-core" } lavina-core = { path = "crates/lavina-core" }
tracing-subscriber = "0.3.16" tracing-subscriber = "0.3.16"
sasl = { path = "crates/sasl" } sasl = { path = "crates/sasl" }
chrono = "0.4.37"
reqwest = { version = "0.12.0", default-features = false, features = ["json"] }
opentelemetry = "0.22.0"
[package] [package]
name = "lavina" name = "lavina"
@ -58,8 +61,14 @@ projection-irc = { path = "crates/projection-irc" }
projection-xmpp = { path = "crates/projection-xmpp" } projection-xmpp = { path = "crates/projection-xmpp" }
mgmt-api = { path = "crates/mgmt-api" } mgmt-api = { path = "crates/mgmt-api" }
clap.workspace = true clap.workspace = true
opentelemetry.workspace = true
opentelemetry-semantic-conventions = "0.14.0"
opentelemetry_sdk = { version = "0.22.1", features = ["rt-tokio"] }
opentelemetry-otlp = "0.15.0"
tracing-opentelemetry = "0.23.0"
chrono.workspace = true
[dev-dependencies] [dev-dependencies]
assert_matches.workspace = true assert_matches.workspace = true
regex = "1.7.1" regex = "1.7.1"
reqwest = { version = "0.11", default-features = false } reqwest.workspace = true

30
config.0.toml Normal file
View File

@ -0,0 +1,30 @@
[telemetry]
listen_on = "127.0.0.1:8080"
[irc]
listen_on = "127.0.0.1:6667"
server_name = "irc.localhost"
[xmpp]
listen_on = "127.0.0.1:5222"
cert = "./certs/xmpp.pem"
key = "./certs/xmpp.key"
hostname = "localhost"
[storage]
db_path = "db.0.sqlite"
[cluster]
addresses = [
"127.0.0.1:8080",
"127.0.0.1:8081",
]
[cluster.metadata]
node_id = 0
main_owner = 0
rooms = { aaaaa = 1, test = 0 }
[tracing]
endpoint = "http://localhost:4317"
service_name = "lavina-0"

30
config.1.toml Normal file
View File

@ -0,0 +1,30 @@
[telemetry]
listen_on = "127.0.0.1:8081"
[irc]
listen_on = "127.0.0.1:6668"
server_name = "irc.localhost"
[xmpp]
listen_on = "127.0.0.1:5223"
cert = "./certs/xmpp.pem"
key = "./certs/xmpp.key"
hostname = "localhost"
[storage]
db_path = "db.1.sqlite"
[cluster]
addresses = [
"127.0.0.1:8080",
"127.0.0.1:8081",
]
[cluster.metadata]
node_id = 1
main_owner = 0
rooms = { aaaaa = 1, test = 0 }
[tracing]
endpoint = "http://localhost:4317"
service_name = "lavina-1"

View File

@ -9,6 +9,15 @@ server_name = "irc.localhost"
listen_on = "127.0.0.1:5222" listen_on = "127.0.0.1:5222"
cert = "./certs/xmpp.pem" cert = "./certs/xmpp.pem"
key = "./certs/xmpp.key" key = "./certs/xmpp.key"
hostname = "localhost"
[storage] [storage]
db_path = "db.sqlite" db_path = "db.sqlite"
[cluster]
addresses = []
[cluster.metadata]
node_id = 0
main_owner = 0
rooms = {}

View File

@ -5,8 +5,16 @@ version.workspace = true
[dependencies] [dependencies]
anyhow.workspace = true anyhow.workspace = true
sqlx = { version = "0.7.0-alpha.2", features = ["sqlite", "migrate"] } sqlx = { version = "0.7.4", features = ["sqlite", "migrate", "chrono"] }
serde.workspace = true serde.workspace = true
tokio.workspace = true tokio.workspace = true
tracing.workspace = true tracing.workspace = true
prometheus.workspace = true prometheus.workspace = true
chrono.workspace = true
argon2 = { version = "0.5.3" }
rand_core = { version = "0.6.4", features = ["getrandom"] }
reqwest.workspace = true
reqwest-middleware = { version = "0.3", features = ["json"] }
opentelemetry.workspace = true
mgmt-api = { path = "../mgmt-api" }
reqwest-tracing = { version = "0.5", features = ["opentelemetry_0_22"] }

View File

@ -0,0 +1 @@
alter table messages add column created_at text;

View File

@ -0,0 +1,17 @@
create table dialogs(
id integer primary key autoincrement not null,
participant_1 integer not null,
participant_2 integer not null,
created_at timestamp not null,
message_count integer not null default 0,
unique (participant_1, participant_2)
);
create table dialog_messages(
dialog_id integer not null,
id integer not null, -- unique per dialog, sequential in one dialog
author_id integer not null,
content string not null,
created_at timestamp not null,
primary key (dialog_id, id)
);

View File

@ -0,0 +1,4 @@
create table challenges_argon2_password(
user_id integer primary key not null,
hash string not null
);

View File

@ -0,0 +1,2 @@
alter table messages drop column created_at;
alter table messages add column created_at datetime default "1970-01-01T00:00:00Z";

View File

@ -0,0 +1,58 @@
use anyhow::{anyhow, Result};
use argon2::password_hash::{PasswordHash, PasswordHasher, PasswordVerifier, SaltString};
use argon2::Argon2;
use rand_core::OsRng;
use crate::LavinaCore;
pub enum Verdict {
Authenticated,
UserNotFound,
InvalidPassword,
}
pub enum UpdatePasswordResult {
PasswordUpdated,
UserNotFound,
}
impl LavinaCore {
#[tracing::instrument(skip(self, provided_password), name = "Services::authenticate")]
pub async fn authenticate(&self, login: &str, provided_password: &str) -> Result<Verdict> {
let Some(stored_user) = self.services.storage.retrieve_user_by_name(login).await? else {
return Ok(Verdict::UserNotFound);
};
if let Some(argon2_hash) = stored_user.argon2_hash {
let argon2 = Argon2::default();
let password_hash =
PasswordHash::new(&argon2_hash).map_err(|e| anyhow!("Failed to parse password hash: {e:?}"))?;
let password_verifier = argon2.verify_password(provided_password.as_bytes(), &password_hash);
if password_verifier.is_ok() {
return Ok(Verdict::Authenticated);
}
}
if let Some(expected_password) = stored_user.password {
if expected_password == provided_password {
return Ok(Verdict::Authenticated);
}
}
Ok(Verdict::InvalidPassword)
}
#[tracing::instrument(skip(self, provided_password), name = "Services::set_password")]
pub async fn set_password(&self, login: &str, provided_password: &str) -> Result<UpdatePasswordResult> {
let Some(u) = self.services.storage.retrieve_user_by_name(login).await? else {
return Ok(UpdatePasswordResult::UserNotFound);
};
let salt = SaltString::generate(&mut OsRng);
let argon2 = Argon2::default();
let password_hash = argon2
.hash_password(provided_password.as_bytes(), &salt)
.map_err(|e| anyhow!("Failed to hash password: {e:?}"))?;
self.services.storage.set_argon2_challenge(u.id, password_hash.to_string().as_str()).await?;
tracing::info!("Password changed for player {login}");
Ok(UpdatePasswordResult::PasswordUpdated)
}
}

View File

@ -0,0 +1,56 @@
use std::collections::HashMap;
use std::net::SocketAddr;
use anyhow::{anyhow, Result};
use reqwest::Client;
use reqwest_middleware::{ClientBuilder, ClientWithMiddleware};
use reqwest_tracing::{DefaultSpanBackend, TracingMiddleware};
use serde::{Deserialize, Serialize};
pub mod broadcast;
pub mod room;
type Addresses = Vec<SocketAddr>;
#[derive(Deserialize, Debug, Clone)]
pub struct ClusterConfig {
pub metadata: ClusterMetadata,
pub addresses: Addresses,
}
#[derive(Deserialize, Debug, Clone)]
pub struct ClusterMetadata {
/// The node id of the current node.
pub node_id: u32,
/// Owns all rooms in the cluster except the ones specified in `rooms`.
pub main_owner: u32,
pub rooms: HashMap<String, u32>,
}
pub struct LavinaClient {
addresses: Addresses,
client: ClientWithMiddleware,
}
impl LavinaClient {
pub fn new(addresses: Addresses) -> Self {
let client = ClientBuilder::new(Client::new()).with(TracingMiddleware::<DefaultSpanBackend>::new()).build();
Self { addresses, client }
}
async fn send_request(&self, node_id: u32, path: &str, req: impl Serialize) -> Result<()> {
let Some(address) = self.addresses.get(node_id as usize) else {
return Err(anyhow!("Unknown node"));
};
match self.client.post(format!("http://{}{}", address, path)).json(&req).send().await {
Ok(res) => {
if res.status().is_server_error() || res.status().is_client_error() {
tracing::error!("Cluster request failed: {:?}", res);
return Err(anyhow!("Server error"));
}
Ok(())
}
Err(e) => Err(e.into()),
}
}
}

View File

@ -0,0 +1,53 @@
use std::collections::{HashMap, HashSet};
use chrono::{DateTime, Utc};
use tokio::sync::Mutex;
use crate::player::{PlayerId, Updates};
use crate::prelude::Str;
use crate::room::RoomId;
use crate::Services;
/// Receives updates from other nodes and broadcasts them to local player actors.
struct BroadcastingInner {
subscriptions: HashMap<RoomId, HashSet<PlayerId>>,
}
pub struct Broadcasting(Mutex<BroadcastingInner>);
impl Broadcasting {
pub fn new() -> Self {
let inner = BroadcastingInner {
subscriptions: HashMap::new(),
};
Self(Mutex::new(inner))
}
}
impl Services {
/// Broadcasts the given update to subscribed player actors on local node.
#[tracing::instrument(skip(self, message, created_at))]
pub async fn broadcast(&self, room_id: RoomId, author_id: PlayerId, message: Str, created_at: DateTime<Utc>) {
let inner = self.broadcasting.0.lock().await;
let Some(subscribers) = inner.subscriptions.get(&room_id) else {
return;
};
let update = Updates::NewMessage {
room_id: room_id.clone(),
author_id: author_id.clone(),
body: message.clone(),
created_at: created_at.clone(),
};
for i in subscribers {
if i == &author_id {
continue;
}
let Some(player) = self.players.get_player(i).await else {
continue;
};
player.update(update.clone()).await;
}
}
pub async fn subscribe(&self, subscriber: PlayerId, room_id: RoomId) {
self.broadcasting.0.lock().await.subscriptions.entry(room_id).or_insert_with(HashSet::new).insert(subscriber);
}
}

View File

@ -0,0 +1,88 @@
use anyhow::Result;
use serde::{Deserialize, Serialize};
use crate::clustering::LavinaClient;
use crate::player::PlayerId;
use crate::prelude::Str;
use crate::room::RoomId;
use crate::LavinaCore;
pub mod paths {
pub const JOIN: &'static str = "/cluster/rooms/join";
pub const LEAVE: &'static str = "/cluster/rooms/leave";
pub const ADD_MESSAGE: &'static str = "/cluster/rooms/add_message";
pub const SET_TOPIC: &'static str = "/cluster/rooms/set_topic";
}
#[derive(Serialize, Deserialize, Debug)]
pub struct JoinRoomReq<'a> {
pub room_id: &'a str,
pub player_id: &'a str,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct LeaveRoomReq<'a> {
pub room_id: &'a str,
pub player_id: &'a str,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct SendMessageReq<'a> {
pub room_id: &'a str,
pub player_id: &'a str,
pub message: &'a str,
pub created_at: &'a str,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct SetRoomTopicReq<'a> {
pub room_id: &'a str,
pub player_id: &'a str,
pub topic: &'a str,
}
impl LavinaClient {
#[tracing::instrument(skip(self, req), name = "LavinaClient::join_room")]
pub async fn join_room(&self, node_id: u32, req: JoinRoomReq<'_>) -> Result<()> {
self.send_request(node_id, paths::JOIN, req).await
}
#[tracing::instrument(skip(self, req), name = "LavinaClient::leave_room")]
pub async fn leave_room(&self, node_id: u32, req: LeaveRoomReq<'_>) -> Result<()> {
self.send_request(node_id, paths::LEAVE, req).await
}
#[tracing::instrument(skip(self, req), name = "LavinaClient::send_room_message")]
pub async fn send_room_message(&self, node_id: u32, req: SendMessageReq<'_>) -> Result<()> {
self.send_request(node_id, paths::ADD_MESSAGE, req).await
}
#[tracing::instrument(skip(self, req), name = "LavinaClient::set_room_topic")]
pub async fn set_room_topic(&self, node_id: u32, req: SetRoomTopicReq<'_>) -> Result<()> {
self.send_request(node_id, paths::SET_TOPIC, req).await
}
}
impl LavinaCore {
pub async fn cluster_join_room(&self, room_id: RoomId, player_id: &PlayerId) -> Result<()> {
let room_handle = self.services.rooms.get_or_create_room(&self.services, room_id).await?;
let storage_id =
self.services.storage.create_or_retrieve_user_id_by_name(player_id.as_inner().as_ref()).await?;
room_handle.add_member(&self.services, &player_id, storage_id).await;
Ok(())
}
pub async fn cluster_send_room_message(
&self,
room_id: RoomId,
player_id: &PlayerId,
message: Str,
created_at: chrono::DateTime<chrono::Utc>,
) -> Result<Option<()>> {
let Some(room_handle) = self.services.rooms.get_room(&self.services, &room_id).await? else {
return Ok(None);
};
room_handle.send_message(&self.services, &player_id, message, created_at).await?;
Ok(Some(()))
}
}

View File

@ -0,0 +1,151 @@
//! Domain of dialogs conversations between two participants.
//!
//! Dialogs are different from rooms in that they are always between two participants.
//! There are no admins or other roles in dialogs, both participants have equal rights.
use std::collections::HashMap;
use chrono::{DateTime, Utc};
use tokio::sync::RwLock as AsyncRwLock;
use crate::player::{PlayerId, Updates};
use crate::prelude::*;
use crate::Services;
/// Id of a conversation between two players.
///
/// Dialogs are identified by the pair of participants' ids. The order of ids does not matter.
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct DialogId(PlayerId, PlayerId);
impl DialogId {
pub fn new(a: PlayerId, b: PlayerId) -> DialogId {
if a.as_inner() < b.as_inner() {
DialogId(a, b)
} else {
DialogId(b, a)
}
}
pub fn as_inner(&self) -> (&PlayerId, &PlayerId) {
(&self.0, &self.1)
}
pub fn into_inner(self) -> (PlayerId, PlayerId) {
(self.0, self.1)
}
}
struct Dialog {
storage_id: u32,
player_storage_id_1: u32,
player_storage_id_2: u32,
message_count: u32,
}
struct DialogRegistryInner {
dialogs: HashMap<DialogId, AsyncRwLock<Dialog>>,
}
pub(crate) struct DialogRegistry(AsyncRwLock<DialogRegistryInner>);
impl Services {
#[tracing::instrument(skip(self, body, created_at))]
pub async fn send_dialog_message(
&self,
from: PlayerId,
to: PlayerId,
body: Str,
created_at: &DateTime<Utc>,
) -> Result<()> {
let guard = self.dialogs.0.read().await;
let id = DialogId::new(from.clone(), to.clone());
let dialog = guard.dialogs.get(&id);
if let Some(d) = dialog {
let mut d = d.write().await;
self.storage
.insert_dialog_message(d.storage_id, d.message_count, from.as_inner(), &body, created_at)
.await?;
d.message_count += 1;
} else {
drop(guard);
let mut guard2 = self.dialogs.0.write().await;
// double check in case concurrent access has loaded this dialog
if let Some(d) = guard2.dialogs.get(&id) {
let mut d = d.write().await;
self.storage
.insert_dialog_message(d.storage_id, d.message_count, from.as_inner(), &body, created_at)
.await?;
d.message_count += 1;
} else {
let (p1, p2) = id.as_inner();
tracing::info!("Dialog {id:?} not found locally, trying to load from storage");
let stored_dialog = match self.storage.retrieve_dialog(p1.as_inner(), p2.as_inner()).await? {
Some(t) => t,
None => {
tracing::info!("Dialog {id:?} does not exist, creating a new one in storage");
self.storage.initialize_dialog(p1.as_inner(), p2.as_inner(), created_at).await?
}
};
tracing::info!("Dialog {id:?} loaded");
self.storage
.insert_dialog_message(
stored_dialog.id,
stored_dialog.message_count,
from.as_inner(),
&body,
created_at,
)
.await?;
let dialog = Dialog {
storage_id: stored_dialog.id,
player_storage_id_1: stored_dialog.participant_1,
player_storage_id_2: stored_dialog.participant_2,
message_count: stored_dialog.message_count + 1,
};
guard2.dialogs.insert(id.clone(), AsyncRwLock::new(dialog));
}
drop(guard2);
}
// TODO send message to the other player and persist it
let Some(player) = self.players.get_player(&to).await else {
tracing::debug!("Player {to:?} not active, not sending message");
return Ok(());
};
let update = Updates::NewDialogMessage {
sender: from.clone(),
receiver: to.clone(),
body: body.clone(),
created_at: created_at.clone(),
};
player.update(update).await;
return Ok(());
}
}
impl DialogRegistry {
pub fn new() -> DialogRegistry {
DialogRegistry(AsyncRwLock::new(DialogRegistryInner {
dialogs: HashMap::new(),
}))
}
pub fn shutdown(self) {}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_dialog_id_new() -> Result<()> {
let a = PlayerId::from("a")?;
let b = PlayerId::from("b")?;
let id1 = DialogId::new(a.clone(), b.clone());
let id2 = DialogId::new(a.clone(), b.clone());
// Dialog ids are invariant with respect to the order of participants
assert_eq!(id1, id2);
assert_eq!(id1.as_inner(), (&a, &b));
assert_eq!(id2.as_inner(), (&a, &b));
Ok(())
}
}

View File

@ -1,4 +1,20 @@
//! Domain definitions and implementation of common chat logic. //! Domain definitions and implementation of common chat logic.
use std::ops::Deref;
use std::sync::Arc;
use anyhow::Result;
use prometheus::Registry as MetricsRegistry;
use crate::clustering::broadcast::Broadcasting;
use crate::clustering::{ClusterConfig, ClusterMetadata, LavinaClient};
use crate::dialog::DialogRegistry;
use crate::player::{PlayerConnectionResult, PlayerId, PlayerRegistry};
use crate::repo::Storage;
use crate::room::{RoomHandle, RoomId, RoomInfo, RoomRegistry};
pub mod auth;
pub mod clustering;
pub mod dialog;
pub mod player; pub mod player;
pub mod prelude; pub mod prelude;
pub mod repo; pub mod repo;
@ -6,3 +22,91 @@ pub mod room;
pub mod terminator; pub mod terminator;
mod table; mod table;
#[derive(Clone)]
pub struct LavinaCore {
services: Arc<Services>,
}
impl Deref for LavinaCore {
type Target = Services;
fn deref(&self) -> &Self::Target {
&self.services
}
}
impl LavinaCore {
pub async fn connect_to_player(&self, player_id: &PlayerId) -> Result<PlayerConnectionResult> {
self.services.players.connect_to_player(&self, player_id).await
}
pub async fn get_room(&self, room_id: &RoomId) -> Result<Option<RoomHandle>> {
self.services.rooms.get_room(&self.services, room_id).await
}
pub async fn create_player(&self, player_id: &PlayerId) -> Result<()> {
self.services.storage.create_user(player_id.as_inner()).await
}
pub async fn get_all_rooms(&self) -> Vec<RoomInfo> {
self.services.rooms.get_all_rooms().await
}
pub async fn stop_player(&self, player_id: &PlayerId) -> Result<Option<()>> {
self.services.players.stop_player(player_id).await
}
}
pub struct Services {
pub(crate) players: PlayerRegistry,
pub(crate) rooms: RoomRegistry,
pub(crate) dialogs: DialogRegistry,
pub(crate) broadcasting: Broadcasting,
pub(crate) client: LavinaClient,
pub(crate) storage: Storage,
pub(crate) cluster_metadata: ClusterMetadata,
}
impl LavinaCore {
pub async fn new(
metrics: &mut MetricsRegistry,
cluster_config: ClusterConfig,
storage: Storage,
) -> Result<LavinaCore> {
// TODO shutdown all services in reverse order on error
let broadcasting = Broadcasting::new();
let client = LavinaClient::new(cluster_config.addresses.clone());
let rooms = RoomRegistry::new(metrics)?;
let dialogs = DialogRegistry::new();
let players = PlayerRegistry::empty(metrics)?;
let services = Services {
players,
rooms,
dialogs,
broadcasting,
client,
storage,
cluster_metadata: cluster_config.metadata,
};
Ok(LavinaCore {
services: Arc::new(services),
})
}
pub async fn shutdown(self) -> Storage {
self.players.shutdown_all().await;
let services = match Arc::try_unwrap(self.services) {
Ok(e) => e,
Err(_) => {
panic!("failed to acquire services ownership on shutdown");
}
};
let _ = services.players.shutdown();
let _ = services.dialogs.shutdown();
let _ = services.rooms.shutdown();
services.storage
}
}

View File

@ -7,21 +7,21 @@
//! //!
//! A player actor is a serial handler of commands from a single player. It is preferable to run all per-player validations in the player actor, //! A player actor is a serial handler of commands from a single player. It is preferable to run all per-player validations in the player actor,
//! so that they don't overload the room actor. //! so that they don't overload the room actor.
use std::{ use std::collections::{HashMap, HashSet};
collections::{HashMap, HashSet},
sync::{Arc, RwLock},
};
use anyhow::anyhow;
use chrono::{DateTime, Utc};
use prometheus::{IntGauge, Registry as MetricsRegistry}; use prometheus::{IntGauge, Registry as MetricsRegistry};
use serde::Serialize; use serde::Serialize;
use tokio::{ use tokio::sync::mpsc::{channel, Receiver, Sender};
sync::mpsc::{channel, Receiver, Sender}, use tokio::sync::RwLock;
task::JoinHandle, use tracing::{Instrument, Span};
};
use crate::clustering::room::*;
use crate::prelude::*; use crate::prelude::*;
use crate::room::{RoomHandle, RoomId, RoomInfo, RoomRegistry}; use crate::room::{RoomHandle, RoomId, RoomInfo, StoredMessage};
use crate::table::{AnonTable, Key as AnonKey}; use crate::table::{AnonTable, Key as AnonKey};
use crate::LavinaCore;
/// Opaque player identifier. Cannot contain spaces, must be shorter than 32. /// Opaque player identifier. Cannot contain spaces, must be shorter than 32.
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize)] #[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize)]
@ -45,145 +45,219 @@ impl PlayerId {
} }
} }
/// Node-local identifier of a connection. It is used to address a connection within a player actor.
#[derive(Debug, Clone, PartialEq, Eq, Hash)] #[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct ConnectionId(pub AnonKey); pub struct ConnectionId(pub AnonKey);
/// Representation of an authenticated client connection.
/// The public API available to projections through which all client actions are executed.
///
/// The connection is used to send commands to the player actor and to receive updates that might be sent to the client.
pub struct PlayerConnection { pub struct PlayerConnection {
pub connection_id: ConnectionId, pub connection_id: ConnectionId,
pub receiver: Receiver<Updates>, pub receiver: Receiver<ConnectionMessage>,
player_handle: PlayerHandle, player_handle: PlayerHandle,
} }
impl PlayerConnection { impl PlayerConnection {
pub async fn send_message(&mut self, room_id: RoomId, body: Str) -> Result<()> { /// Handled in [Player::send_room_message].
self.player_handle #[tracing::instrument(skip(self, body), name = "PlayerConnection::send_message")]
.send_message(room_id, self.connection_id.clone(), body) pub async fn send_message(&mut self, room_id: RoomId, body: Str) -> Result<SendMessageResult> {
.await
}
pub async fn join_room(&mut self, room_id: RoomId) -> Result<JoinResult> {
self.player_handle.join_room(room_id, self.connection_id.clone()).await
}
pub async fn change_topic(&mut self, room_id: RoomId, new_topic: Str) -> Result<()> {
let (promise, deferred) = oneshot(); let (promise, deferred) = oneshot();
let cmd = Cmd::ChangeTopic { let cmd = ClientCommand::SendMessage { room_id, body, promise };
self.player_handle.send(ActorCommand::ClientCommand(cmd, self.connection_id.clone())).await?;
deferred.await?
}
/// Handled in [Player::join_room].
#[tracing::instrument(skip(self), name = "PlayerConnection::join_room")]
pub async fn join_room(&mut self, room_id: RoomId) -> Result<JoinResult> {
let (promise, deferred) = oneshot();
let cmd = ClientCommand::JoinRoom { room_id, promise };
self.player_handle.send(ActorCommand::ClientCommand(cmd, self.connection_id.clone())).await?;
deferred.await?
}
/// Handled in [Player::change_room_topic].
#[tracing::instrument(skip(self, new_topic), name = "PlayerConnection::change_topic")]
pub async fn change_topic(&mut self, room_id: RoomId, new_topic: Str) -> Result<ChangeRoomTopicResult> {
let (promise, deferred) = oneshot();
let cmd = ClientCommand::ChangeRoomTopic {
room_id, room_id,
new_topic, new_topic,
promise, promise,
}; };
self.player_handle self.player_handle.send(ActorCommand::ClientCommand(cmd, self.connection_id.clone())).await?;
.send(PlayerCommand::Cmd(cmd, self.connection_id.clone())) deferred.await?
.await;
Ok(deferred.await?)
} }
/// Handled in [Player::leave_room].
#[tracing::instrument(skip(self), name = "PlayerConnection::leave_room")]
pub async fn leave_room(&mut self, room_id: RoomId) -> Result<()> { pub async fn leave_room(&mut self, room_id: RoomId) -> Result<()> {
let (promise, deferred) = oneshot(); let (promise, deferred) = oneshot();
self.player_handle let cmd = ClientCommand::LeaveRoom { room_id, promise };
.send(PlayerCommand::Cmd( self.player_handle.send(ActorCommand::ClientCommand(cmd, self.connection_id.clone())).await?;
Cmd::LeaveRoom { room_id, promise }, deferred.await?
self.connection_id.clone(),
))
.await;
Ok(deferred.await?)
} }
pub async fn terminate(self) { pub async fn terminate(self) -> Result<()> {
self.player_handle self.player_handle.send(ActorCommand::TerminateConnection(self.connection_id)).await
.send(PlayerCommand::TerminateConnection(self.connection_id))
.await;
} }
/// Handled in [Player::get_rooms].
#[tracing::instrument(skip(self), name = "PlayerConnection::get_rooms")]
pub async fn get_rooms(&self) -> Result<Vec<RoomInfo>> { pub async fn get_rooms(&self) -> Result<Vec<RoomInfo>> {
let (promise, deferred) = oneshot(); let (promise, deferred) = oneshot();
self.player_handle.send(PlayerCommand::GetRooms(promise)).await; let cmd = ClientCommand::GetRooms { promise };
Ok(deferred.await?) self.player_handle.send(ActorCommand::ClientCommand(cmd, self.connection_id.clone())).await?;
deferred.await?
}
#[tracing::instrument(skip(self), name = "PlayerConnection::get_room_message_history")]
pub async fn get_room_message_history(&self, room_id: &RoomId, limit: u32) -> Result<RoomHistoryResult> {
let (promise, deferred) = oneshot();
let cmd = ClientCommand::GetRoomHistory {
room_id: room_id.clone(),
promise,
limit,
};
self.player_handle.send(ActorCommand::ClientCommand(cmd, self.connection_id.clone())).await?;
deferred.await?
}
/// Handler in [Player::send_dialog_message].
#[tracing::instrument(skip(self, body), name = "PlayerConnection::send_dialog_message")]
pub async fn send_dialog_message(&self, recipient: PlayerId, body: Str) -> Result<()> {
let (promise, deferred) = oneshot();
let cmd = ClientCommand::SendDialogMessage {
recipient,
body,
promise,
};
self.player_handle.send(ActorCommand::ClientCommand(cmd, self.connection_id.clone())).await?;
deferred.await?
}
/// Handler in [Player::check_user_existence].
#[tracing::instrument(skip(self), name = "PlayerConnection::check_user_existence")]
pub async fn check_user_existence(&self, recipient: PlayerId) -> Result<GetInfoResult> {
let (promise, deferred) = oneshot();
let cmd = ClientCommand::GetInfo { recipient, promise };
self.player_handle.send(ActorCommand::ClientCommand(cmd, self.connection_id.clone())).await?;
deferred.await?
} }
} }
/// Handle to a player actor. /// Handle to a player actor.
#[derive(Clone)] #[derive(Clone)]
pub struct PlayerHandle { pub struct PlayerHandle {
tx: Sender<PlayerCommand>, tx: Sender<(ActorCommand, Span)>,
} }
impl PlayerHandle { impl PlayerHandle {
pub async fn subscribe(&self) -> PlayerConnection { pub async fn subscribe(&self) -> Result<PlayerConnection> {
let (sender, receiver) = channel(32); let (sender, receiver) = channel(32);
let (promise, deferred) = oneshot(); let (promise, deferred) = oneshot();
let cmd = PlayerCommand::AddConnection { sender, promise }; let cmd = ActorCommand::AddConnection { sender, promise };
let _ = self.tx.send(cmd).await; self.send(cmd).await?;
let connection_id = deferred.await.unwrap(); let connection_id = deferred.await?;
PlayerConnection { Ok(PlayerConnection {
connection_id, connection_id,
player_handle: self.clone(), player_handle: self.clone(),
receiver, receiver,
} })
} }
pub async fn send_message(&self, room_id: RoomId, connection_id: ConnectionId, body: Str) -> Result<()> { async fn send(&self, command: ActorCommand) -> Result<()> {
let (promise, deferred) = oneshot(); let span = tracing::span!(tracing::Level::INFO, "PlayerHandle::send");
let cmd = Cmd::SendMessage { room_id, body, promise }; self.tx.send((command, span)).await?;
let _ = self.tx.send(PlayerCommand::Cmd(cmd, connection_id)).await; Ok(())
Ok(deferred.await?)
}
pub async fn join_room(&self, room_id: RoomId, connection_id: ConnectionId) -> Result<JoinResult> {
let (promise, deferred) = oneshot();
let cmd = Cmd::JoinRoom { room_id, promise };
let _ = self.tx.send(PlayerCommand::Cmd(cmd, connection_id)).await;
Ok(deferred.await?)
}
async fn send(&self, command: PlayerCommand) {
let _ = self.tx.send(command).await;
} }
pub async fn update(&self, update: Updates) { pub async fn update(&self, update: Updates) {
self.send(PlayerCommand::Update(update)).await; let _ = self.send(ActorCommand::Update(update)).await;
} }
} }
enum PlayerCommand { /// Messages sent to the player actor.
/** Commands from connections */ enum ActorCommand {
/// Establish a new connection.
AddConnection { AddConnection {
sender: Sender<Updates>, sender: Sender<ConnectionMessage>,
promise: Promise<ConnectionId>, promise: Promise<ConnectionId>,
}, },
/// Terminate an existing connection.
TerminateConnection(ConnectionId), TerminateConnection(ConnectionId),
Cmd(Cmd, ConnectionId), /// Player-issued command.
/// Query - responds with a list of rooms the player is a member of. ClientCommand(ClientCommand, ConnectionId),
GetRooms(Promise<Vec<RoomInfo>>), /// Update which is sent from a room the player is member of.
/** Events from rooms */
Update(Updates), Update(Updates),
Stop, Stop,
} }
pub enum Cmd { /// Client-issued command sent to the player actor. The actor will respond with by fulfilling the promise.
pub enum ClientCommand {
JoinRoom { JoinRoom {
room_id: RoomId, room_id: RoomId,
promise: Promise<JoinResult>, promise: Promise<Result<JoinResult>>,
}, },
LeaveRoom { LeaveRoom {
room_id: RoomId, room_id: RoomId,
promise: Promise<()>, promise: Promise<Result<()>>,
}, },
SendMessage { SendMessage {
room_id: RoomId, room_id: RoomId,
body: Str, body: Str,
promise: Promise<()>, promise: Promise<Result<SendMessageResult>>,
}, },
ChangeTopic { ChangeRoomTopic {
room_id: RoomId, room_id: RoomId,
new_topic: Str, new_topic: Str,
promise: Promise<()>, promise: Promise<Result<ChangeRoomTopicResult>>,
}, },
GetRooms {
promise: Promise<Result<Vec<RoomInfo>>>,
},
SendDialogMessage {
recipient: PlayerId,
body: Str,
promise: Promise<Result<()>>,
},
GetInfo {
recipient: PlayerId,
promise: Promise<Result<GetInfoResult>>,
},
GetRoomHistory {
room_id: RoomId,
limit: u32,
promise: Promise<Result<RoomHistoryResult>>,
},
}
pub enum GetInfoResult {
UserExists,
UserDoesntExist,
} }
pub enum JoinResult { pub enum JoinResult {
Success(RoomInfo), Success(RoomInfo),
AlreadyJoined,
Banned, Banned,
} }
pub enum ChangeRoomTopicResult {
Success,
NoSuchRoom,
}
pub enum SendMessageResult {
Success(DateTime<Utc>),
NoSuchRoom,
}
pub enum RoomHistoryResult {
Success(Vec<StoredMessage>),
NoSuchRoom,
}
/// Player update event type which is sent to a player actor and from there to a connection handler. /// Player update event type which is sent to a player actor and from there to a connection handler.
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub enum Updates { pub enum Updates {
@ -195,6 +269,7 @@ pub enum Updates {
room_id: RoomId, room_id: RoomId,
author_id: PlayerId, author_id: PlayerId,
body: Str, body: Str,
created_at: DateTime<Utc>,
}, },
RoomJoined { RoomJoined {
room_id: RoomId, room_id: RoomId,
@ -206,215 +281,521 @@ pub enum Updates {
}, },
/// The player was banned from the room and left it immediately. /// The player was banned from the room and left it immediately.
BannedFrom(RoomId), BannedFrom(RoomId),
NewDialogMessage {
sender: PlayerId,
receiver: PlayerId,
body: Str,
created_at: DateTime<Utc>,
},
} }
/// Handle to a player registry — a shared data structure containing information about players. /// Handle to a player registry — a shared data structure containing information about players.
#[derive(Clone)] pub(crate) struct PlayerRegistry(RwLock<PlayerRegistryInner>);
pub struct PlayerRegistry(Arc<RwLock<PlayerRegistryInner>>);
impl PlayerRegistry { impl PlayerRegistry {
pub fn empty(room_registry: RoomRegistry, metrics: &mut MetricsRegistry) -> Result<PlayerRegistry> { pub fn empty(metrics: &mut MetricsRegistry) -> Result<PlayerRegistry> {
let metric_active_players = IntGauge::new("chat_players_active", "Number of alive player actors")?; let metric_active_players = IntGauge::new("chat_players_active", "Number of alive player actors")?;
metrics.register(Box::new(metric_active_players.clone()))?; metrics.register(Box::new(metric_active_players.clone()))?;
let inner = PlayerRegistryInner { let inner = PlayerRegistryInner {
room_registry,
players: HashMap::new(), players: HashMap::new(),
metric_active_players, metric_active_players,
}; };
Ok(PlayerRegistry(Arc::new(RwLock::new(inner)))) Ok(PlayerRegistry(RwLock::new(inner)))
} }
pub async fn get_or_create_player(&mut self, id: PlayerId) -> PlayerHandle { pub fn shutdown(self) {
let mut inner = self.0.write().unwrap(); let res = self.0.into_inner();
if let Some((handle, _)) = inner.players.get(&id) { drop(res);
handle.clone() }
#[tracing::instrument(skip(self), name = "PlayerRegistry::get_player")]
pub async fn get_player(&self, id: &PlayerId) -> Option<PlayerHandle> {
let inner = self.0.read().await;
inner.players.get(id).map(|(handle, _)| handle.clone())
}
#[tracing::instrument(skip(self), name = "PlayerRegistry::stop_player")]
pub async fn stop_player(&self, id: &PlayerId) -> Result<Option<()>> {
let mut inner = self.0.write().await;
if let Some((handle, fiber)) = inner.players.remove(id) {
if let Err(_) = handle.send(ActorCommand::Stop).await {
log::warn!("Failed to send Stop to the player actor #{id:?}. Ignoring, it is probably stopped already");
}
drop(handle);
fiber.await?;
inner.metric_active_players.dec();
Ok(Some(()))
} else { } else {
let (handle, fiber) = Player::launch(id.clone(), inner.room_registry.clone()); Ok(None)
inner.players.insert(id, (handle.clone(), fiber));
inner.metric_active_players.inc();
handle
} }
} }
pub async fn connect_to_player(&mut self, id: PlayerId) -> PlayerConnection { #[tracing::instrument(skip(self, core), name = "PlayerRegistry::get_or_launch_player")]
let player_handle = self.get_or_create_player(id).await; async fn get_or_launch_player(&self, core: &LavinaCore, id: &PlayerId) -> Result<Option<PlayerHandle>> {
player_handle.subscribe().await let inner = self.0.read().await;
if let Some((handle, _)) = inner.players.get(id) {
Ok(Some(handle.clone()))
} else {
drop(inner);
let mut inner = self.0.write().await;
if let Some((handle, _)) = inner.players.get(id) {
Ok(Some(handle.clone()))
} else {
let Some((handle, fiber)) = Player::launch(id.clone(), core.clone()).await? else {
return Ok(None);
};
inner.players.insert(id.clone(), (handle.clone(), fiber));
inner.metric_active_players.inc();
Ok(Some(handle))
}
}
} }
pub async fn shutdown_all(&mut self) -> Result<()> { #[tracing::instrument(skip(self, core), name = "PlayerRegistry::connect_to_player")]
let mut inner = self.0.write().unwrap(); pub async fn connect_to_player(&self, core: &LavinaCore, id: &PlayerId) -> Result<PlayerConnectionResult> {
for (i, (k, j)) in inner.players.drain() { let Some(player_handle) = self.get_or_launch_player(core, id).await? else {
k.send(PlayerCommand::Stop).await; return Ok(PlayerConnectionResult::PlayerNotFound);
drop(k); };
j.await?; let new_conn = player_handle.subscribe().await?;
log::debug!("Player stopped #{i:?}") Ok(PlayerConnectionResult::Success(new_conn))
}
pub async fn shutdown_all(&self) {
let mut inner = self.0.write().await;
for (id, (handle, task)) in inner.players.drain() {
let _ = handle.send(ActorCommand::Stop).await;
drop(handle);
match task.await {
Ok(_) => log::debug!("Player stopped #{id:?}"),
Err(e) => log::error!("Player #{id:?} failed to stop: {e}"),
}
} }
log::debug!("All players stopped"); log::debug!("All players stopped");
Ok(())
} }
} }
/// The player registry state representation. /// The player registry state representation.
struct PlayerRegistryInner { struct PlayerRegistryInner {
room_registry: RoomRegistry, /// Active player actors.
players: HashMap<PlayerId, (PlayerHandle, JoinHandle<Player>)>, players: HashMap<PlayerId, (PlayerHandle, JoinHandle<Player>)>,
metric_active_players: IntGauge, metric_active_players: IntGauge,
} }
enum RoomRef {
Local(RoomHandle),
Remote { node_id: u32 },
}
/// Player actor inner state representation. /// Player actor inner state representation.
struct Player { struct Player {
player_id: PlayerId, player_id: PlayerId,
connections: AnonTable<Sender<Updates>>, storage_id: u32,
my_rooms: HashMap<RoomId, RoomHandle>, connections: AnonTable<Sender<ConnectionMessage>>,
my_rooms: HashMap<RoomId, RoomRef>,
banned_from: HashSet<RoomId>, banned_from: HashSet<RoomId>,
rx: Receiver<PlayerCommand>, rx: Receiver<(ActorCommand, Span)>,
handle: PlayerHandle, handle: PlayerHandle,
rooms: RoomRegistry, services: LavinaCore,
} }
impl Player { impl Player {
fn launch(player_id: PlayerId, rooms: RoomRegistry) -> (PlayerHandle, JoinHandle<Player>) { async fn launch(player_id: PlayerId, core: LavinaCore) -> Result<Option<(PlayerHandle, JoinHandle<Player>)>> {
let (tx, rx) = channel(32); let (tx, rx) = channel(32);
let handle = PlayerHandle { tx }; let handle = PlayerHandle { tx };
let handle_clone = handle.clone(); let handle_clone = handle.clone();
let Some(storage_id) = core.services.storage.retrieve_user_id_by_name(player_id.as_inner()).await? else {
return Ok(None);
};
let player = Player { let player = Player {
player_id, player_id,
storage_id,
// connections are empty when the actor is just started
connections: AnonTable::new(), connections: AnonTable::new(),
// room handlers will be loaded later in the started task
my_rooms: HashMap::new(), my_rooms: HashMap::new(),
banned_from: HashSet::from([RoomId::from("Empty").unwrap()]), // TODO implement and load bans
banned_from: HashSet::new(),
rx, rx,
handle, handle,
rooms, services: core,
}; };
let fiber = tokio::task::spawn(player.main_loop()); let fiber = tokio::task::spawn(player.main_loop());
(handle_clone, fiber) Ok(Some((handle_clone, fiber)))
}
fn room_location(&self, room_id: &RoomId) -> Option<u32> {
let res = self.services.cluster_metadata.rooms.get(room_id.as_inner().as_ref()).copied();
let node = res.unwrap_or(self.services.cluster_metadata.main_owner);
if node == self.services.cluster_metadata.node_id {
None
} else {
Some(node)
}
} }
async fn main_loop(mut self) -> Self { async fn main_loop(mut self) -> Self {
let rooms = self.services.storage.get_rooms_of_a_user(self.storage_id).await.unwrap();
for room_id in rooms {
if let Some(remote_node) = self.room_location(&room_id) {
self.my_rooms.insert(room_id.clone(), RoomRef::Remote { node_id: remote_node });
self.services.subscribe(self.player_id.clone(), room_id).await;
} else {
let room = self.services.rooms.get_room(&self.services, &room_id).await.unwrap();
if let Some(room) = room {
room.subscribe(&self.player_id, self.handle.clone()).await;
self.my_rooms.insert(room_id, RoomRef::Local(room));
} else {
tracing::error!("Room #{room_id:?} not found");
}
}
}
while let Some(cmd) = self.rx.recv().await { while let Some(cmd) = self.rx.recv().await {
match cmd { let (cmd, span) = cmd;
PlayerCommand::AddConnection { sender, promise } => { let should_stop = async {
let connection_id = self.connections.insert(sender); match cmd {
if let Err(connection_id) = promise.send(ConnectionId(connection_id)) { ActorCommand::AddConnection { sender, promise } => {
log::warn!("Connection {connection_id:?} terminated before finalization"); let connection_id = self.connections.insert(sender);
self.terminate_connection(connection_id); if let Err(connection_id) = promise.send(ConnectionId(connection_id)) {
} log::warn!("Connection {connection_id:?} terminated before finalization");
} self.terminate_connection(connection_id);
PlayerCommand::TerminateConnection(connection_id) => {
self.terminate_connection(connection_id);
}
PlayerCommand::GetRooms(promise) => {
let mut response = vec![];
for (_, handle) in &self.my_rooms {
response.push(handle.get_room_info().await);
}
let _ = promise.send(response);
}
PlayerCommand::Update(update) => {
log::info!(
"Player received an update, broadcasting to {} connections",
self.connections.len()
);
match update {
Updates::BannedFrom(ref room_id) => {
self.banned_from.insert(room_id.clone());
self.my_rooms.remove(room_id);
} }
_ => {} false
} }
for (_, connection) in &self.connections { ActorCommand::TerminateConnection(connection_id) => {
let _ = connection.send(update.clone()).await; self.terminate_connection(connection_id);
false
} }
ActorCommand::Update(update) => {
self.handle_update(update).await;
false
}
ActorCommand::ClientCommand(cmd, connection_id) => {
self.handle_cmd(cmd, connection_id).await;
false
}
ActorCommand::Stop => true,
} }
PlayerCommand::Cmd(cmd, connection_id) => self.handle_cmd(cmd, connection_id).await, }
PlayerCommand::Stop => break, .instrument(span)
.await;
if should_stop {
break;
} }
} }
log::debug!("Shutting down player actor #{:?}", self.player_id); log::debug!("Shutting down player actor #{:?}", self.player_id);
self self
} }
/// Handle an incoming update by changing the internal state and broadcasting it to all connections if necessary.
#[tracing::instrument(skip(self, update), name = "Player::handle_update")]
async fn handle_update(&mut self, update: Updates) {
log::debug!(
"Player received an update, broadcasting to {} connections",
self.connections.len()
);
match update {
Updates::BannedFrom(ref room_id) => {
self.banned_from.insert(room_id.clone());
self.my_rooms.remove(room_id);
}
_ => {}
}
for (_, connection) in &self.connections {
let _ = connection.send(ConnectionMessage::Update(update.clone())).await;
}
}
fn terminate_connection(&mut self, connection_id: ConnectionId) { fn terminate_connection(&mut self, connection_id: ConnectionId) {
if let None = self.connections.pop(connection_id.0) { if let None = self.connections.pop(connection_id.0) {
log::warn!("Connection {connection_id:?} already terminated"); log::warn!("Connection {connection_id:?} already terminated");
} }
} }
async fn handle_cmd(&mut self, cmd: Cmd, connection_id: ConnectionId) { /// Dispatches a client command to the appropriate handler.
async fn handle_cmd(&mut self, cmd: ClientCommand, connection_id: ConnectionId) {
match cmd { match cmd {
Cmd::JoinRoom { room_id, promise } => { ClientCommand::JoinRoom { room_id, promise } => {
if self.banned_from.contains(&room_id) { let result = self.join_room(connection_id, room_id).await;
let _ = promise.send(JoinResult::Banned); let _ = promise.send(result);
return;
}
let room = match self.rooms.get_or_create_room(room_id.clone()).await {
Ok(room) => room,
Err(e) => {
log::error!("Failed to get or create room: {e}");
return;
}
};
room.subscribe(self.player_id.clone(), self.handle.clone()).await;
self.my_rooms.insert(room_id.clone(), room.clone());
let room_info = room.get_room_info().await;
let _ = promise.send(JoinResult::Success(room_info));
let update = Updates::RoomJoined {
room_id,
new_member_id: self.player_id.clone(),
};
self.broadcast_update(update, connection_id).await;
} }
Cmd::LeaveRoom { room_id, promise } => { ClientCommand::LeaveRoom { room_id, promise } => {
let room = self.my_rooms.remove(&room_id); let result = self.leave_room(connection_id, room_id).await;
if let Some(room) = room { let _ = promise.send(result);
room.unsubscribe(&self.player_id).await;
let room_info = room.get_room_info().await;
}
let _ = promise.send(());
let update = Updates::RoomLeft {
room_id,
former_member_id: self.player_id.clone(),
};
self.broadcast_update(update, connection_id).await;
} }
Cmd::SendMessage { room_id, body, promise } => { ClientCommand::SendMessage { room_id, body, promise } => {
let room = self.rooms.get_room(&room_id).await; let result = self.send_room_message(connection_id, room_id, body).await;
if let Some(room) = room { let _ = promise.send(result);
room.send_message(self.player_id.clone(), body.clone()).await;
} else {
tracing::info!("no room found");
}
let _ = promise.send(());
let update = Updates::NewMessage {
room_id,
author_id: self.player_id.clone(),
body,
};
self.broadcast_update(update, connection_id).await;
} }
Cmd::ChangeTopic { ClientCommand::ChangeRoomTopic {
room_id, room_id,
new_topic, new_topic,
promise, promise,
} => { } => {
let room = self.rooms.get_room(&room_id).await; let result = self.change_room_topic(connection_id, room_id, new_topic).await;
if let Some(mut room) = room { let _ = promise.send(result);
room.set_topic(self.player_id.clone(), new_topic.clone()).await; }
} else { ClientCommand::GetRooms { promise } => {
tracing::info!("no room found"); let result = self.get_rooms().await;
} let _ = promise.send(Ok(result));
let _ = promise.send(()); }
let update = Updates::RoomTopicChanged { room_id, new_topic }; ClientCommand::SendDialogMessage {
self.broadcast_update(update, connection_id).await; recipient,
body,
promise,
} => {
let result = self.send_dialog_message(connection_id, recipient, body).await;
let _ = promise.send(result);
}
ClientCommand::GetInfo { recipient, promise } => {
let result = self.check_user_existence(recipient).await;
let _ = promise.send(result);
}
ClientCommand::GetRoomHistory {
room_id,
limit,
promise,
} => {
let result = self.get_room_history(room_id, limit).await;
let _ = promise.send(result);
} }
} }
} }
#[tracing::instrument(skip(self), name = "Player::join_room")]
async fn join_room(&mut self, connection_id: ConnectionId, room_id: RoomId) -> Result<JoinResult> {
if self.banned_from.contains(&room_id) {
return Ok(JoinResult::Banned);
}
if self.my_rooms.contains_key(&room_id) {
return Ok(JoinResult::AlreadyJoined);
}
if let Some(remote_node) = self.room_location(&room_id) {
let req = JoinRoomReq {
room_id: room_id.as_inner(),
player_id: self.player_id.as_inner(),
};
self.services.client.join_room(remote_node, req).await?;
let room_storage_id = self.services.storage.create_or_retrieve_room_id_by_name(room_id.as_inner()).await?;
self.services.storage.add_room_member(room_storage_id, self.storage_id).await?;
self.my_rooms.insert(room_id.clone(), RoomRef::Remote { node_id: remote_node });
Ok(JoinResult::Success(RoomInfo {
id: room_id,
topic: "unknown".into(),
members: vec![],
}))
} else {
let room = match self.services.rooms.get_or_create_room(&self.services, room_id.clone()).await {
Ok(room) => room,
Err(e) => {
log::error!("Failed to get or create room: {e}");
todo!();
}
};
room.add_member(&self.services, &self.player_id, self.storage_id).await;
room.subscribe(&self.player_id, self.handle.clone()).await;
self.my_rooms.insert(room_id.clone(), RoomRef::Local(room.clone()));
let room_info = room.get_room_info().await;
let update = Updates::RoomJoined {
room_id,
new_member_id: self.player_id.clone(),
};
self.broadcast_update(update, connection_id).await;
Ok(JoinResult::Success(room_info))
}
}
#[tracing::instrument(skip(self), name = "Player::retrieve_room_history")]
async fn get_room_history(&mut self, room_id: RoomId, limit: u32) -> Result<RoomHistoryResult> {
let room = self.my_rooms.get(&room_id);
if let Some(room) = room {
match room {
RoomRef::Local(room) => {
let res = room.get_message_history(&self.services, limit).await?;
Ok(RoomHistoryResult::Success(res))
}
RoomRef::Remote { node_id: _ } => {
tracing::error!("TODO Room #{room_id:?} is remote, cannot retrieve history");
Err(anyhow!("Not implemented"))
}
}
} else {
tracing::debug!("Room #{room_id:?} not found");
Ok(RoomHistoryResult::NoSuchRoom)
}
}
#[tracing::instrument(skip(self), name = "Player::leave_room")]
async fn leave_room(&mut self, connection_id: ConnectionId, room_id: RoomId) -> Result<()> {
let room = self.my_rooms.remove(&room_id);
if let Some(room) = room {
match room {
RoomRef::Local(room) => {
room.unsubscribe(&self.player_id).await;
room.remove_member(&self.services, &self.player_id, self.storage_id).await;
}
RoomRef::Remote { node_id } => {
let req = LeaveRoomReq {
room_id: room_id.as_inner(),
player_id: self.player_id.as_inner(),
};
self.services.client.leave_room(node_id, req).await?;
let room_storage_id =
self.services.storage.create_or_retrieve_room_id_by_name(room_id.as_inner()).await?;
self.services.storage.remove_room_member(room_storage_id, self.storage_id).await?;
}
}
}
let update = Updates::RoomLeft {
room_id,
former_member_id: self.player_id.clone(),
};
self.broadcast_update(update, connection_id).await;
Ok(())
}
#[tracing::instrument(skip(self, body), name = "Player::send_room_message")]
async fn send_room_message(
&mut self,
connection_id: ConnectionId,
room_id: RoomId,
body: Str,
) -> Result<SendMessageResult> {
let Some(room) = self.my_rooms.get(&room_id) else {
tracing::info!("Room with ID {room_id:?} not found");
return Ok(SendMessageResult::NoSuchRoom);
};
let created_at = Utc::now();
match room {
RoomRef::Local(room) => {
room.send_message(&self.services, &self.player_id, body.clone(), created_at.clone()).await?;
}
RoomRef::Remote { node_id } => {
let req = SendMessageReq {
room_id: room_id.as_inner(),
player_id: self.player_id.as_inner(),
message: &*body,
created_at: &*created_at.to_rfc3339(),
};
self.services.client.send_room_message(*node_id, req).await?;
self.services
.broadcast(
room_id.clone(),
self.player_id.clone(),
body.clone(),
created_at.clone(),
)
.await;
}
}
let update = Updates::NewMessage {
room_id,
author_id: self.player_id.clone(),
body,
created_at,
};
self.broadcast_update(update, connection_id).await;
Ok(SendMessageResult::Success(created_at))
}
#[tracing::instrument(skip(self, new_topic), name = "Player::change_room_topic")]
async fn change_room_topic(
&mut self,
connection_id: ConnectionId,
room_id: RoomId,
new_topic: Str,
) -> Result<ChangeRoomTopicResult> {
let Some(room) = self.my_rooms.get(&room_id) else {
tracing::debug!("Room with ID {room_id:?} not found");
return Ok(ChangeRoomTopicResult::NoSuchRoom);
};
match room {
RoomRef::Local(room) => {
room.set_topic(&self.services, &self.player_id, new_topic.clone()).await?;
}
RoomRef::Remote { node_id } => {
let req = SetRoomTopicReq {
room_id: room_id.as_inner(),
player_id: self.player_id.as_inner(),
topic: &*new_topic,
};
self.services.client.set_room_topic(*node_id, req).await?;
}
}
let update = Updates::RoomTopicChanged { room_id, new_topic };
self.broadcast_update(update, connection_id).await;
Ok(ChangeRoomTopicResult::Success)
}
#[tracing::instrument(skip(self), name = "Player::get_rooms")]
async fn get_rooms(&self) -> Vec<RoomInfo> {
let mut response = vec![];
for (room_id, handle) in &self.my_rooms {
match handle {
RoomRef::Local(handle) => {
response.push(handle.get_room_info().await);
}
RoomRef::Remote { .. } => {
let room_info = RoomInfo {
id: room_id.clone(),
topic: "unknown".into(),
members: vec![],
};
response.push(room_info);
}
}
}
response
}
#[tracing::instrument(skip(self, body), name = "Player::send_dialog_message")]
async fn send_dialog_message(&self, connection_id: ConnectionId, recipient: PlayerId, body: Str) -> Result<()> {
let created_at = Utc::now();
self.services.send_dialog_message(self.player_id.clone(), recipient.clone(), body.clone(), &created_at).await?;
let update = Updates::NewDialogMessage {
sender: self.player_id.clone(),
receiver: recipient.clone(),
body,
created_at,
};
self.broadcast_update(update, connection_id).await;
Ok(())
}
#[tracing::instrument(skip(self), name = "Player::check_user_existence")]
async fn check_user_existence(&self, recipient: PlayerId) -> Result<GetInfoResult> {
if self.services.storage.check_user_existence(recipient.as_inner().as_ref()).await? {
Ok(GetInfoResult::UserExists)
} else {
Ok(GetInfoResult::UserDoesntExist)
}
}
/// Broadcasts an update to all connections except the one with the given id.
///
/// This is called after handling a client command.
/// Sending the update to the connection which sent the command is handled by the connection itself.
#[tracing::instrument(skip(self, update), name = "Player::broadcast_update")]
async fn broadcast_update(&self, update: Updates, except: ConnectionId) { async fn broadcast_update(&self, update: Updates, except: ConnectionId) {
for (a, b) in &self.connections { for (a, b) in &self.connections {
if ConnectionId(a) == except { if ConnectionId(a) == except {
continue; continue;
} }
let _ = b.send(update.clone()).await; let _ = b.send(ConnectionMessage::Update(update.clone())).await;
} }
} }
} }
pub enum ConnectionMessage {
Update(Updates),
Stop(StopReason),
}
#[derive(Debug)]
pub enum StopReason {
ServerShutdown,
InternalError,
}
pub enum PlayerConnectionResult {
Success(PlayerConnection),
PlayerNotFound,
}

View File

@ -0,0 +1,20 @@
use anyhow::Result;
use crate::repo::Storage;
impl Storage {
#[tracing::instrument(skip(self), name = "Storage::set_argon2_challenge")]
pub async fn set_argon2_challenge(&self, user_id: u32, hash: &str) -> Result<()> {
let mut executor = self.conn.lock().await;
sqlx::query(
"insert into challenges_argon2_password(user_id, hash)
values (?, ?)
on conflict(user_id) do update set hash = excluded.hash;",
)
.bind(user_id)
.bind(hash)
.execute(&mut *executor)
.await?;
Ok(())
}
}

View File

@ -0,0 +1,91 @@
use anyhow::{anyhow, Result};
use chrono::{DateTime, Utc};
use sqlx::FromRow;
use crate::repo::Storage;
impl Storage {
#[tracing::instrument(skip(self), name = "Storage::retrieve_dialog")]
pub async fn retrieve_dialog(&self, participant_1: &str, participant_2: &str) -> Result<Option<StoredDialog>> {
let mut executor = self.conn.lock().await;
let res = sqlx::query_as(
"select r.id, r.participant_1, r.participant_2, r.message_count
from dialogs r join users u1 on r.participant_1 = u1.id join users u2 on r.participant_2 = u2.id
where u1.name = ? and u2.name = ?;",
)
.bind(participant_1)
.bind(participant_2)
.fetch_optional(&mut *executor)
.await?;
Ok(res)
}
#[tracing::instrument(skip(self, content, created_at), name = "Storage::insert_dialog_message")]
pub async fn insert_dialog_message(
&self,
dialog_id: u32,
id: u32,
author_id: &str,
content: &str,
created_at: &DateTime<Utc>,
) -> Result<()> {
let mut executor = self.conn.lock().await;
let res: Option<(u32,)> = sqlx::query_as("select id from users where name = ?;")
.bind(author_id)
.fetch_optional(&mut *executor)
.await?;
let Some((author_id,)) = res else {
return Err(anyhow!("No such user"));
};
sqlx::query(
"insert into dialog_messages(dialog_id, id, author_id, content, created_at)
values (?, ?, ?, ?, ?);
update dialogs set message_count = message_count + 1 where id = ?;",
)
.bind(dialog_id)
.bind(id)
.bind(author_id)
.bind(content)
.bind(created_at)
.bind(dialog_id)
.execute(&mut *executor)
.await?;
Ok(())
}
#[tracing::instrument(skip(self, created_at), name = "Storage::initialize_dialog")]
pub async fn initialize_dialog(
&self,
participant_1: &str,
participant_2: &str,
created_at: &DateTime<Utc>,
) -> Result<StoredDialog> {
let mut executor = self.conn.lock().await;
let res: StoredDialog = sqlx::query_as(
"insert into dialogs(participant_1, participant_2, created_at)
values (
(select id from users where name = ?),
(select id from users where name = ?),
?
)
returning id, participant_1, participant_2, message_count;",
)
.bind(participant_1)
.bind(participant_2)
.bind(&created_at)
.fetch_one(&mut *executor)
.await?;
Ok(res)
}
}
#[derive(FromRow)]
pub struct StoredDialog {
pub id: u32,
pub participant_1: u32,
pub participant_2: u32,
pub message_count: u32,
}

View File

@ -1,24 +1,26 @@
//! Storage and persistence logic. //! Storage and persistence logic.
use std::str::FromStr; use std::str::FromStr;
use std::sync::Arc;
use anyhow::anyhow;
use serde::Deserialize; use serde::Deserialize;
use sqlx::sqlite::SqliteConnectOptions; use sqlx::sqlite::SqliteConnectOptions;
use sqlx::{ConnectOptions, Connection, FromRow, Sqlite, SqliteConnection, Transaction}; use sqlx::{ConnectOptions, Connection, SqliteConnection};
use tokio::sync::Mutex; use tokio::sync::Mutex;
use crate::prelude::*; use crate::prelude::*;
mod auth;
mod dialog;
mod room;
mod user;
#[derive(Deserialize, Debug, Clone)] #[derive(Deserialize, Debug, Clone)]
pub struct StorageConfig { pub struct StorageConfig {
pub db_path: String, pub db_path: String,
} }
#[derive(Clone)]
pub struct Storage { pub struct Storage {
conn: Arc<Mutex<SqliteConnection>>, conn: Mutex<SqliteConnection>,
} }
impl Storage { impl Storage {
pub async fn open(config: StorageConfig) -> Result<Storage> { pub async fn open(config: StorageConfig) -> Result<Storage> {
@ -30,144 +32,17 @@ impl Storage {
migrator.run(&mut conn).await?; migrator.run(&mut conn).await?;
log::info!("Migrations passed"); log::info!("Migrations passed");
let conn = Arc::new(Mutex::new(conn)); let conn = Mutex::new(conn);
Ok(Storage { conn }) Ok(Storage { conn })
} }
pub async fn retrieve_user_by_name(&mut self, name: &str) -> Result<Option<StoredUser>> { pub async fn close(self) {
let mut executor = self.conn.lock().await; let res = self.conn.into_inner();
let res = sqlx::query_as( match res.close().await {
"select u.id, u.name, c.password Ok(_) => {}
from users u left join challenges_plain_password c on u.id = c.user_id
where u.name = ?;",
)
.bind(name)
.fetch_optional(&mut *executor)
.await?;
Ok(res)
}
pub async fn retrieve_room_by_name(&mut self, name: &str) -> Result<Option<StoredRoom>> {
let mut executor = self.conn.lock().await;
let res = sqlx::query_as(
"select id, name, topic, message_count
from rooms
where name = ?;",
)
.bind(name)
.fetch_optional(&mut *executor)
.await?;
Ok(res)
}
pub async fn create_new_room(&mut self, name: &str, topic: &str) -> Result<u32> {
let mut executor = self.conn.lock().await;
let (id,): (u32,) = sqlx::query_as(
"insert into rooms(name, topic)
values (?, ?)
returning id;",
)
.bind(name)
.bind(topic)
.fetch_one(&mut *executor)
.await?;
Ok(id)
}
pub async fn insert_message(&mut self, room_id: u32, id: u32, content: &str, author_id: &str) -> Result<()> {
let mut executor = self.conn.lock().await;
let res: Option<(u32,)> = sqlx::query_as("select id from users where name = ?;")
.bind(author_id)
.fetch_optional(&mut *executor)
.await?;
let Some((author_id,)) = res else {
return Err(anyhow!("No such user"));
};
sqlx::query(
"insert into messages(room_id, id, content, author_id)
values (?, ?, ?, ?);
update rooms set message_count = message_count + 1 where id = ?;",
)
.bind(room_id)
.bind(id)
.bind(content)
.bind(author_id)
.bind(room_id)
.execute(&mut *executor)
.await?;
Ok(())
}
pub async fn close(self) -> Result<()> {
let res = match Arc::try_unwrap(self.conn) {
Ok(e) => e,
Err(_) => return Err(fail("failed to acquire DB ownership on shutdown")),
};
let res = res.into_inner();
res.close().await?;
Ok(())
}
pub async fn create_user(&mut self, name: &str) -> Result<()> {
let query = sqlx::query(
"insert into users(name)
values (?);",
)
.bind(name);
let mut executor = self.conn.lock().await;
query.execute(&mut *executor).await?;
Ok(())
}
pub async fn set_password<'a>(&'a mut self, name: &'a str, pwd: &'a str) -> Result<Option<()>> {
async fn inner(txn: &mut Transaction<'_, Sqlite>, name: &str, pwd: &str) -> Result<Option<()>> {
let id: Option<(u32,)> = sqlx::query_as("select * from users where name = ? limit 1;")
.bind(name)
.fetch_optional(&mut **txn)
.await?;
let Some((id,)) = id else {
return Ok(None);
};
sqlx::query("insert or replace into challenges_plain_password(user_id, password) values (?, ?);")
.bind(id)
.bind(pwd)
.execute(&mut **txn)
.await?;
Ok(Some(()))
}
let mut executor = self.conn.lock().await;
let mut tx = executor.begin().await?;
let res = inner(&mut tx, name, pwd).await;
match res {
Ok(e) => {
tx.commit().await?;
Ok(e)
}
Err(e) => { Err(e) => {
tx.rollback().await?; tracing::error!("Failed to close the DB connection: {e:?}");
Err(e)
} }
} }
} }
} }
#[derive(FromRow)]
pub struct StoredUser {
pub id: u32,
pub name: String,
pub password: Option<String>,
}
#[derive(FromRow)]
pub struct StoredRoom {
pub id: u32,
pub name: String,
pub topic: String,
pub message_count: u32,
}

View File

@ -0,0 +1,230 @@
use anyhow::{anyhow, Result};
use chrono::{DateTime, Utc};
use sqlx::FromRow;
use crate::player::PlayerId;
use crate::repo::Storage;
use crate::room::{RoomId, StoredMessage};
#[derive(FromRow)]
pub struct StoredRoom {
pub id: u32,
pub name: String,
pub topic: String,
pub message_count: u32,
}
impl Storage {
#[tracing::instrument(skip(self), name = "Storage::retrieve_room_by_name")]
pub async fn retrieve_room_by_name(&self, name: &str) -> Result<Option<StoredRoom>> {
let mut executor = self.conn.lock().await;
let res = sqlx::query_as(
"select id, name, topic, message_count
from rooms
where name = ?;",
)
.bind(name)
.fetch_optional(&mut *executor)
.await?;
Ok(res)
}
#[tracing::instrument(skip(self), name = "Storage::retrieve_room_message_history")]
pub async fn get_room_message_history(&self, room_id: u32, limit: u32) -> Result<Vec<StoredMessage>> {
let mut executor = self.conn.lock().await;
let res = sqlx::query_as(
"
select
*
from (
select
messages.id as id,
content,
created_at,
users.name as author_name
from
messages
join
users
on messages.author_id = users.id
where
room_id = ?
order by
messages.id desc
limit ?
)
order by
id asc;
",
)
.bind(room_id)
.bind(limit)
.fetch_all(&mut *executor)
.await?;
Ok(res)
}
#[tracing::instrument(skip(self, topic), name = "Storage::create_new_room")]
pub async fn create_new_room(&self, name: &str, topic: &str) -> Result<u32> {
let mut executor = self.conn.lock().await;
let (id,): (u32,) = sqlx::query_as(
"insert into rooms(name, topic)
values (?, ?)
returning id;",
)
.bind(name)
.bind(topic)
.fetch_one(&mut *executor)
.await?;
Ok(id)
}
#[tracing::instrument(skip(self, content, created_at), name = "Storage::insert_room_message")]
pub async fn insert_room_message(
&self,
room_id: u32,
id: u32,
content: &str,
author_id: &str,
created_at: &DateTime<Utc>,
) -> Result<()> {
let mut executor = self.conn.lock().await;
let res: Option<(u32,)> = sqlx::query_as("select id from users where name = ?;")
.bind(author_id)
.fetch_optional(&mut *executor)
.await?;
let Some((author_id,)) = res else {
return Err(anyhow!("No such user"));
};
sqlx::query(
"insert into messages(room_id, id, content, author_id, created_at)
values (?, ?, ?, ?, ?);
update rooms set message_count = message_count + 1 where id = ?;",
)
.bind(room_id)
.bind(id)
.bind(content)
.bind(author_id)
.bind(created_at)
.bind(room_id)
.execute(&mut *executor)
.await?;
Ok(())
}
#[tracing::instrument(skip(self), name = "Storage::is_room_member")]
pub async fn is_room_member(&self, room_id: u32, player_id: u32) -> Result<bool> {
let mut executor = self.conn.lock().await;
let res: (u32,) = sqlx::query_as(
"
select
count(*)
from
memberships
where
user_id = ? and room_id = ?;
",
)
.bind(player_id)
.bind(room_id)
.fetch_one(&mut *executor)
.await?;
Ok(res.0 > 0)
}
#[tracing::instrument(skip(self), name = "Storage::add_room_member")]
pub async fn add_room_member(&self, room_id: u32, player_id: u32) -> Result<()> {
let mut executor = self.conn.lock().await;
sqlx::query(
"insert into memberships(user_id, room_id, status)
values (?, ?, 1);",
)
.bind(player_id)
.bind(room_id)
.execute(&mut *executor)
.await?;
Ok(())
}
#[tracing::instrument(skip(self), name = "Storage::remove_room_member")]
pub async fn remove_room_member(&self, room_id: u32, player_id: u32) -> Result<()> {
let mut executor = self.conn.lock().await;
sqlx::query(
"delete from memberships
where user_id = ? and room_id = ?;",
)
.bind(player_id)
.bind(room_id)
.execute(&mut *executor)
.await?;
Ok(())
}
#[tracing::instrument(skip(self, topic), name = "Storage::set_room_topic")]
pub async fn set_room_topic(&self, id: u32, topic: &str) -> Result<()> {
let mut executor = self.conn.lock().await;
sqlx::query(
"update rooms
set topic = ?
where id = ?;",
)
.bind(topic)
.bind(id)
.fetch_optional(&mut *executor)
.await?;
Ok(())
}
#[tracing::instrument(skip(self), name = "Storage::create_or_retrieve_room_id_by_name")]
pub async fn create_or_retrieve_room_id_by_name(&self, name: &str) -> Result<u32> {
// TODO we don't need any info except the name on non-owning nodes, should remove stubs here
let mut executor = self.conn.lock().await;
let res: (u32,) = sqlx::query_as(
"insert into rooms(name, topic)
values (?, '')
on conflict(name) do update set name = excluded.name
returning id;",
)
.bind(name)
.fetch_one(&mut *executor)
.await?;
Ok(res.0)
}
#[tracing::instrument(skip(self), name = "Storage::get_rooms_of_a_user")]
pub async fn get_rooms_of_a_user(&self, user_id: u32) -> Result<Vec<RoomId>> {
let mut executor = self.conn.lock().await;
let res: Vec<(String,)> = sqlx::query_as(
"select r.name
from memberships m inner join rooms r on m.room_id = r.id
where m.user_id = ?;",
)
.bind(user_id)
.fetch_all(&mut *executor)
.await?;
res.into_iter().map(|(room_id,)| RoomId::try_from(room_id)).collect()
}
pub async fn get_users_of_a_room(&self, room_id: u32) -> Result<Vec<PlayerId>> {
let mut executor = self.conn.lock().await;
let res: Vec<(String,)> = sqlx::query_as(
"select u.name
from memberships m inner join users u on m.user_id = u.id
where m.room_id = ?;",
)
.bind(room_id)
.fetch_all(&mut *executor)
.await?;
Ok(res.into_iter()).and_then(|iter| iter.map(|(user_id,)| PlayerId::from(user_id)).collect())
}
}

View File

@ -0,0 +1,114 @@
use anyhow::Result;
use sqlx::{Connection, FromRow, Sqlite, Transaction};
use crate::repo::Storage;
#[derive(FromRow)]
pub struct StoredUser {
pub id: u32,
pub name: String,
pub password: Option<String>,
pub argon2_hash: Option<Box<str>>,
}
impl Storage {
#[tracing::instrument(skip(self), name = "Storage::retrieve_user_by_name")]
pub async fn retrieve_user_by_name(&self, name: &str) -> Result<Option<StoredUser>> {
let mut executor = self.conn.lock().await;
let res = sqlx::query_as(
"select u.id, u.name, c.password, a.hash as argon2_hash
from users u left join challenges_plain_password c on u.id = c.user_id
left join challenges_argon2_password a on u.id = a.user_id
where u.name = ?;",
)
.bind(name)
.fetch_optional(&mut *executor)
.await?;
Ok(res)
}
#[tracing::instrument(skip(self), name = "Storage::check_user_existence")]
pub async fn check_user_existence(&self, username: &str) -> Result<bool> {
let mut executor = self.conn.lock().await;
let result: Option<(String,)> = sqlx::query_as("select name from users where name = ?;")
.bind(username)
.fetch_optional(&mut *executor)
.await?;
Ok(result.is_some())
}
#[tracing::instrument(skip(self), name = "Storage::create_user")]
pub async fn create_user(&self, name: &str) -> Result<()> {
let query = sqlx::query(
"insert into users(name)
values (?);",
)
.bind(name);
let mut executor = self.conn.lock().await;
query.execute(&mut *executor).await?;
Ok(())
}
#[tracing::instrument(skip(self, pwd), name = "Storage::set_password")]
pub async fn set_password(&self, name: &str, pwd: &str) -> Result<Option<()>> {
async fn inner(txn: &mut Transaction<'_, Sqlite>, name: &str, pwd: &str) -> Result<Option<()>> {
let id: Option<(u32,)> = sqlx::query_as("select * from users where name = ? limit 1;")
.bind(name)
.fetch_optional(&mut **txn)
.await?;
let Some((id,)) = id else {
return Ok(None);
};
sqlx::query("insert or replace into challenges_plain_password(user_id, password) values (?, ?);")
.bind(id)
.bind(pwd)
.execute(&mut **txn)
.await?;
Ok(Some(()))
}
let mut executor = self.conn.lock().await;
let mut tx = executor.begin().await?;
let res = inner(&mut tx, name, pwd).await;
match res {
Ok(e) => {
tx.commit().await?;
Ok(e)
}
Err(e) => {
tx.rollback().await?;
Err(e)
}
}
}
#[tracing::instrument(skip(self), name = "Storage::retrieve_user_id_by_name")]
pub async fn retrieve_user_id_by_name(&self, name: &str) -> Result<Option<u32>> {
let mut executor = self.conn.lock().await;
let res: Option<(u32,)> = sqlx::query_as("select u.id from users u where u.name = ?;")
.bind(name)
.fetch_optional(&mut *executor)
.await?;
Ok(res.map(|(id,)| id))
}
#[tracing::instrument(skip(self), name = "Storage::create_or_retrieve_user_id_by_name")]
pub async fn create_or_retrieve_user_id_by_name(&self, name: &str) -> Result<u32> {
let mut executor = self.conn.lock().await;
let res: (u32,) = sqlx::query_as(
"insert into users(name)
values (?)
on conflict(name) do update set name = excluded.name
returning id;",
)
.bind(name)
.fetch_one(&mut *executor)
.await?;
Ok(res.0)
}
}

View File

@ -1,19 +1,24 @@
//! Domain of rooms — chats with multiple participants. //! Domain of rooms — chats with multiple participants.
use std::collections::HashSet;
use std::{collections::HashMap, hash::Hash, sync::Arc}; use std::{collections::HashMap, hash::Hash, sync::Arc};
use chrono::{DateTime, Utc};
use prometheus::{IntGauge, Registry as MetricRegistry}; use prometheus::{IntGauge, Registry as MetricRegistry};
use serde::Serialize; use serde::Serialize;
use sqlx::sqlite::SqliteRow;
use sqlx::{FromRow, Row};
use tokio::sync::RwLock as AsyncRwLock; use tokio::sync::RwLock as AsyncRwLock;
use crate::player::{PlayerHandle, PlayerId, Updates}; use crate::player::{PlayerHandle, PlayerId, Updates};
use crate::prelude::*; use crate::prelude::*;
use crate::repo::Storage; use crate::{LavinaCore, Services};
/// Opaque room id /// Opaque room id
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize)] #[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize)]
pub struct RoomId(Str); pub struct RoomId(Str);
impl RoomId { impl RoomId {
pub fn from(str: impl Into<Str>) -> Result<RoomId> { pub fn try_from(str: impl Into<Str>) -> Result<RoomId> {
let bytes = str.into(); let bytes = str.into();
if bytes.len() > 32 { if bytes.len() > 32 {
return Err(anyhow::Error::msg("Room name cannot be longer than 32 symbols")); return Err(anyhow::Error::msg("Room name cannot be longer than 32 symbols"));
@ -31,54 +36,40 @@ impl RoomId {
} }
} }
/// Shared datastructure for storing metadata about rooms. /// Shared data structure for storing metadata about rooms.
#[derive(Clone)] pub(crate) struct RoomRegistry(AsyncRwLock<RoomRegistryInner>);
pub struct RoomRegistry(Arc<AsyncRwLock<RoomRegistryInner>>);
impl RoomRegistry { impl RoomRegistry {
pub fn new(metrics: &mut MetricRegistry, storage: Storage) -> Result<RoomRegistry> { pub fn new(metrics: &mut MetricRegistry) -> Result<RoomRegistry> {
let metric_active_rooms = IntGauge::new("chat_rooms_active", "Number of alive room actors")?; let metric_active_rooms = IntGauge::new("chat_rooms_active", "Number of alive room actors")?;
metrics.register(Box::new(metric_active_rooms.clone()))?; metrics.register(Box::new(metric_active_rooms.clone()))?;
let inner = RoomRegistryInner { let inner = RoomRegistryInner {
rooms: HashMap::new(), rooms: HashMap::new(),
metric_active_rooms, metric_active_rooms,
storage,
}; };
Ok(RoomRegistry(Arc::new(AsyncRwLock::new(inner)))) Ok(RoomRegistry(AsyncRwLock::new(inner)))
} }
pub async fn get_or_create_room(&mut self, room_id: RoomId) -> Result<RoomHandle> { pub fn shutdown(self) {
// TODO iterate over rooms and stop them
}
#[tracing::instrument(skip(self, services), name = "RoomRegistry::get_or_create_room")]
pub async fn get_or_create_room(&self, services: &Services, room_id: RoomId) -> Result<RoomHandle> {
let mut inner = self.0.write().await; let mut inner = self.0.write().await;
if let Some(room_handle) = inner.rooms.get(&room_id) { if let Some(room_handle) = inner.get_or_load_room(services, &room_id).await? {
// room was already loaded into memory
log::debug!("Room {} was loaded already", &room_id.0);
Ok(room_handle.clone()) Ok(room_handle.clone())
} else if let Some(stored_room) = inner.storage.retrieve_room_by_name(&*room_id.0).await? {
// room exists, but was not loaded
log::debug!("Loading room {}...", &room_id.0);
let room = Room {
storage_id: stored_room.id,
room_id: room_id.clone(),
subscriptions: HashMap::new(), // TODO figure out how to populate subscriptions
topic: stored_room.topic.into(),
message_count: stored_room.message_count,
storage: inner.storage.clone(),
};
let room_handle = RoomHandle(Arc::new(AsyncRwLock::new(room)));
inner.rooms.insert(room_id, room_handle.clone());
inner.metric_active_rooms.inc();
Ok(room_handle)
} else { } else {
// room does not exist, create it and load
log::debug!("Creating room {}...", &room_id.0); log::debug!("Creating room {}...", &room_id.0);
let topic = "New room"; let topic = "New room";
let id = inner.storage.create_new_room(&*room_id.0, &*topic).await?; let id = services.storage.create_new_room(&*room_id.0, &*topic).await?;
let room = Room { let room = Room {
storage_id: id, storage_id: id,
room_id: room_id.clone(), room_id: room_id.clone(),
subscriptions: HashMap::new(), subscriptions: HashMap::new(),
members: HashSet::new(),
topic: topic.into(), topic: topic.into(),
message_count: 0, message_count: 0,
storage: inner.storage.clone(),
}; };
let room_handle = RoomHandle(Arc::new(AsyncRwLock::new(room))); let room_handle = RoomHandle(Arc::new(AsyncRwLock::new(room)));
inner.rooms.insert(room_id, room_handle.clone()); inner.rooms.insert(room_id, room_handle.clone());
@ -87,12 +78,13 @@ impl RoomRegistry {
} }
} }
pub async fn get_room(&self, room_id: &RoomId) -> Option<RoomHandle> { #[tracing::instrument(skip(self, services), name = "RoomRegistry::get_room")]
let inner = self.0.read().await; pub async fn get_room(&self, services: &Services, room_id: &RoomId) -> Result<Option<RoomHandle>> {
let res = inner.rooms.get(room_id); let mut inner = self.0.write().await;
res.map(|r| r.clone()) inner.get_or_load_room(services, room_id).await
} }
#[tracing::instrument(skip(self), name = "RoomRegistry::get_all_rooms")]
pub async fn get_all_rooms(&self) -> Vec<RoomInfo> { pub async fn get_all_rooms(&self) -> Vec<RoomInfo> {
let handles = { let handles = {
let inner = self.0.read().await; let inner = self.0.read().await;
@ -110,20 +102,82 @@ impl RoomRegistry {
struct RoomRegistryInner { struct RoomRegistryInner {
rooms: HashMap<RoomId, RoomHandle>, rooms: HashMap<RoomId, RoomHandle>,
metric_active_rooms: IntGauge, metric_active_rooms: IntGauge,
storage: Storage, }
impl RoomRegistryInner {
#[tracing::instrument(skip(self, services), name = "RoomRegistryInner::get_or_load_room")]
async fn get_or_load_room(&mut self, services: &Services, room_id: &RoomId) -> Result<Option<RoomHandle>> {
if let Some(room_handle) = self.rooms.get(room_id) {
log::debug!("Room {} was loaded already", &room_id.0);
Ok(Some(room_handle.clone()))
} else if let Some(stored_room) = services.storage.retrieve_room_by_name(&*room_id.0).await? {
log::debug!("Loading room {}...", &room_id.0);
let users = services.storage.get_users_of_a_room(stored_room.id).await?;
let room = Room {
storage_id: stored_room.id,
room_id: room_id.clone(),
subscriptions: HashMap::new(),
members: HashSet::from_iter(users.into_iter()),
topic: stored_room.topic.into(),
message_count: stored_room.message_count,
};
let room_handle = RoomHandle(Arc::new(AsyncRwLock::new(room)));
self.rooms.insert(room_id.clone(), room_handle.clone());
self.metric_active_rooms.inc();
Ok(Some(room_handle))
} else {
tracing::debug!("Room {} does not exist", &room_id.0);
Ok(None)
}
}
} }
#[derive(Clone)] #[derive(Clone)]
pub struct RoomHandle(Arc<AsyncRwLock<Room>>); pub struct RoomHandle(Arc<AsyncRwLock<Room>>);
impl RoomHandle { impl RoomHandle {
pub async fn subscribe(&self, player_id: PlayerId, player_handle: PlayerHandle) { #[tracing::instrument(skip(self, player_handle), name = "RoomHandle::subscribe")]
pub async fn subscribe(&self, player_id: &PlayerId, player_handle: PlayerHandle) {
let mut lock = self.0.write().await; let mut lock = self.0.write().await;
lock.add_subscriber(player_id, player_handle).await; tracing::info!("Adding a subscriber to a room");
lock.subscriptions.insert(player_id.clone(), player_handle);
} }
#[tracing::instrument(skip(self, services), name = "RoomHandle::add_member")]
pub async fn add_member(&self, services: &Services, player_id: &PlayerId, player_storage_id: u32) {
let mut lock = self.0.write().await;
tracing::info!("Adding a new member to a room");
let room_storage_id = lock.storage_id;
if !services.storage.is_room_member(room_storage_id, player_storage_id).await.unwrap() {
services.storage.add_room_member(room_storage_id, player_storage_id).await.unwrap();
} else {
tracing::warn!("User {:#?} has already been added to the room.", player_id);
}
lock.members.insert(player_id.clone());
let update = Updates::RoomJoined {
room_id: lock.room_id.clone(),
new_member_id: player_id.clone(),
};
lock.broadcast_update(update, player_id).await;
}
pub async fn get_message_history(&self, services: &Services, limit: u32) -> Result<Vec<StoredMessage>> {
services.storage.get_room_message_history(self.0.read().await.storage_id, limit).await
}
#[tracing::instrument(skip(self), name = "RoomHandle::unsubscribe")]
pub async fn unsubscribe(&self, player_id: &PlayerId) { pub async fn unsubscribe(&self, player_id: &PlayerId) {
let mut lock = self.0.write().await; let mut lock = self.0.write().await;
lock.subscriptions.remove(player_id); lock.subscriptions.remove(player_id);
}
#[tracing::instrument(skip(self, services), name = "RoomHandle::remove_member")]
pub async fn remove_member(&self, services: &Services, player_id: &PlayerId, player_storage_id: u32) {
let mut lock = self.0.write().await;
tracing::info!("Removing a member from a room");
let room_storage_id = lock.storage_id;
services.storage.remove_room_member(room_storage_id, player_storage_id).await.unwrap();
lock.members.remove(player_id);
let update = Updates::RoomLeft { let update = Updates::RoomLeft {
room_id: lock.room_id.clone(), room_id: lock.room_id.clone(),
former_member_id: player_id.clone(), former_member_id: player_id.clone(),
@ -131,68 +185,97 @@ impl RoomHandle {
lock.broadcast_update(update, player_id).await; lock.broadcast_update(update, player_id).await;
} }
pub async fn send_message(&self, player_id: PlayerId, body: Str) { #[tracing::instrument(skip(self, services, body, created_at), name = "RoomHandle::send_message")]
pub async fn send_message(
&self,
services: &Services,
player_id: &PlayerId,
body: Str,
created_at: DateTime<Utc>,
) -> Result<()> {
let mut lock = self.0.write().await; let mut lock = self.0.write().await;
let res = lock.send_message(player_id, body).await; let res = lock.send_message(services, player_id, body, created_at).await;
if let Err(err) = res { if let Err(err) = &res {
log::warn!("Failed to send message: {err:?}"); tracing::error!("Failed to send message: {err:?}");
} }
res
} }
#[tracing::instrument(skip(self), name = "RoomHandle::get_room_info")]
pub async fn get_room_info(&self) -> RoomInfo { pub async fn get_room_info(&self) -> RoomInfo {
let lock = self.0.read().await; let lock = self.0.read().await;
RoomInfo { RoomInfo {
id: lock.room_id.clone(), id: lock.room_id.clone(),
members: lock.subscriptions.keys().map(|x| x.clone()).collect::<Vec<_>>(), members: lock.members.iter().map(|x| x.clone()).collect::<Vec<_>>(),
topic: lock.topic.clone(), topic: lock.topic.clone(),
} }
} }
pub async fn set_topic(&mut self, changer_id: PlayerId, new_topic: Str) { #[tracing::instrument(skip(self, services, new_topic), name = "RoomHandle::set_topic")]
pub async fn set_topic(&self, services: &Services, changer_id: &PlayerId, new_topic: Str) -> Result<()> {
let mut lock = self.0.write().await; let mut lock = self.0.write().await;
let storage_id = lock.storage_id;
lock.topic = new_topic.clone(); lock.topic = new_topic.clone();
services.storage.set_room_topic(storage_id, &new_topic).await?;
let update = Updates::RoomTopicChanged { let update = Updates::RoomTopicChanged {
room_id: lock.room_id.clone(), room_id: lock.room_id.clone(),
new_topic: new_topic.clone(), new_topic: new_topic.clone(),
}; };
lock.broadcast_update(update, &changer_id).await; lock.broadcast_update(update, changer_id).await;
Ok(())
} }
} }
struct Room { struct Room {
/// The numeric node-local id of the room as it is stored in the database.
storage_id: u32, storage_id: u32,
/// The cluster-global id of the room.
room_id: RoomId, room_id: RoomId,
/// Player actors on the local node which are subscribed to this room's updates.
subscriptions: HashMap<PlayerId, PlayerHandle>, subscriptions: HashMap<PlayerId, PlayerHandle>,
/// Members of the room.
members: HashSet<PlayerId>,
/// The total number of messages. Used to calculate the id of the new message.
message_count: u32, message_count: u32,
topic: Str, topic: Str,
storage: Storage,
} }
impl Room {
async fn add_subscriber(&mut self, player_id: PlayerId, player_handle: PlayerHandle) {
tracing::info!("Adding a subscriber to room");
self.subscriptions.insert(player_id.clone(), player_handle);
let update = Updates::RoomJoined {
room_id: self.room_id.clone(),
new_member_id: player_id.clone(),
};
self.broadcast_update(update, &player_id).await;
}
async fn send_message(&mut self, author_id: PlayerId, body: Str) -> Result<()> { impl Room {
#[tracing::instrument(skip(self, services, body, created_at), name = "Room::send_message")]
async fn send_message(
&mut self,
services: &Services,
author_id: &PlayerId,
body: Str,
created_at: DateTime<Utc>,
) -> Result<()> {
tracing::info!("Adding a message to room"); tracing::info!("Adding a message to room");
self.storage services
.insert_message(self.storage_id, self.message_count, &body, &*author_id.as_inner()) .storage
.insert_room_message(
self.storage_id,
self.message_count,
&body,
&*author_id.as_inner(),
&created_at,
)
.await?; .await?;
self.message_count += 1; self.message_count += 1;
let update = Updates::NewMessage { let update = Updates::NewMessage {
room_id: self.room_id.clone(), room_id: self.room_id.clone(),
author_id: author_id.clone(), author_id: author_id.clone(),
body, body,
created_at,
}; };
self.broadcast_update(update, &author_id).await; self.broadcast_update(update, author_id).await;
Ok(()) Ok(())
} }
/// Broadcasts an update to all players except the one who caused the update.
///
/// This is called after handling a client command.
/// Sending the update to the player who sent the command is handled by the player actor.
#[tracing::instrument(skip(self, update), name = "Room::broadcast_update")]
async fn broadcast_update(&self, update: Updates, except: &PlayerId) { async fn broadcast_update(&self, update: Updates, except: &PlayerId) {
tracing::debug!("Broadcasting an update to {} subs", self.subscriptions.len()); tracing::debug!("Broadcasting an update to {} subs", self.subscriptions.len());
for (player_id, sub) in &self.subscriptions { for (player_id, sub) in &self.subscriptions {
@ -211,3 +294,11 @@ pub struct RoomInfo {
pub members: Vec<PlayerId>, pub members: Vec<PlayerId>,
pub topic: Str, pub topic: Str,
} }
#[derive(Debug, FromRow)]
pub struct StoredMessage {
pub id: u32,
pub author_name: String,
pub content: String,
pub created_at: DateTime<Utc>,
}

View File

@ -1,5 +1,7 @@
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
pub mod rooms;
#[derive(Serialize, Deserialize)] #[derive(Serialize, Deserialize)]
pub struct ErrorResponse<'a> { pub struct ErrorResponse<'a> {
pub code: &'a str, pub code: &'a str,
@ -11,6 +13,11 @@ pub struct CreatePlayerRequest<'a> {
pub name: &'a str, pub name: &'a str,
} }
#[derive(Serialize, Deserialize)]
pub struct StopPlayerRequest<'a> {
pub name: &'a str,
}
#[derive(Serialize, Deserialize)] #[derive(Serialize, Deserialize)]
pub struct ChangePasswordRequest<'a> { pub struct ChangePasswordRequest<'a> {
pub player_name: &'a str, pub player_name: &'a str,
@ -19,6 +26,7 @@ pub struct ChangePasswordRequest<'a> {
pub mod paths { pub mod paths {
pub const CREATE_PLAYER: &'static str = "/mgmt/create_player"; pub const CREATE_PLAYER: &'static str = "/mgmt/create_player";
pub const STOP_PLAYER: &'static str = "/mgmt/stop_player";
pub const SET_PASSWORD: &'static str = "/mgmt/set_password"; pub const SET_PASSWORD: &'static str = "/mgmt/set_password";
} }

View File

@ -0,0 +1,24 @@
use serde::{Deserialize, Serialize};
#[derive(Serialize, Deserialize)]
pub struct SendMessageReq<'a> {
pub room_id: &'a str,
pub author_id: &'a str,
pub message: &'a str,
}
#[derive(Serialize, Deserialize)]
pub struct SetTopicReq<'a> {
pub room_id: &'a str,
pub author_id: &'a str,
pub topic: &'a str,
}
pub mod paths {
pub const SEND_MESSAGE: &'static str = "/mgmt/rooms/send_message";
pub const SET_TOPIC: &'static str = "/mgmt/rooms/set_topic";
}
pub mod errors {
pub const ROOM_NOT_FOUND: &'static str = "room_not_found";
}

View File

@ -12,6 +12,7 @@ tokio.workspace = true
prometheus.workspace = true prometheus.workspace = true
futures-util.workspace = true futures-util.workspace = true
nonempty.workspace = true nonempty.workspace = true
chrono.workspace = true
bitflags = "2.4.1" bitflags = "2.4.1"
proto-irc = { path = "../proto-irc" } proto-irc = { path = "../proto-irc" }
sasl = { path = "../sasl" } sasl = { path = "../sasl" }

View File

@ -1,2 +0,0 @@
max_width = 120
chain_width = 120

View File

@ -1,9 +1,10 @@
use bitflags::bitflags; use bitflags::bitflags;
bitflags! { bitflags! {
#[derive(Debug)] #[derive(Debug, Clone, Copy)]
pub struct Capabilities: u32 { pub struct Capabilities: u32 {
const None = 0;
const Sasl = 1 << 0; const Sasl = 1 << 0;
const ServerTime = 1 << 1;
const ChatHistory = 1 << 2;
} }
} }

View File

@ -0,0 +1,20 @@
use std::future::Future;
use anyhow::Result;
use tokio::io::AsyncWrite;
use crate::RegisteredUser;
use lavina_core::player::PlayerConnection;
use lavina_core::prelude::Str;
pub struct IrcConnection<'a, T: AsyncWrite + Unpin> {
pub server_name: Str,
pub writer: &'a mut T,
pub player_connection: &'a mut PlayerConnection,
pub user: &'a RegisteredUser,
}
/// Represents a client-to-server IRC message that can be handled by the server.
pub trait IrcCommand {
fn handle_with(&self, conn: &mut IrcConnection<impl AsyncWrite + Unpin>) -> impl Future<Output = Result<()>>;
}

View File

@ -0,0 +1,57 @@
use anyhow::Result;
use chrono::SecondsFormat;
use tokio::io::{AsyncWrite, AsyncWriteExt};
use crate::cap::Capabilities;
use crate::handler::{IrcCommand, IrcConnection};
use lavina_core::player::RoomHistoryResult;
use lavina_core::room::RoomId;
use proto_irc::client::ChatHistory;
use proto_irc::server::{ServerMessage, ServerMessageBody};
use proto_irc::{Chan, Recipient, Tag};
impl IrcCommand for ChatHistory {
async fn handle_with(&self, conn: &mut IrcConnection<'_, impl AsyncWrite + Unpin>) -> Result<()> {
if !conn.user.enabled_capabilities.contains(Capabilities::ChatHistory) {
tracing::debug!(
"Requested chat history for user {:?} even though the capability was not negotiated",
conn.user.nickname
);
return Ok(());
}
let channel_name = match self.chan.clone() {
Chan::Global(chan) => chan,
// TODO Respond with an error when a local channel is requested
Chan::Local(chan) => chan,
};
let room_id = &RoomId::try_from(channel_name.clone())?;
let res = conn.player_connection.get_room_message_history(room_id, self.limit).await?;
match res {
RoomHistoryResult::Success(messages) => {
for message in messages {
let mut tags = vec![];
if conn.user.enabled_capabilities.contains(Capabilities::ServerTime) {
let tag = Tag {
key: "time".into(),
value: Some(message.created_at.to_rfc3339_opts(SecondsFormat::Millis, true).into()),
};
tags.push(tag);
}
ServerMessage {
tags,
sender: Some(message.author_name.into()),
body: ServerMessageBody::PrivateMessage {
target: Recipient::Chan(self.chan.clone()),
body: message.content.into(),
},
}
.write_async(conn.writer)
.await?;
}
conn.writer.flush().await?;
}
RoomHistoryResult::NoSuchRoom => {}
}
Ok(())
}
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,51 @@
use anyhow::Result;
use tokio::io::AsyncWrite;
use crate::handler::{IrcCommand, IrcConnection};
use lavina_core::player::GetInfoResult;
use lavina_core::player::PlayerId;
use lavina_core::prelude::Str;
use proto_irc::client::Whois;
use proto_irc::commands::whois::error::{ErrNoNicknameGiven431, ErrNoSuchNick401};
use proto_irc::commands::whois::response::RplEndOfWhois318;
use proto_irc::response::IrcResponseMessage;
use proto_irc::response::WriteResponse;
impl IrcCommand for Whois {
async fn handle_with(&self, conn: &mut IrcConnection<'_, impl AsyncWrite + Unpin>) -> Result<()> {
match self {
Whois::Nick(nick) => handle_nick_target(nick.clone(), conn).await?,
Whois::TargetNick(_, nick) => handle_nick_target(nick.clone(), conn).await?,
Whois::EmptyArgs => {
IrcResponseMessage::empty_tags(
Some(conn.server_name.clone()),
ErrNoNicknameGiven431::new(conn.server_name.clone()),
)
.write_response(conn.writer)
.await?
}
}
Ok(())
}
}
async fn handle_nick_target(nick: Str, conn: &mut IrcConnection<'_, impl AsyncWrite + Unpin>) -> Result<()> {
match conn.player_connection.check_user_existence(PlayerId::from(nick.clone())?).await? {
GetInfoResult::UserExists => {}
GetInfoResult::UserDoesntExist => {
IrcResponseMessage::empty_tags(
Some(conn.server_name.clone()),
ErrNoSuchNick401::new(conn.user.nickname.clone(), nick.clone()),
)
.write_response(conn.writer)
.await?
}
}
IrcResponseMessage::empty_tags(
Some(conn.server_name.clone()),
RplEndOfWhois318::new(conn.user.nickname.clone(), nick.clone()),
)
.write_response(conn.writer)
.await?;
Ok(())
}

View File

@ -1,20 +1,27 @@
use std::io::ErrorKind;
use std::time::Duration; use std::time::Duration;
use anyhow::{anyhow, Result}; use anyhow::{anyhow, Result};
use chrono::{DateTime, SecondsFormat};
use prometheus::Registry as MetricsRegistry; use prometheus::Registry as MetricsRegistry;
use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader}; use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader};
use tokio::net::tcp::{ReadHalf, WriteHalf}; use tokio::net::tcp::{ReadHalf, WriteHalf};
use tokio::net::TcpStream; use tokio::net::TcpStream;
use lavina_core::clustering::{ClusterConfig, ClusterMetadata};
use lavina_core::player::{JoinResult, PlayerConnectionResult, PlayerId, SendMessageResult};
use lavina_core::repo::{Storage, StorageConfig}; use lavina_core::repo::{Storage, StorageConfig};
use lavina_core::{player::PlayerRegistry, room::RoomRegistry}; use lavina_core::room::RoomId;
use lavina_core::LavinaCore;
use projection_irc::APP_VERSION;
use projection_irc::{launch, read_irc_message, RunningServer, ServerConfig}; use projection_irc::{launch, read_irc_message, RunningServer, ServerConfig};
struct TestScope<'a> { struct TestScope<'a> {
reader: BufReader<ReadHalf<'a>>, reader: BufReader<ReadHalf<'a>>,
writer: WriteHalf<'a>, writer: WriteHalf<'a>,
buffer: Vec<u8>, buffer: Vec<u8>,
pub timeout: Duration, pub timeout_optimistic: Duration,
pub timeout_pessimistic: Duration,
} }
impl<'a> TestScope<'a> { impl<'a> TestScope<'a> {
@ -22,12 +29,14 @@ impl<'a> TestScope<'a> {
let (reader, writer) = stream.split(); let (reader, writer) = stream.split();
let reader = BufReader::new(reader); let reader = BufReader::new(reader);
let buffer = vec![]; let buffer = vec![];
let timeout = Duration::from_millis(100); let timeout_optimistic = Duration::from_millis(50);
let timeout_pessimistic = Duration::from_millis(2000);
TestScope { TestScope {
reader, reader,
writer, writer,
buffer, buffer,
timeout, timeout_optimistic,
timeout_pessimistic,
} }
} }
@ -40,15 +49,48 @@ impl<'a> TestScope<'a> {
async fn expect(&mut self, str: &str) -> Result<()> { async fn expect(&mut self, str: &str) -> Result<()> {
tracing::debug!("Expecting {}", str); tracing::debug!("Expecting {}", str);
let len = tokio::time::timeout(self.timeout, read_irc_message(&mut self.reader, &mut self.buffer)).await??; let len = tokio::time::timeout(
self.timeout_pessimistic,
read_irc_message(&mut self.reader, &mut self.buffer),
)
.await??;
assert_eq!(std::str::from_utf8(&self.buffer[..len - 2])?, str); assert_eq!(std::str::from_utf8(&self.buffer[..len - 2])?, str);
self.buffer.clear(); self.buffer.clear();
Ok(()) Ok(())
} }
async fn expect_that(&mut self, validate: impl FnOnce(&str) -> bool) -> Result<()> {
let len = tokio::time::timeout(
self.timeout_pessimistic,
read_irc_message(&mut self.reader, &mut self.buffer),
)
.await??;
let msg = std::str::from_utf8(&self.buffer[..len - 2])?;
if !validate(msg) {
return Err(anyhow!("unexpected message: {:?}", msg));
}
self.buffer.clear();
Ok(())
}
async fn expect_server_introduction(&mut self, nick: &str) -> Result<()> {
self.expect(&format!(":testserver 001 {nick} :Welcome to testserver Server")).await?;
self.expect(&format!(":testserver 002 {nick} :Welcome to testserver Server")).await?;
self.expect(&format!(":testserver 003 {nick} :Welcome to testserver Server")).await?;
self.expect(&format!(
":testserver 004 {nick} testserver {APP_VERSION} r CFILPQbcefgijklmnopqrstvz"
))
.await?;
self.expect(&format!(
":testserver 005 {nick} CHANTYPES=# :are supported by this server"
))
.await?;
Ok(())
}
async fn expect_eof(&mut self) -> Result<()> { async fn expect_eof(&mut self) -> Result<()> {
let mut buf = [0; 1]; let mut buf = [0; 1];
let len = tokio::time::timeout(self.timeout, self.reader.read(&mut buf)).await??; let len = tokio::time::timeout(self.timeout_pessimistic, self.reader.read(&mut buf)).await??;
if len != 0 { if len != 0 {
return Err(anyhow!("not a eof")); return Err(anyhow!("not a eof"));
} }
@ -57,53 +99,85 @@ impl<'a> TestScope<'a> {
async fn expect_nothing(&mut self) -> Result<()> { async fn expect_nothing(&mut self) -> Result<()> {
let mut buf = [0; 1]; let mut buf = [0; 1];
match tokio::time::timeout(self.timeout, self.reader.read(&mut buf)).await { match tokio::time::timeout(self.timeout_optimistic, self.reader.read(&mut buf)).await {
Ok(res) => Err(anyhow!("received something: {:?}", res)), Ok(res) => Err(anyhow!("received something: {:?}", res)),
Err(_) => Ok(()), Err(_) => Ok(()),
} }
} }
async fn expect_cap_ls(&mut self) -> Result<()> {
self.expect(":testserver CAP * LS :sasl=PLAIN server-time draft/chathistory").await?;
Ok(())
}
} }
struct TestServer { struct TestServer {
metrics: MetricsRegistry, core: LavinaCore,
storage: Storage,
rooms: RoomRegistry,
players: PlayerRegistry,
server: RunningServer, server: RunningServer,
} }
impl TestServer { impl TestServer {
async fn start() -> Result<TestServer> { async fn start() -> Result<TestServer> {
let _ = tracing_subscriber::fmt::try_init(); let _ = tracing_subscriber::fmt::try_init();
let config = ServerConfig { let config = ServerConfig {
listen_on: "127.0.0.1:0".parse().unwrap(), listen_on: "127.0.0.1:0".parse()?,
server_name: "testserver".into(), server_name: "testserver".into(),
}; };
let mut metrics = MetricsRegistry::new(); let mut metrics = MetricsRegistry::new();
let mut storage = Storage::open(StorageConfig { let storage = Storage::open(StorageConfig {
db_path: ":memory:".into(), db_path: ":memory:".into(),
}) })
.await?; .await?;
let rooms = RoomRegistry::new(&mut metrics, storage.clone()).unwrap(); let cluster_config = ClusterConfig {
let players = PlayerRegistry::empty(rooms.clone(), &mut metrics).unwrap(); addresses: vec![],
let server = launch(config, players.clone(), rooms.clone(), metrics.clone(), storage.clone()).await.unwrap(); metadata: ClusterMetadata {
Ok(TestServer { node_id: 0,
metrics, main_owner: 0,
storage, rooms: Default::default(),
rooms, },
players, };
server, let core = LavinaCore::new(&mut metrics, cluster_config, storage).await?;
}) let server = launch(config, core.clone(), metrics.clone()).await?;
Ok(TestServer { core, server })
}
async fn reboot(self) -> Result<TestServer> {
let config = ServerConfig {
listen_on: "127.0.0.1:0".parse()?,
server_name: "testserver".into(),
};
let cluster_config = ClusterConfig {
addresses: vec![],
metadata: ClusterMetadata {
node_id: 0,
main_owner: 0,
rooms: Default::default(),
},
};
let TestServer { core, server } = self;
server.terminate().await?;
let storage = core.shutdown().await;
let mut metrics = MetricsRegistry::new();
let core = LavinaCore::new(&mut metrics, cluster_config, storage).await?;
let server = launch(config, core.clone(), metrics.clone()).await?;
Ok(TestServer { core, server })
}
async fn shutdown(self) {
let _ = self.server.terminate().await;
let storage = self.core.shutdown().await;
let _ = storage.close().await;
} }
} }
#[tokio::test] #[tokio::test]
async fn scenario_basic() -> Result<()> { async fn scenario_basic() -> Result<()> {
let mut server = TestServer::start().await?; let server = TestServer::start().await?;
// test scenario // test scenario
server.storage.create_user("tester").await?; server.core.create_player(&PlayerId::from("tester")?).await?;
server.storage.set_password("tester", "password").await?; server.core.set_password("tester", "password").await?;
let mut stream = TcpStream::connect(server.server.addr).await?; let mut stream = TcpStream::connect(server.server.addr).await?;
let mut s = TestScope::new(&mut stream); let mut s = TestScope::new(&mut stream);
@ -111,11 +185,7 @@ async fn scenario_basic() -> Result<()> {
s.send("PASS password").await?; s.send("PASS password").await?;
s.send("NICK tester").await?; s.send("NICK tester").await?;
s.send("USER UserName 0 * :Real Name").await?; s.send("USER UserName 0 * :Real Name").await?;
s.expect(":testserver 001 tester :Welcome to Kek Server").await?; s.expect_server_introduction("tester").await?;
s.expect(":testserver 002 tester :Welcome to Kek Server").await?;
s.expect(":testserver 003 tester :Welcome to Kek Server").await?;
s.expect(":testserver 004 tester testserver kek-0.1.alpha.3 r CFILPQbcefgijklmnopqrstvz").await?;
s.expect(":testserver 005 tester CHANTYPES=# :are supported by this server").await?;
s.expect_nothing().await?; s.expect_nothing().await?;
s.send("QUIT :Leaving").await?; s.send("QUIT :Leaving").await?;
s.expect(":testserver ERROR :Leaving the server").await?; s.expect(":testserver ERROR :Leaving the server").await?;
@ -125,7 +195,362 @@ async fn scenario_basic() -> Result<()> {
// wrap up // wrap up
server.server.terminate().await?; server.shutdown().await;
Ok(())
}
#[tokio::test]
async fn scenario_basic_with_chathistory() -> Result<()> {
let server = TestServer::start().await?;
// test scenario
server.core.create_player(&PlayerId::from("tester")?).await?;
server.core.set_password("tester", "password").await?;
let mut stream = TcpStream::connect(server.server.addr).await?;
let mut s = TestScope::new(&mut stream);
s.send("NICK tester").await?;
s.send("CAP REQ :draft/chathistory").await?;
s.send("USER UserName 0 * :Real Name").await?;
s.send("PASS password").await?;
s.send("USER UserName 0 * :Real Name").await?;
s.expect(":testserver CAP tester ACK :draft/chathistory").await?;
s.send("CAP END").await?;
s.expect_server_introduction("tester").await?;
s.expect_nothing().await?;
s.send("JOIN #test").await?;
s.expect(":tester JOIN #test").await?;
s.expect(":testserver 332 tester #test :New room").await?;
s.expect(":testserver 353 tester = #test :tester").await?;
s.expect(":testserver 366 tester #test :End of /NAMES list").await?;
s.send("PRIVMSG #test :Message1").await?;
s.send("PRIVMSG #test :Message2").await?;
s.send("PRIVMSG #test :Message3").await?;
s.send("PRIVMSG #test :Message4").await?;
s.send("CHATHISTORY LATEST #test * 1").await?;
s.expect(":tester PRIVMSG #test :Message4").await?;
s.send("CHATHISTORY LATEST #test * 3").await?;
s.expect(":tester PRIVMSG #test :Message2").await?;
s.expect(":tester PRIVMSG #test :Message3").await?;
s.expect(":tester PRIVMSG #test :Message4").await?;
s.expect_nothing().await?;
s.send("QUIT :Leaving").await?;
s.expect(":testserver ERROR :Leaving the server").await?;
s.expect_eof().await?;
stream.shutdown().await?;
// wrap up
server.shutdown().await;
Ok(())
}
#[tokio::test]
async fn scenario_join_and_reboot() -> Result<()> {
let server = TestServer::start().await?;
// test scenario
server.core.create_player(&PlayerId::from("tester")?).await?;
server.core.set_password("tester", "password").await?;
let mut stream = TcpStream::connect(server.server.addr).await?;
let mut s = TestScope::new(&mut stream);
// Open a connection and join a channel
s.send("PASS password").await?;
s.send("NICK tester").await?;
s.send("USER UserName 0 * :Real Name").await?;
s.expect_server_introduction("tester").await?;
s.expect_nothing().await?;
s.send("JOIN #test").await?;
s.expect(":tester JOIN #test").await?;
s.expect(":testserver 332 tester #test :New room").await?;
s.expect(":testserver 353 tester = #test :tester").await?;
s.expect(":testserver 366 tester #test :End of /NAMES list").await?;
s.send("PRIVMSG #test :Hello").await?;
s.send("QUIT :Leaving").await?;
s.expect(":testserver ERROR :Leaving the server").await?;
s.expect_eof().await?;
stream.shutdown().await?;
// Open a new connection and expect to be force-joined to the channel
let mut stream = TcpStream::connect(server.server.addr).await?;
let mut s = TestScope::new(&mut stream);
async fn test(s: &mut TestScope<'_>) -> Result<()> {
s.send("PASS password").await?;
s.send("NICK tester").await?;
s.send("USER UserName 0 * :Real Name").await?;
s.expect_server_introduction("tester").await?;
s.expect(":tester JOIN #test").await?;
s.expect(":testserver 332 tester #test :New room").await?;
s.expect(":testserver 353 tester = #test :tester").await?;
s.expect(":testserver 366 tester #test :End of /NAMES list").await?;
s.send("QUIT :Leaving").await?;
s.expect(":testserver ERROR :Leaving the server").await?;
s.expect_eof().await?;
Ok(())
}
test(&mut s).await?;
stream.shutdown().await?;
// Reboot the server
let server = server.reboot().await?;
// Open a new connection and expect to be force-joined to the channel
let mut stream = TcpStream::connect(server.server.addr).await?;
let mut s = TestScope::new(&mut stream);
test(&mut s).await?;
stream.shutdown().await?;
// wrap up
server.shutdown().await;
Ok(())
}
#[tokio::test]
async fn scenario_subscribe_after_reboot() -> Result<()> {
let server = TestServer::start().await?;
server.core.create_player(&PlayerId::from("tester1")?).await?;
server.core.set_password("tester1", "password").await?;
server.core.create_player(&PlayerId::from("tester2")?).await?;
server.core.set_password("tester2", "password").await?;
let mut stream1 = TcpStream::connect(server.server.addr).await?;
let mut s1 = TestScope::new(&mut stream1);
// Open a connection, join a channel, close the connection
s1.send("PASS password").await?;
s1.send("NICK tester1").await?;
s1.send("USER UserName 0 * :Real Name").await?;
s1.expect_server_introduction("tester1").await?;
s1.expect_nothing().await?;
s1.send("JOIN #test").await?;
s1.expect(":tester1 JOIN #test").await?;
s1.expect(":testserver 332 tester1 #test :New room").await?;
s1.expect(":testserver 353 tester1 = #test :tester1").await?;
s1.expect(":testserver 366 tester1 #test :End of /NAMES list").await?;
s1.send("PRIVMSG #test :Hello").await?;
s1.send("QUIT :Leaving").await?;
s1.expect(":testserver ERROR :Leaving the server").await?;
s1.expect_eof().await?;
stream1.shutdown().await?;
let server = server.reboot().await?;
// Open a new connection and expect to be force-joined to the channel
let mut stream1 = TcpStream::connect(server.server.addr).await?;
let mut s1 = TestScope::new(&mut stream1);
s1.send("PASS password").await?;
s1.send("NICK tester1").await?;
s1.send("USER UserName 0 * :Real Name").await?;
s1.expect_server_introduction("tester1").await?;
s1.expect(":tester1 JOIN #test").await?;
s1.expect(":testserver 332 tester1 #test :New room").await?;
s1.expect(":testserver 353 tester1 = #test :tester1").await?;
s1.expect(":testserver 366 tester1 #test :End of /NAMES list").await?;
// Open a connection from the second player and join the channel
let mut stream2 = TcpStream::connect(server.server.addr).await?;
let mut s2 = TestScope::new(&mut stream2);
s2.send("PASS password").await?;
s2.send("NICK tester2").await?;
s2.send("USER UserName 0 * :Real Name").await?;
s2.expect_server_introduction("tester2").await?;
s2.send("JOIN #test").await?;
s2.expect(":tester2 JOIN #test").await?;
s2.expect(":testserver 332 tester2 #test :New room").await?;
s2.expect_that(|msg| {
msg == ":testserver 353 tester2 = #test :tester1 tester2"
|| msg == ":testserver 353 tester2 = #test :tester2 tester1"
})
.await?;
s2.expect(":testserver 366 tester2 #test :End of /NAMES list").await?;
s2.expect_nothing().await?;
// The first player should receive the joining message
s1.expect(":tester2 JOIN #test").await?;
s1.expect_nothing().await?;
// Also send a message for good measure
s2.send("PRIVMSG #test :Hello").await?;
s1.expect(":tester2 PRIVMSG #test :Hello").await?;
s1.expect_nothing().await?;
// Wrap up
s1.send("QUIT :Leaving").await?;
s1.expect(":testserver ERROR :Leaving the server").await?;
s1.expect_eof().await?;
s2.send("QUIT :Leaving").await?;
s2.expect(":testserver ERROR :Leaving the server").await?;
s2.expect_eof().await?;
stream1.shutdown().await?;
stream2.shutdown().await?;
Ok(())
}
#[tokio::test]
async fn scenario_force_join_msg() -> Result<()> {
let server = TestServer::start().await?;
// test scenario
server.core.create_player(&PlayerId::from("tester")?).await?;
server.core.set_password("tester", "password").await?;
let mut stream1 = TcpStream::connect(server.server.addr).await?;
let mut s1 = TestScope::new(&mut stream1);
let mut stream2 = TcpStream::connect(server.server.addr).await?;
let mut s2 = TestScope::new(&mut stream2);
s1.send("PASS password").await?;
s1.send("NICK tester").await?;
s1.send("USER UserName 0 * :Real Name").await?;
s1.expect_server_introduction("tester").await?;
s1.expect_nothing().await?;
s2.send("PASS password").await?;
s2.send("NICK tester").await?;
s2.send("USER UserName 0 * :Real Name").await?;
s2.expect_server_introduction("tester").await?;
s2.expect_nothing().await?;
// We join the channel from the first connection
s1.send("JOIN #test").await?;
s1.expect(":tester JOIN #test").await?;
s1.expect(":testserver 332 tester #test :New room").await?;
s1.expect(":testserver 353 tester = #test :tester").await?;
s1.expect(":testserver 366 tester #test :End of /NAMES list").await?;
// And the second connection should receive the JOIN message (forced JOIN)
s2.expect(":tester JOIN #test").await?;
s2.expect(":testserver 332 tester #test :New room").await?;
s2.expect(":testserver 353 tester = #test :tester").await?;
s2.expect(":testserver 366 tester #test :End of /NAMES list").await?;
// We send a message to the channel from the second connection
s2.send("PRIVMSG #test :Hello").await?;
// We should not receive an acknowledgement from the server
s2.expect_nothing().await?;
// But we should receive this message from the first connection
s1.expect(":tester PRIVMSG #test :Hello").await?;
s1.send("QUIT :Leaving").await?;
s1.expect(":testserver ERROR :Leaving the server").await?;
s1.expect_eof().await?;
// Closing a connection does not kick you from the channel on a different connection
s2.expect_nothing().await?;
s2.send("QUIT :Leaving").await?;
s2.expect(":testserver ERROR :Leaving the server").await?;
s2.expect_eof().await?;
stream1.shutdown().await?;
stream2.shutdown().await?;
// wrap up
server.shutdown().await;
Ok(())
}
#[tokio::test]
async fn scenario_two_users() -> Result<()> {
let server = TestServer::start().await?;
// test scenario
server.core.create_player(&PlayerId::from("tester1")?).await?;
server.core.set_password("tester1", "password").await?;
server.core.create_player(&PlayerId::from("tester2")?).await?;
server.core.set_password("tester2", "password").await?;
let mut stream1 = TcpStream::connect(server.server.addr).await?;
let mut s1 = TestScope::new(&mut stream1);
let mut stream2 = TcpStream::connect(server.server.addr).await?;
let mut s2 = TestScope::new(&mut stream2);
s1.send("PASS password").await?;
s1.send("NICK tester1").await?;
s1.send("USER UserName 0 * :Real Name").await?;
s1.expect_server_introduction("tester1").await?;
s1.expect_nothing().await?;
s2.send("PASS password").await?;
s2.send("NICK tester2").await?;
s2.send("USER UserName 0 * :Real Name").await?;
s2.expect_server_introduction("tester2").await?;
s2.expect_nothing().await?;
// Join the channel from the first user
s1.send("JOIN #test").await?;
s1.expect(":tester1 JOIN #test").await?;
s1.expect(":testserver 332 tester1 #test :New room").await?;
s1.expect(":testserver 353 tester1 = #test :tester1").await?;
s1.expect(":testserver 366 tester1 #test :End of /NAMES list").await?;
// Then join the channel from the second user
s2.send("JOIN #test").await?;
s2.expect(":tester2 JOIN #test").await?;
s2.expect(":testserver 332 tester2 #test :New room").await?;
s2.expect_that(|msg| {
msg == ":testserver 353 tester2 = #test :tester1 tester2"
|| msg == ":testserver 353 tester2 = #test :tester2 tester1"
})
.await?;
s2.expect(":testserver 366 tester2 #test :End of /NAMES list").await?;
// The first user should receive the JOIN message from the second user
s1.expect(":tester2 JOIN #test").await?;
s1.expect_nothing().await?;
s2.expect_nothing().await?;
// Send a message from the second user
s2.send("PRIVMSG #test :Hello").await?;
// The first user should receive the message
s1.expect(":tester2 PRIVMSG #test :Hello").await?;
// Leave the channel from the first user
s1.send("PART #test").await?;
s1.expect(":tester1 PART #test").await?;
// The second user should receive the PART message
s2.expect(":tester1 PART #test").await?;
s1.send("WHOIS tester2").await?;
s1.expect(":testserver 318 tester1 tester2 :End of /WHOIS list").await?;
stream1.shutdown().await?;
s2.send("WHOIS tester3").await?;
s2.expect(":testserver 401 tester2 tester3 :No such nick/channel").await?;
s2.expect(":testserver 318 tester2 tester3 :End of /WHOIS list").await?;
stream2.shutdown().await?;
server.shutdown().await;
Ok(()) Ok(())
} }
@ -135,12 +560,12 @@ AUTHENTICATE doc: https://modern.ircdocs.horse/#authenticate-message
*/ */
#[tokio::test] #[tokio::test]
async fn scenario_cap_full_negotiation() -> Result<()> { async fn scenario_cap_full_negotiation() -> Result<()> {
let mut server = TestServer::start().await?; let server = TestServer::start().await?;
// test scenario // test scenario
server.storage.create_user("tester").await?; server.core.create_player(&PlayerId::from("tester")?).await?;
server.storage.set_password("tester", "password").await?; server.core.set_password("tester", "password").await?;
let mut stream = TcpStream::connect(server.server.addr).await?; let mut stream = TcpStream::connect(server.server.addr).await?;
let mut s = TestScope::new(&mut stream); let mut s = TestScope::new(&mut stream);
@ -148,7 +573,7 @@ async fn scenario_cap_full_negotiation() -> Result<()> {
s.send("CAP LS 302").await?; s.send("CAP LS 302").await?;
s.send("NICK tester").await?; s.send("NICK tester").await?;
s.send("USER UserName 0 * :Real Name").await?; s.send("USER UserName 0 * :Real Name").await?;
s.expect(":testserver CAP * LS :sasl=PLAIN").await?; s.expect_cap_ls().await?;
s.send("CAP REQ :sasl").await?; s.send("CAP REQ :sasl").await?;
s.expect(":testserver CAP tester ACK :sasl").await?; s.expect(":testserver CAP tester ACK :sasl").await?;
s.send("AUTHENTICATE PLAIN").await?; s.send("AUTHENTICATE PLAIN").await?;
@ -159,11 +584,7 @@ async fn scenario_cap_full_negotiation() -> Result<()> {
s.send("CAP END").await?; s.send("CAP END").await?;
s.expect(":testserver 001 tester :Welcome to Kek Server").await?; s.expect_server_introduction("tester").await?;
s.expect(":testserver 002 tester :Welcome to Kek Server").await?;
s.expect(":testserver 003 tester :Welcome to Kek Server").await?;
s.expect(":testserver 004 tester testserver kek-0.1.alpha.3 r CFILPQbcefgijklmnopqrstvz").await?;
s.expect(":testserver 005 tester CHANTYPES=# :are supported by this server").await?;
s.expect_nothing().await?; s.expect_nothing().await?;
s.send("QUIT :Leaving").await?; s.send("QUIT :Leaving").await?;
s.expect(":testserver ERROR :Leaving the server").await?; s.expect(":testserver ERROR :Leaving the server").await?;
@ -173,18 +594,57 @@ async fn scenario_cap_full_negotiation() -> Result<()> {
// wrap up // wrap up
server.server.terminate().await?; server.shutdown().await;
Ok(())
}
#[tokio::test]
async fn scenario_cap_full_negotiation_nick_last() -> Result<()> {
let server = TestServer::start().await?;
// test scenario
server.core.create_player(&PlayerId::from("tester")?).await?;
server.core.set_password("tester", "password").await?;
let mut stream = TcpStream::connect(server.server.addr).await?;
let mut s = TestScope::new(&mut stream);
s.send("CAP LS 302").await?;
s.expect_cap_ls().await?;
s.send("CAP REQ :sasl").await?;
s.expect(":testserver CAP * ACK :sasl").await?;
s.send("AUTHENTICATE PLAIN").await?;
s.expect(":testserver AUTHENTICATE +").await?;
s.send("AUTHENTICATE dGVzdGVyAHRlc3RlcgBwYXNzd29yZA==").await?; // base64-encoded 'tester\x00tester\x00password'
s.expect(":testserver 900 tester tester tester :You are now logged in as tester").await?;
s.expect(":testserver 903 tester :SASL authentication successful").await?;
s.send("CAP END").await?;
s.send("USER UserName 0 * :Real Name").await?;
s.send("NICK tester").await?;
s.expect_server_introduction("tester").await?;
s.expect_nothing().await?;
s.send("QUIT :Leaving").await?;
s.expect(":testserver ERROR :Leaving the server").await?;
s.expect_eof().await?;
stream.shutdown().await?;
// wrap up
server.shutdown().await;
Ok(()) Ok(())
} }
#[tokio::test] #[tokio::test]
async fn scenario_cap_short_negotiation() -> Result<()> { async fn scenario_cap_short_negotiation() -> Result<()> {
let mut server = TestServer::start().await?; let server = TestServer::start().await?;
// test scenario // test scenario
server.storage.create_user("tester").await?; server.core.create_player(&PlayerId::from("tester")?).await?;
server.storage.set_password("tester", "password").await?; server.core.set_password("tester", "password").await?;
let mut stream = TcpStream::connect(server.server.addr).await?; let mut stream = TcpStream::connect(server.server.addr).await?;
let mut s = TestScope::new(&mut stream); let mut s = TestScope::new(&mut stream);
@ -201,11 +661,7 @@ async fn scenario_cap_short_negotiation() -> Result<()> {
s.send("CAP END").await?; s.send("CAP END").await?;
s.expect(":testserver 001 tester :Welcome to Kek Server").await?; s.expect_server_introduction("tester").await?;
s.expect(":testserver 002 tester :Welcome to Kek Server").await?;
s.expect(":testserver 003 tester :Welcome to Kek Server").await?;
s.expect(":testserver 004 tester testserver kek-0.1.alpha.3 r CFILPQbcefgijklmnopqrstvz").await?;
s.expect(":testserver 005 tester CHANTYPES=# :are supported by this server").await?;
s.expect_nothing().await?; s.expect_nothing().await?;
s.send("QUIT :Leaving").await?; s.send("QUIT :Leaving").await?;
s.expect(":testserver ERROR :Leaving the server").await?; s.expect(":testserver ERROR :Leaving the server").await?;
@ -215,18 +671,18 @@ async fn scenario_cap_short_negotiation() -> Result<()> {
// wrap up // wrap up
server.server.terminate().await?; server.shutdown().await;
Ok(()) Ok(())
} }
#[tokio::test] #[tokio::test]
async fn scenario_cap_sasl_fail() -> Result<()> { async fn scenario_cap_sasl_fail() -> Result<()> {
let mut server = TestServer::start().await?; let server = TestServer::start().await?;
// test scenario // test scenario
server.storage.create_user("tester").await?; server.core.create_player(&PlayerId::from("tester")?).await?;
server.storage.set_password("tester", "password").await?; server.core.set_password("tester", "password").await?;
let mut stream = TcpStream::connect(server.server.addr).await?; let mut stream = TcpStream::connect(server.server.addr).await?;
let mut s = TestScope::new(&mut stream); let mut s = TestScope::new(&mut stream);
@ -234,7 +690,7 @@ async fn scenario_cap_sasl_fail() -> Result<()> {
s.send("CAP LS 302").await?; s.send("CAP LS 302").await?;
s.send("NICK tester").await?; s.send("NICK tester").await?;
s.send("USER UserName 0 * :Real Name").await?; s.send("USER UserName 0 * :Real Name").await?;
s.expect(":testserver CAP * LS :sasl=PLAIN").await?; s.expect_cap_ls().await?;
s.send("CAP REQ :sasl").await?; s.send("CAP REQ :sasl").await?;
s.expect(":testserver CAP tester ACK :sasl").await?; s.expect(":testserver CAP tester ACK :sasl").await?;
s.send("AUTHENTICATE SHA256").await?; s.send("AUTHENTICATE SHA256").await?;
@ -249,11 +705,7 @@ async fn scenario_cap_sasl_fail() -> Result<()> {
s.send("CAP END").await?; s.send("CAP END").await?;
s.expect(":testserver 001 tester :Welcome to Kek Server").await?; s.expect_server_introduction("tester").await?;
s.expect(":testserver 002 tester :Welcome to Kek Server").await?;
s.expect(":testserver 003 tester :Welcome to Kek Server").await?;
s.expect(":testserver 004 tester testserver kek-0.1.alpha.3 r CFILPQbcefgijklmnopqrstvz").await?;
s.expect(":testserver 005 tester CHANTYPES=# :are supported by this server").await?;
s.expect_nothing().await?; s.expect_nothing().await?;
s.send("QUIT :Leaving").await?; s.send("QUIT :Leaving").await?;
s.expect(":testserver ERROR :Leaving the server").await?; s.expect(":testserver ERROR :Leaving the server").await?;
@ -263,6 +715,167 @@ async fn scenario_cap_sasl_fail() -> Result<()> {
// wrap up // wrap up
server.server.terminate().await?; server.shutdown().await;
Ok(())
}
#[tokio::test]
async fn terminate_socket_scenario() -> Result<()> {
let server = TestServer::start().await?;
// test scenario
server.core.create_player(&PlayerId::from("tester")?).await?;
server.core.set_password("tester", "password").await?;
let mut stream = TcpStream::connect(server.server.addr).await?;
let mut s = TestScope::new(&mut stream);
s.send("NICK tester").await?;
s.send("CAP REQ :sasl").await?;
s.send("USER UserName 0 * :Real Name").await?;
s.expect(":testserver CAP tester ACK :sasl").await?;
s.send("AUTHENTICATE PLAIN").await?;
s.expect(":testserver AUTHENTICATE +").await?;
server.shutdown().await;
assert_eq!(stream.read_u8().await.unwrap_err().kind(), ErrorKind::UnexpectedEof);
Ok(())
}
#[tokio::test]
async fn server_time_capability() -> Result<()> {
let server = TestServer::start().await?;
// test scenario
server.core.create_player(&PlayerId::from("tester")?).await?;
server.core.set_password("tester", "password").await?;
let mut stream = TcpStream::connect(server.server.addr).await?;
let mut s = TestScope::new(&mut stream);
s.send("CAP LS 302").await?;
s.send("NICK tester").await?;
s.send("USER UserName 0 * :Real Name").await?;
s.expect_cap_ls().await?;
s.send("CAP REQ :sasl server-time").await?;
s.expect(":testserver CAP tester ACK :sasl server-time").await?;
s.send("AUTHENTICATE PLAIN").await?;
s.expect(":testserver AUTHENTICATE +").await?;
s.send("AUTHENTICATE dGVzdGVyAHRlc3RlcgBwYXNzd29yZA==").await?; // base64-encoded 'tester\x00tester\x00password'
s.expect(":testserver 900 tester tester tester :You are now logged in as tester").await?;
s.expect(":testserver 903 tester :SASL authentication successful").await?;
s.send("CAP END").await?;
s.expect_server_introduction("tester").await?;
s.expect_nothing().await?;
s.send("JOIN #test").await?;
s.expect(":tester JOIN #test").await?;
s.expect(":testserver 332 tester #test :New room").await?;
s.expect(":testserver 353 tester = #test :tester").await?;
s.expect(":testserver 366 tester #test :End of /NAMES list").await?;
server.core.create_player(&PlayerId::from("some_guy")?).await?;
let mut conn = match server.core.connect_to_player(&PlayerId::from("some_guy")?).await? {
PlayerConnectionResult::Success(conn) => conn,
PlayerConnectionResult::PlayerNotFound => panic!("user was created, but not returned"),
};
let res = conn.join_room(RoomId::try_from("test")?).await?;
let JoinResult::Success(_) = res else {
panic!("Failed to join room");
};
s.expect(":some_guy JOIN #test").await?;
let SendMessageResult::Success(res) = conn.send_message(RoomId::try_from("test")?, "Hello".into()).await? else {
panic!("Failed to send message");
};
s.expect(&format!(
"@time={} :some_guy PRIVMSG #test :Hello",
res.to_rfc3339_opts(SecondsFormat::Millis, true)
))
.await?;
// formatting check
assert_eq!(
DateTime::parse_from_rfc3339(&"2024-01-01T10:00:32.123Z")?.to_rfc3339_opts(SecondsFormat::Millis, true),
"2024-01-01T10:00:32.123Z"
);
s.send("QUIT :Leaving").await?;
s.expect(":testserver ERROR :Leaving the server").await?;
s.expect_eof().await?;
stream.shutdown().await?;
// wrap up
server.shutdown().await;
Ok(())
}
#[tokio::test]
async fn scenario_two_players_dialog() -> Result<()> {
let server = TestServer::start().await?;
// test scenario
server.core.create_player(&PlayerId::from("tester1")?).await?;
server.core.set_password("tester1", "password").await?;
server.core.create_player(&PlayerId::from("tester2")?).await?;
server.core.set_password("tester2", "password").await?;
let mut stream1 = TcpStream::connect(server.server.addr).await?;
let mut s1 = TestScope::new(&mut stream1);
let mut stream2 = TcpStream::connect(server.server.addr).await?;
let mut s2 = TestScope::new(&mut stream2);
s1.send("CAP LS 302").await?;
s1.send("NICK tester1").await?;
s1.send("USER UserName 0 * :Real Name").await?;
s1.expect_cap_ls().await?;
s1.send("CAP REQ :sasl").await?;
s1.expect(":testserver CAP tester1 ACK :sasl").await?;
s1.send("AUTHENTICATE PLAIN").await?;
s1.expect(":testserver AUTHENTICATE +").await?;
s1.send("AUTHENTICATE dGVzdGVyMQB0ZXN0ZXIxAHBhc3N3b3Jk").await?; // base64-encoded 'tester1\x00tester1\x00password'
s1.expect(":testserver 900 tester1 tester1 tester1 :You are now logged in as tester1").await?;
s1.expect(":testserver 903 tester1 :SASL authentication successful").await?;
s1.send("CAP END").await?;
s1.expect_server_introduction("tester1").await?;
s1.expect_nothing().await?;
s2.send("CAP LS 302").await?;
s2.send("NICK tester2").await?;
s2.send("USER UserName 0 * :Real Name").await?;
s2.expect_cap_ls().await?;
s2.send("CAP REQ :sasl").await?;
s2.expect(":testserver CAP tester2 ACK :sasl").await?;
s2.send("AUTHENTICATE PLAIN").await?;
s2.expect(":testserver AUTHENTICATE +").await?;
s2.send("AUTHENTICATE dGVzdGVyMgB0ZXN0ZXIyAHBhc3N3b3Jk").await?; // base64-encoded 'tester2\x00tester2\x00password'
s2.expect(":testserver 900 tester2 tester2 tester2 :You are now logged in as tester2").await?;
s2.expect(":testserver 903 tester2 :SASL authentication successful").await?;
s2.send("CAP END").await?;
s2.expect_server_introduction("tester2").await?;
s2.expect_nothing().await?;
s1.send("PRIVMSG tester2 :Henlo! How are you?").await?;
s1.expect_nothing().await?;
s2.expect(":tester1 PRIVMSG tester2 :Henlo! How are you?").await?;
s2.expect_nothing().await?;
s2.send("PRIVMSG tester1 good").await?;
s2.expect_nothing().await?;
s1.expect(":tester2 PRIVMSG tester1 :good").await?;
s1.expect_nothing().await?;
stream1.shutdown().await?;
stream2.shutdown().await?;
server.shutdown().await;
Ok(()) Ok(())
} }

View File

@ -1,2 +0,0 @@
max_width = 120
chain_width = 120

View File

@ -2,32 +2,30 @@
use quick_xml::events::Event; use quick_xml::events::Event;
use lavina_core::room::RoomRegistry; use lavina_core::room::RoomId;
use proto_xmpp::bind::{BindResponse, Jid, Name, Resource, Server}; use lavina_core::LavinaCore;
use proto_xmpp::client::{Iq, IqType}; use proto_xmpp::bind::{BindRequest, BindResponse, Jid, Name, Server};
use proto_xmpp::client::{Iq, IqError, IqErrorCondition, IqErrorType, IqType};
use proto_xmpp::disco::{Feature, Identity, InfoQuery, Item, ItemQuery}; use proto_xmpp::disco::{Feature, Identity, InfoQuery, Item, ItemQuery};
use proto_xmpp::mam::{Fin, Set};
use proto_xmpp::roster::RosterQuery; use proto_xmpp::roster::RosterQuery;
use proto_xmpp::session::Session; use proto_xmpp::session::Session;
use proto_xmpp::xml::ToXml;
use crate::proto::IqClientBody; use crate::proto::IqClientBody;
use crate::XmppConnection; use crate::XmppConnection;
use proto_xmpp::xml::ToXml;
impl<'a> XmppConnection<'a> { impl<'a> XmppConnection<'a> {
#[tracing::instrument(skip(self, output, iq), name = "XmppConnection::handle_iq")]
pub async fn handle_iq(&self, output: &mut Vec<Event<'static>>, iq: Iq<IqClientBody>) { pub async fn handle_iq(&self, output: &mut Vec<Event<'static>>, iq: Iq<IqClientBody>) {
match iq.body { match iq.body {
IqClientBody::Bind(b) => { IqClientBody::Bind(req) => {
let req = Iq { let req = Iq {
from: None, from: None,
id: iq.id, id: iq.id,
to: None, to: None,
r#type: IqType::Result, r#type: IqType::Result,
body: BindResponse(Jid { body: self.bind(&req).await,
name: Some(Name("darova".into())),
server: Server("localhost".into()),
resource: Some(Resource("kek".into())),
}),
}; };
req.serialize(output); req.serialize(output);
} }
@ -52,7 +50,32 @@ impl<'a> XmppConnection<'a> {
req.serialize(output); req.serialize(output);
} }
IqClientBody::DiscoInfo(info) => { IqClientBody::DiscoInfo(info) => {
let response = disco_info(iq.to.as_deref(), &info); let response = self.disco_info(iq.to.as_ref(), &info).await;
match response {
Ok(response) => {
let req = Iq {
from: iq.to,
id: iq.id,
to: None,
r#type: IqType::Result,
body: response,
};
req.serialize(output);
}
Err(response) => {
let req = Iq {
from: iq.to,
id: iq.id,
to: None,
r#type: IqType::Error,
body: response,
};
req.serialize(output);
}
}
}
IqClientBody::DiscoItem(item) => {
let response = self.disco_items(iq.to.as_ref(), &item, self.core).await;
let req = Iq { let req = Iq {
from: iq.to, from: iq.to,
id: iq.id, id: iq.id,
@ -62,16 +85,17 @@ impl<'a> XmppConnection<'a> {
}; };
req.serialize(output); req.serialize(output);
} }
IqClientBody::DiscoItem(item) => { IqClientBody::MessageArchiveRequest(_) => {
let response = disco_items(iq.to.as_deref(), &item, self.rooms).await; let response = Iq {
let req = Iq {
from: iq.to, from: iq.to,
id: iq.id, id: iq.id,
to: None, to: None,
r#type: IqType::Result, r#type: IqType::Result,
body: response, body: Fin {
set: Set { count: Some(0) },
},
}; };
req.serialize(output); response.serialize(output);
} }
_ => { _ => {
let req = Iq { let req = Iq {
@ -79,84 +103,144 @@ impl<'a> XmppConnection<'a> {
id: iq.id, id: iq.id,
to: None, to: None,
r#type: IqType::Error, r#type: IqType::Error,
body: (), body: IqError {
r#type: IqErrorType::Cancel,
condition: None,
},
}; };
req.serialize(output); req.serialize(output);
} }
} }
} }
}
fn disco_info(to: Option<&str>, req: &InfoQuery) -> InfoQuery { #[tracing::instrument(skip(self), name = "XmppConnection::bind")]
let identity; pub(crate) async fn bind(&self, req: &BindRequest) -> BindResponse {
let feature; BindResponse(Jid {
match to { name: Some(self.user.xmpp_name.clone()),
Some("localhost") => { server: Server(self.hostname.clone()),
identity = vec![Identity { resource: Some(self.user.xmpp_resource.clone()),
category: "server".into(), })
name: None,
r#type: "im".into(),
}];
feature = vec![
Feature::new("http://jabber.org/protocol/disco#info"),
Feature::new("http://jabber.org/protocol/disco#items"),
Feature::new("iq"),
Feature::new("presence"),
]
}
Some("rooms.localhost") => {
identity = vec![Identity {
category: "conference".into(),
name: Some("Chat rooms".into()),
r#type: "text".into(),
}];
feature = vec![
Feature::new("http://jabber.org/protocol/disco#info"),
Feature::new("http://jabber.org/protocol/disco#items"),
Feature::new("http://jabber.org/protocol/muc"),
]
}
_ => {
identity = vec![];
feature = vec![];
}
};
InfoQuery {
node: None,
identity,
feature,
} }
}
async fn disco_items(to: Option<&str>, req: &ItemQuery, rooms: &RoomRegistry) -> ItemQuery { #[tracing::instrument(skip(self), name = "XmppConnection::disco_info")]
let item = match to { async fn disco_info(&self, to: Option<&Jid>, req: &InfoQuery) -> Result<InfoQuery, IqError> {
Some("localhost") => { let identity;
vec![Item { let feature;
jid: Jid {
name: None, match to {
server: Server("rooms.localhost".into()), Some(Jid {
resource: None,
},
name: None, name: None,
node: None, server,
}] resource: None,
} }) if server.0 == self.hostname => {
Some("rooms.localhost") => { identity = vec![Identity {
let room_list = rooms.get_all_rooms().await; category: "server".into(),
room_list name: None,
.into_iter() r#type: "im".into(),
.map(|room_info| Item { }];
feature = vec![
Feature::new("http://jabber.org/protocol/disco#info"),
Feature::new("http://jabber.org/protocol/disco#items"),
Feature::new("iq"),
Feature::new("presence"),
]
}
Some(Jid {
name: None,
server,
resource: None,
}) if server.0 == self.hostname_rooms => {
identity = vec![Identity {
category: "conference".into(),
name: Some("Chat rooms".into()),
r#type: "text".into(),
}];
feature = vec![
Feature::new("http://jabber.org/protocol/disco#info"),
Feature::new("http://jabber.org/protocol/disco#items"),
Feature::new("http://jabber.org/protocol/muc"),
]
}
Some(Jid {
name: Some(room_name),
server,
resource: None,
}) if server.0 == self.hostname_rooms => {
let room_id = RoomId::try_from(room_name.0.clone()).unwrap();
let Some(_) = self.core.get_room(&room_id).await.unwrap() else {
// TODO should return item-not-found
// example:
// <error type="cancel">
// <item-not-found xmlns="urn:ietf:params:xml:ns:xmpp-stanzas"/>
// <text xmlns="urn:ietf:params:xml:ns:xmpp-stanzas" xml:lang="en">Conference room does not exist</text>
// </error>
return Err(IqError {
r#type: IqErrorType::Cancel,
condition: Some(IqErrorCondition::ItemNotFound),
});
};
identity = vec![Identity {
category: "conference".into(),
name: Some(room_id.into_inner().to_string()),
r#type: "text".into(),
}];
feature = vec![
Feature::new("http://jabber.org/protocol/disco#info"),
Feature::new("http://jabber.org/protocol/disco#items"),
Feature::new("http://jabber.org/protocol/muc"),
]
}
_ => {
identity = vec![];
feature = vec![];
}
};
Ok(InfoQuery {
node: None,
identity,
feature,
})
}
#[tracing::instrument(skip(self, core), name = "XmppConnection::disco_items")]
async fn disco_items(&self, to: Option<&Jid>, req: &ItemQuery, core: &LavinaCore) -> ItemQuery {
let item = match to {
Some(Jid {
name: None,
server,
resource: None,
}) if server.0 == self.hostname => {
vec![Item {
jid: Jid { jid: Jid {
name: Some(Name(room_info.id.into_inner())), name: None,
server: Server("rooms.localhost".into()), server: Server(self.hostname_rooms.clone()),
resource: None, resource: None,
}, },
name: None, name: None,
node: None, node: None,
}) }]
.collect() }
} Some(Jid {
_ => vec![], name: None,
}; server,
ItemQuery { item } resource: None,
}) if server.0 == self.hostname_rooms => {
let room_list = core.get_all_rooms().await;
room_list
.into_iter()
.map(|room_info| Item {
jid: Jid {
name: Some(Name(room_info.id.into_inner())),
server: Server(self.hostname_rooms.clone()),
resource: None,
},
name: None,
node: None,
})
.collect()
}
_ => vec![],
};
ItemQuery { item }
}
} }

View File

@ -22,13 +22,14 @@ use tokio::sync::mpsc::channel;
use tokio_rustls::rustls::{Certificate, PrivateKey}; use tokio_rustls::rustls::{Certificate, PrivateKey};
use tokio_rustls::TlsAcceptor; use tokio_rustls::TlsAcceptor;
use lavina_core::player::{PlayerConnection, PlayerId, PlayerRegistry}; use lavina_core::auth::Verdict;
use lavina_core::player::{ConnectionMessage, PlayerConnection, PlayerConnectionResult, PlayerId, StopReason};
use lavina_core::prelude::*; use lavina_core::prelude::*;
use lavina_core::repo::Storage;
use lavina_core::room::RoomRegistry;
use lavina_core::terminator::Terminator; use lavina_core::terminator::Terminator;
use lavina_core::LavinaCore;
use proto_xmpp::bind::{Name, Resource}; use proto_xmpp::bind::{Name, Resource};
use proto_xmpp::stream::*; use proto_xmpp::stream::*;
use proto_xmpp::streamerror::{StreamError, StreamErrorKind};
use proto_xmpp::xml::{Continuation, FromXml, Parser, ToXml}; use proto_xmpp::xml::{Continuation, FromXml, Parser, ToXml};
use sasl::AuthBody; use sasl::AuthBody;
@ -39,11 +40,15 @@ mod message;
mod presence; mod presence;
mod updates; mod updates;
#[cfg(test)]
mod testkit;
#[derive(Deserialize, Debug, Clone)] #[derive(Deserialize, Debug, Clone)]
pub struct ServerConfig { pub struct ServerConfig {
pub listen_on: SocketAddr, pub listen_on: SocketAddr,
pub cert: PathBuf, pub cert: PathBuf,
pub key: PathBuf, pub key: PathBuf,
pub hostname: Str,
} }
struct LoadedConfig { struct LoadedConfig {
@ -52,9 +57,17 @@ struct LoadedConfig {
} }
struct Authenticated { struct Authenticated {
/// Identifier of the authenticated player.
///
/// Used when communicating with lavina-core on behalf of the player.
player_id: PlayerId, player_id: PlayerId,
/// The user's XMPP name.
///
/// Used in `to` and `from` fields of XMPP messages.
xmpp_name: Name, xmpp_name: Name,
/// The resource given to this user by the server.
xmpp_resource: Resource, xmpp_resource: Resource,
/// The resource used by this user when joining MUCs.
xmpp_muc_name: Resource, xmpp_muc_name: Resource,
} }
@ -69,13 +82,7 @@ impl RunningServer {
} }
} }
pub async fn launch( pub async fn launch(config: ServerConfig, core: LavinaCore, metrics: MetricsRegistry) -> Result<RunningServer> {
config: ServerConfig,
players: PlayerRegistry,
rooms: RoomRegistry,
metrics: MetricsRegistry,
storage: Storage,
) -> Result<RunningServer> {
log::info!("Starting XMPP projection"); log::info!("Starting XMPP projection");
let certs = certs(&mut SyncBufReader::new(File::open(config.cert)?))?; let certs = certs(&mut SyncBufReader::new(File::open(config.cert)?))?;
@ -114,14 +121,13 @@ pub async fn launch(
// TODO kill the older connection and restart it // TODO kill the older connection and restart it
continue; continue;
} }
let players = players.clone(); let core = core.clone();
let rooms = rooms.clone(); let hostname = config.hostname.clone();
let storage = storage.clone();
let terminator = Terminator::spawn(|termination| { let terminator = Terminator::spawn(|termination| {
let stopped_tx = stopped_tx.clone(); let stopped_tx = stopped_tx.clone();
let loaded_config = loaded_config.clone(); let loaded_config = loaded_config.clone();
async move { async move {
match handle_socket(loaded_config, stream, &socket_addr, players, rooms, storage, termination).await { match handle_socket(loaded_config, stream, &socket_addr, core, hostname, termination).await {
Ok(_) => log::info!("Connection terminated"), Ok(_) => log::info!("Connection terminated"),
Err(err) => log::warn!("Connection failed: {err}"), Err(err) => log::warn!("Connection failed: {err}"),
} }
@ -156,12 +162,11 @@ pub async fn launch(
} }
async fn handle_socket( async fn handle_socket(
config: Arc<LoadedConfig>, cert_config: Arc<LoadedConfig>,
mut stream: TcpStream, mut stream: TcpStream,
socket_addr: &SocketAddr, socket_addr: &SocketAddr,
mut players: PlayerRegistry, core: LavinaCore,
rooms: RoomRegistry, hostname: Str,
mut storage: Storage,
termination: Deferred<()>, // TODO use it to stop the connection gracefully termination: Deferred<()>, // TODO use it to stop the connection gracefully
) -> Result<()> { ) -> Result<()> {
log::info!("Received an XMPP connection from {socket_addr}"); log::info!("Received an XMPP connection from {socket_addr}");
@ -170,12 +175,12 @@ async fn handle_socket(
let mut buf_reader = BufReader::new(reader); let mut buf_reader = BufReader::new(reader);
let mut buf_writer = BufWriter::new(writer); let mut buf_writer = BufWriter::new(writer);
socket_force_tls(&mut buf_reader, &mut buf_writer, &mut reader_buf).await?; socket_force_tls(&mut buf_reader, &mut buf_writer, &mut reader_buf, &hostname).await?;
let mut config = tokio_rustls::rustls::ServerConfig::builder() let mut config = tokio_rustls::rustls::ServerConfig::builder()
.with_safe_defaults() .with_safe_defaults()
.with_no_client_auth() .with_no_client_auth()
.with_single_cert(vec![config.cert.clone()], config.key.clone())?; .with_single_cert(vec![cert_config.cert.clone()], cert_config.key.clone())?;
config.key_log = Arc::new(tokio_rustls::rustls::KeyLogFile::new()); config.key_log = Arc::new(tokio_rustls::rustls::KeyLogFile::new());
log::debug!("Accepting TLS connection..."); log::debug!("Accepting TLS connection...");
@ -185,23 +190,45 @@ async fn handle_socket(
let (a, b) = tokio::io::split(new_stream); let (a, b) = tokio::io::split(new_stream);
let mut xml_reader = NsReader::from_reader(BufReader::new(a)); let mut xml_reader = NsReader::from_reader(BufReader::new(a));
let mut xml_writer = Writer::new(b); let mut xml_writer = Writer::new(BufWriter::new(b));
let authenticated = socket_auth(&mut xml_reader, &mut xml_writer, &mut reader_buf, &mut storage).await?; pin!(termination);
log::debug!("User authenticated"); select! {
let mut connection = players.connect_to_player(authenticated.player_id.clone()).await; biased;
socket_final( _ = &mut termination => {
&mut xml_reader, log::info!("Socket handling was terminated");
&mut xml_writer, return Ok(())
&mut reader_buf, },
&authenticated, authenticated = socket_auth(&mut xml_reader, &mut xml_writer, &mut reader_buf, &core, &hostname) => {
&mut connection, match authenticated {
&rooms, Ok(authenticated) => {
) let mut connection = match core.connect_to_player(&authenticated.player_id).await? {
.await?; PlayerConnectionResult::Success(connection) => connection,
PlayerConnectionResult::PlayerNotFound => {
tracing::error!("Authorized user unexpectedly not found in the database");
return Err(anyhow!("no such user"));
}
};
socket_final(
&mut xml_reader,
&mut xml_writer,
&mut reader_buf,
&authenticated,
&mut connection,
&core,
&hostname,
)
.await?;
},
Err(err) => {
log::error!("Authentication error: {:?}", err);
}
}
},
}
let a = xml_reader.into_inner().into_inner(); let a = xml_reader.into_inner().into_inner();
let b = xml_writer.into_inner(); let b = xml_writer.into_inner().into_inner();
a.unsplit(b).shutdown().await?; a.unsplit(b).shutdown().await?;
Ok(()) Ok(())
} }
@ -210,17 +237,18 @@ async fn socket_force_tls(
reader: &mut (impl AsyncBufRead + Unpin), reader: &mut (impl AsyncBufRead + Unpin),
writer: &mut (impl AsyncWrite + Unpin), writer: &mut (impl AsyncWrite + Unpin),
reader_buf: &mut Vec<u8>, reader_buf: &mut Vec<u8>,
hostname: &Str,
) -> Result<()> { ) -> Result<()> {
use proto_xmpp::tls::*; use proto_xmpp::tls::*;
let xml_reader = &mut NsReader::from_reader(reader); let xml_reader = &mut NsReader::from_reader(reader);
let xml_writer = &mut Writer::new(writer); let xml_writer = &mut Writer::new(writer);
read_xml_header(xml_reader, reader_buf).await?; // TODO validate the server hostname received in the stream start
let _ = ClientStreamStart::parse(xml_reader, reader_buf).await?; let _ = ClientStreamStart::parse(xml_reader, reader_buf).await?;
let event = Event::Decl(BytesDecl::new("1.0", None, None)); let event = Event::Decl(BytesDecl::new("1.0", None, None));
xml_writer.write_event_async(event).await?; xml_writer.write_event_async(event).await?;
let msg = ServerStreamStart { let msg = ServerStreamStart {
from: "localhost".into(), from: hostname.to_string(),
lang: "en".into(), lang: "en".into(),
id: uuid::Uuid::new_v4().to_string(), id: uuid::Uuid::new_v4().to_string(),
version: "1.0".into(), version: "1.0".into(),
@ -244,14 +272,15 @@ async fn socket_auth(
xml_reader: &mut NsReader<(impl AsyncBufRead + Unpin)>, xml_reader: &mut NsReader<(impl AsyncBufRead + Unpin)>,
xml_writer: &mut Writer<(impl AsyncWrite + Unpin)>, xml_writer: &mut Writer<(impl AsyncWrite + Unpin)>,
reader_buf: &mut Vec<u8>, reader_buf: &mut Vec<u8>,
storage: &mut Storage, core: &LavinaCore,
hostname: &Str,
) -> Result<Authenticated> { ) -> Result<Authenticated> {
read_xml_header(xml_reader, reader_buf).await?; // TODO validate the server hostname received in the stream start
let _ = ClientStreamStart::parse(xml_reader, reader_buf).await?; let _ = ClientStreamStart::parse(xml_reader, reader_buf).await?;
xml_writer.write_event_async(Event::Decl(BytesDecl::new("1.0", None, None))).await?; xml_writer.write_event_async(Event::Decl(BytesDecl::new("1.0", None, None))).await?;
ServerStreamStart { ServerStreamStart {
from: "localhost".into(), from: hostname.to_string(),
lang: "en".into(), lang: "en".into(),
id: uuid::Uuid::new_v4().to_string(), id: uuid::Uuid::new_v4().to_string(),
version: "1.0".into(), version: "1.0".into(),
@ -268,36 +297,28 @@ async fn socket_auth(
xml_writer.get_mut().flush().await?; xml_writer.get_mut().flush().await?;
let auth: proto_xmpp::sasl::Auth = proto_xmpp::sasl::Auth::parse(xml_reader, reader_buf).await?; let auth: proto_xmpp::sasl::Auth = proto_xmpp::sasl::Auth::parse(xml_reader, reader_buf).await?;
proto_xmpp::sasl::Success.write_xml(xml_writer).await?;
match AuthBody::from_str(&auth.body) { match AuthBody::from_str(&auth.body) {
Ok(logopass) => { Ok(logopass) => {
let name = &logopass.login; let name = &logopass.login;
let stored_user = storage.retrieve_user_by_name(name).await?; let verdict = core.authenticate(name, &logopass.password).await?;
match verdict {
let stored_user = match stored_user { Verdict::Authenticated => {
Some(u) => u, proto_xmpp::sasl::Success.write_xml(xml_writer).await?;
None => { xml_writer.get_mut().flush().await?;
log::info!("User '{}' not found", name); }
return Err(fail("no user found")); Verdict::UserNotFound | Verdict::InvalidPassword => {
proto_xmpp::sasl::Failure.write_xml(xml_writer).await?;
xml_writer.get_mut().flush().await?;
return Err(anyhow!("incorrect credentials"));
} }
};
// TODO return proper XML errors to the client
if stored_user.password.is_none() {
log::info!("Password not defined for user '{}'", name);
return Err(fail("password is not defined"));
} }
if stored_user.password.as_deref() != Some(&logopass.password) { let name: Str = name.as_str().into();
log::info!("Incorrect password supplied for user '{}'", name);
return Err(fail("passwords do not match"));
}
Ok(Authenticated { Ok(Authenticated {
player_id: PlayerId::from(name.as_str())?, player_id: PlayerId::from(name.clone())?,
xmpp_name: Name(name.to_string().into()), xmpp_name: Name(name.clone()),
xmpp_resource: Resource(name.to_string().into()), xmpp_resource: Resource(name.clone()),
xmpp_muc_name: Resource(name.to_string().into()), xmpp_muc_name: Resource(name.clone()),
}) })
} }
Err(e) => return Err(e), Err(e) => return Err(e),
@ -310,14 +331,15 @@ async fn socket_final(
reader_buf: &mut Vec<u8>, reader_buf: &mut Vec<u8>,
authenticated: &Authenticated, authenticated: &Authenticated,
user_handle: &mut PlayerConnection, user_handle: &mut PlayerConnection,
rooms: &RoomRegistry, core: &LavinaCore,
hostname: &Str,
) -> Result<()> { ) -> Result<()> {
read_xml_header(xml_reader, reader_buf).await?; // TODO validate the server hostname received in the stream start
let _ = ClientStreamStart::parse(xml_reader, reader_buf).await?; let _ = ClientStreamStart::parse(xml_reader, reader_buf).await?;
xml_writer.write_event_async(Event::Decl(BytesDecl::new("1.0", None, None))).await?; xml_writer.write_event_async(Event::Decl(BytesDecl::new("1.0", None, None))).await?;
ServerStreamStart { ServerStreamStart {
from: "localhost".into(), from: hostname.to_string(),
lang: "en".into(), lang: "en".into(),
id: uuid::Uuid::new_v4().to_string(), id: uuid::Uuid::new_v4().to_string(),
version: "1.0".into(), version: "1.0".into(),
@ -342,14 +364,16 @@ async fn socket_final(
let mut conn = XmppConnection { let mut conn = XmppConnection {
user: authenticated, user: authenticated,
user_handle, user_handle,
rooms, core,
hostname: hostname.clone(),
hostname_rooms: format!("rooms.{}", hostname).into(),
}; };
let should_recreate_xml_future = select! { let should_recreate_xml_future = select! {
biased; biased;
res = &mut next_xml_event => 's: { res = &mut next_xml_event => 's: {
let (ns, event) = res?; let (ns, event) = res?;
if let Event::Text(ref e) = event { if let Event::Text(ref e) = event {
if e.iter().all(|x| *x == 0xA) { if e.iter().all(|x| *x == b'\n' || *x == b' ') {
break 's true; break 's true;
} }
} }
@ -372,16 +396,41 @@ async fn socket_final(
true true
}, },
update = conn.user_handle.receiver.recv() => { update = conn.user_handle.receiver.recv() => {
if let Some(update) = update { match update {
conn.handle_update(&mut events, update).await?; Some(ConnectionMessage::Update(update)) => {
for i in &events { conn.handle_update(&mut events, update).await?;
xml_writer.write_event_async(i).await?; for i in &events {
xml_writer.write_event_async(i).await?;
}
events.clear();
xml_writer.get_mut().flush().await?;
}
Some(ConnectionMessage::Stop(reason)) => {
tracing::debug!("Connection is being terminated: {reason:?}");
let kind = match reason {
StopReason::ServerShutdown => StreamErrorKind::SystemShutdown,
StopReason::InternalError => StreamErrorKind::InternalServerError,
};
StreamError { kind }.serialize(&mut events);
ServerStreamEnd.serialize(&mut events);
for i in &events {
xml_writer.write_event_async(i).await?;
}
events.clear();
xml_writer.get_mut().flush().await?;
break;
}
None => {
log::error!("Player is terminated, must terminate the connection");
StreamError { kind: StreamErrorKind::SystemShutdown }.serialize(&mut events);
ServerStreamEnd.serialize(&mut events);
for i in &events {
xml_writer.write_event_async(i).await?;
}
events.clear();
xml_writer.get_mut().flush().await?;
break;
} }
events.clear();
xml_writer.get_mut().flush().await?;
} else {
log::warn!("Player is terminated, must terminate the connection");
break;
} }
false false
} }
@ -398,13 +447,16 @@ async fn socket_final(
struct XmppConnection<'a> { struct XmppConnection<'a> {
user: &'a Authenticated, user: &'a Authenticated,
user_handle: &'a mut PlayerConnection, user_handle: &'a mut PlayerConnection,
rooms: &'a RoomRegistry, core: &'a LavinaCore,
hostname: Str,
hostname_rooms: Str,
} }
impl<'a> XmppConnection<'a> { impl<'a> XmppConnection<'a> {
#[tracing::instrument(skip(self, output, packet), name = "XmppConnection::handle_packet")]
async fn handle_packet(&mut self, output: &mut Vec<Event<'static>>, packet: ClientPacket) -> Result<bool> { async fn handle_packet(&mut self, output: &mut Vec<Event<'static>>, packet: ClientPacket) -> Result<bool> {
let res = match packet { let res = match packet {
proto::ClientPacket::Iq(iq) => { ClientPacket::Iq(iq) => {
self.handle_iq(output, iq).await; self.handle_iq(output, iq).await;
false false
} }
@ -412,37 +464,16 @@ impl<'a> XmppConnection<'a> {
self.handle_message(output, m).await?; self.handle_message(output, m).await?;
false false
} }
proto::ClientPacket::Presence(p) => { ClientPacket::Presence(p) => {
self.handle_presence(output, p).await?; self.handle_presence(output, p).await?;
false false
} }
proto::ClientPacket::StreamEnd => { ClientPacket::StreamEnd => {
ServerStreamEnd.serialize(output); ServerStreamEnd.serialize(output);
true true
} }
ClientPacket::Eos => true,
}; };
Ok(res) Ok(res)
} }
} }
async fn read_xml_header(
xml_reader: &mut NsReader<(impl AsyncBufRead + Unpin)>,
reader_buf: &mut Vec<u8>,
) -> Result<()> {
if let Event::Decl(bytes) = xml_reader.read_event_into_async(reader_buf).await? {
// this is <?xml ...> header
if let Some(encoding) = bytes.encoding() {
let encoding = encoding?;
if &*encoding == b"UTF-8" {
Ok(())
} else {
Err(anyhow!("Unsupported encoding: {encoding:?}"))
}
} else {
// Err(fail("No XML encoding provided"))
Ok(())
}
} else {
Err(anyhow!("Expected XML header"))
}
}

View File

@ -1,5 +1,6 @@
//! Handling of all client2server message stanzas //! Handling of all client2server message stanzas
use lavina_core::player::PlayerId;
use quick_xml::events::Event; use quick_xml::events::Event;
use lavina_core::prelude::*; use lavina_core::prelude::*;
@ -11,6 +12,7 @@ use proto_xmpp::xml::{Ignore, ToXml};
use crate::XmppConnection; use crate::XmppConnection;
impl<'a> XmppConnection<'a> { impl<'a> XmppConnection<'a> {
#[tracing::instrument(skip(self, output, m), name = "XmppConnection::message")]
pub async fn handle_message(&mut self, output: &mut Vec<Event<'static>>, m: Message<Ignore>) -> Result<()> { pub async fn handle_message(&mut self, output: &mut Vec<Event<'static>>, m: Message<Ignore>) -> Result<()> {
if let Some(Jid { if let Some(Jid {
name: Some(name), name: Some(name),
@ -18,17 +20,18 @@ impl<'a> XmppConnection<'a> {
resource: _, resource: _,
}) = m.to }) = m.to
{ {
if server.0.as_ref() == "rooms.localhost" && m.r#type == MessageType::Groupchat { if server.0.as_ref() == &*self.hostname_rooms && m.r#type == MessageType::Groupchat {
self.user_handle.send_message(RoomId::from(name.0.clone())?, m.body.clone().into()).await?; let Some(body) = &m.body else { return Ok(()) };
self.user_handle.send_message(RoomId::try_from(name.0.clone())?, body.clone()).await?;
Message::<()> { Message::<()> {
to: Some(Jid { to: Some(Jid {
name: Some(self.user.xmpp_name.clone()), name: Some(self.user.xmpp_name.clone()),
server: Server("localhost".into()), server: Server(self.hostname.clone()),
resource: Some(self.user.xmpp_resource.clone()), resource: Some(self.user.xmpp_resource.clone()),
}), }),
from: Some(Jid { from: Some(Jid {
name: Some(name), name: Some(name),
server: Server("rooms.localhost".into()), server: Server(self.hostname_rooms.clone()),
resource: Some(self.user.xmpp_muc_name.clone()), resource: Some(self.user.xmpp_muc_name.clone()),
}), }),
id: m.id, id: m.id,
@ -40,6 +43,10 @@ impl<'a> XmppConnection<'a> {
} }
.serialize(output); .serialize(output);
Ok(()) Ok(())
} else if server.0.as_ref() == &*self.hostname && m.r#type == MessageType::Chat {
let Some(body) = &m.body else { return Ok(()) };
self.user_handle.send_dialog_message(PlayerId::from(name.0.clone())?, body.clone()).await?;
Ok(())
} else { } else {
todo!() todo!()
} }

View File

@ -1,55 +1,358 @@
//! Handling of all client2server presence stanzas //! Handling of all client2server presence stanzas
use anyhow::Result;
use quick_xml::events::Event; use quick_xml::events::Event;
use lavina_core::prelude::*; use lavina_core::player::RoomHistoryResult;
use lavina_core::room::RoomId; use lavina_core::room::RoomId;
use proto_xmpp::bind::{Jid, Server}; use proto_xmpp::bind::{Jid, Name, Resource, Server};
use proto_xmpp::client::Presence; use proto_xmpp::client::{Message, MessageType, Presence, Subject};
use proto_xmpp::muc::{Affiliation, Delay, Role, XUser, XUserItem, XmppHistoryMessage};
use proto_xmpp::xml::{Ignore, ToXml}; use proto_xmpp::xml::{Ignore, ToXml};
use crate::XmppConnection; use crate::XmppConnection;
impl<'a> XmppConnection<'a> { impl<'a> XmppConnection<'a> {
#[tracing::instrument(skip(self, output, p), name = "XmppConnection::handle_presence")]
pub async fn handle_presence(&mut self, output: &mut Vec<Event<'static>>, p: Presence<Ignore>) -> Result<()> { pub async fn handle_presence(&mut self, output: &mut Vec<Event<'static>>, p: Presence<Ignore>) -> Result<()> {
let response = if p.to.is_none() { match p.to {
Presence::<()> { None => {
to: Some(Jid { self.self_presence(output, p.r#type.as_deref()).await;
name: Some(self.user.xmpp_name.clone()),
server: Server("localhost".into()),
resource: Some(self.user.xmpp_resource.clone()),
}),
from: Some(Jid {
name: Some(self.user.xmpp_name.clone()),
server: Server("localhost".into()),
resource: Some(self.user.xmpp_resource.clone()),
}),
..Default::default()
} }
} else if let Some(Jid { Some(Jid {
name: Some(name), name: Some(name),
server, server,
resource: Some(resource), // resources in MUCs are members' personas not implemented (yet?)
}) = p.to resource: Some(_),
{ }) if server.0 == self.hostname_rooms => match p.r#type.as_deref() {
let a = self.user_handle.join_room(RoomId::from(name.0.clone())?).await?; None => {
Presence::<()> { self.join_muc(output, p.id, name).await?;
to: Some(Jid { }
name: Some(self.user.xmpp_name.clone()), Some("unavailable") => {
server: Server("localhost".into()), self.leave_muc(output, p.id, name).await?;
resource: Some(self.user.xmpp_resource.clone()), }
}), _ => {
from: Some(Jid { tracing::error!("Unimplemented case")
name: Some(name.clone()), }
server: Server("rooms.localhost".into()), },
resource: Some(self.user.xmpp_muc_name.clone()), _ => {
}), // TODO other presence cases
..Default::default() let response = Presence::<()>::default();
response.serialize(output);
} }
} else { }
Presence::<()>::default() Ok(())
}
async fn join_muc(&mut self, output: &mut Vec<Event<'static>>, id: Option<String>, name: Name) -> Result<()> {
// Response presence
let mut muc_presence = self.retrieve_muc_presence(&name).await?;
muc_presence.id = id;
muc_presence.serialize(output);
// N last messages from the room history before the user joined
let messages = self.retrieve_message_history(&name).await?;
for message in messages {
message.serialize(output)
}
// The subject is the last stanza sent during a MUC join process.
let subject = Message::<()> {
from: Some(Jid {
name: Some(name.clone()),
server: Server(self.hostname_rooms.clone()),
resource: None,
}),
id: None,
to: Some(Jid {
name: Some(self.user.xmpp_name.clone()),
server: Server(self.hostname.clone()),
resource: Some(self.user.xmpp_resource.clone()),
}),
r#type: MessageType::Groupchat,
lang: None,
subject: Some(Subject(None)),
body: None,
custom: vec![],
};
subject.serialize(output);
Ok(())
}
async fn leave_muc(&mut self, output: &mut Vec<Event<'static>>, id: Option<String>, name: Name) -> Result<()> {
self.user_handle.leave_room(RoomId::try_from(name.0.clone())?).await?;
let response = Presence {
id,
to: Some(Jid {
name: Some(self.user.xmpp_name.clone()),
server: Server(self.hostname.clone()),
resource: Some(self.user.xmpp_resource.clone()),
}),
from: Some(Jid {
name: Some(name.clone()),
server: Server(self.hostname_rooms.clone()),
resource: Some(self.user.xmpp_muc_name.clone()),
}),
r#type: Some("unavailable".into()),
custom: vec![XUser {
item: XUserItem {
affiliation: Affiliation::Member,
role: Role::None,
jid: Jid {
name: Some(self.user.xmpp_name.clone()),
server: Server(self.hostname.clone()),
resource: Some(self.user.xmpp_resource.clone()),
},
},
self_presence: true,
just_created: false,
}],
..Default::default()
}; };
response.serialize(output); response.serialize(output);
Ok(()) Ok(())
} }
#[tracing::instrument(skip(self, output, r#type), name = "XmppConnection::self_presence")]
async fn self_presence(&mut self, output: &mut Vec<Event<'static>>, r#type: Option<&str>) {
match r#type {
Some("unavailable") => {
// do not print anything
}
None => {
let response = Presence::<()> {
to: Some(Jid {
name: Some(self.user.xmpp_name.clone()),
server: Server(self.hostname.clone()),
resource: Some(self.user.xmpp_resource.clone()),
}),
from: Some(Jid {
name: Some(self.user.xmpp_name.clone()),
server: Server(self.hostname.clone()),
resource: Some(self.user.xmpp_resource.clone()),
}),
..Default::default()
};
response.serialize(output);
}
e => {
tracing::error!("TODO: unknown presence type: {e:?}");
}
}
}
#[tracing::instrument(skip(self), name = "XmppConnection::retrieve_muc_presence")]
async fn retrieve_muc_presence(&mut self, name: &Name) -> Result<Presence<XUser>> {
let _ = self.user_handle.join_room(RoomId::try_from(name.0.clone())?).await?;
// TODO handle bans
let response = Presence {
to: Some(Jid {
name: Some(self.user.xmpp_name.clone()),
server: Server(self.hostname.clone()),
resource: Some(self.user.xmpp_resource.clone()),
}),
from: Some(Jid {
name: Some(name.clone()),
server: Server(self.hostname_rooms.clone()),
resource: Some(self.user.xmpp_muc_name.clone()),
}),
custom: vec![XUser {
item: XUserItem {
affiliation: Affiliation::Member,
role: Role::Participant,
jid: Jid {
name: Some(self.user.xmpp_name.clone()),
server: Server(self.hostname.clone()),
resource: Some(self.user.xmpp_resource.clone()),
},
},
self_presence: true,
just_created: false, // TODO we don't know this for sure at this point
}],
..Default::default()
};
Ok(response)
}
/// Retrieve a room's message history. The output can be serialized into a stream of XML stanzas.
///
/// Example in [XmppHistoryMessage]'s docs.
#[tracing::instrument(skip(self), name = "XmppConnection::retrieve_message_history")]
async fn retrieve_message_history(&self, room_name: &Name) -> Result<Vec<XmppHistoryMessage>> {
let room_id = RoomId::try_from(room_name.0.clone())?;
let history_messages = self.user_handle.get_room_message_history(&room_id, 50).await?;
let history_messages = match history_messages {
RoomHistoryResult::Success(messages) => messages,
RoomHistoryResult::NoSuchRoom => {
tracing::warn!("No room found during history retrieval on join");
return Ok(vec![]);
}
};
let mut response = vec![];
for history_message in history_messages.into_iter() {
response.push(XmppHistoryMessage {
id: history_message.id.to_string(),
to: Jid {
name: Option::from(Name(self.user.xmpp_muc_name.0.clone().into())),
server: Server(self.hostname.clone()),
resource: None,
},
from: Jid {
name: Option::from(room_name.clone()),
server: Server(self.hostname_rooms.clone()),
resource: Option::from(Resource(history_message.author_name.clone().into())),
},
delay: Delay {
from: Jid {
name: Option::from(Name(history_message.author_name.clone().into())),
server: Server(self.hostname_rooms.clone()),
resource: None,
},
stamp: history_message.created_at.to_rfc3339(),
},
body: history_message.content.clone(),
});
tracing::info!(
"Retrieved message: {:?} {:?}",
history_message.author_name,
history_message.content.clone()
);
}
return Ok(response);
}
}
#[cfg(test)]
mod tests {
use anyhow::Result;
use lavina_core::player::{PlayerConnectionResult, PlayerId};
use proto_xmpp::bind::{Jid, Name, Resource, Server};
use proto_xmpp::client::Presence;
use proto_xmpp::muc::{Affiliation, Role, XUser, XUserItem};
use crate::testkit::{expect_user_authenticated, TestServer};
use crate::Authenticated;
#[tokio::test]
async fn test_muc_joining() -> Result<()> {
let server = TestServer::start().await?;
server.core.create_player(&PlayerId::from("tester")?).await?;
let player_id = PlayerId::from("tester")?;
let user = Authenticated {
player_id,
xmpp_name: Name("tester".into()),
xmpp_resource: Resource("tester".into()),
xmpp_muc_name: Resource("tester".into()),
};
let mut player_conn = match server.core.connect_to_player(&user.player_id).await? {
PlayerConnectionResult::Success(conn) => conn,
PlayerConnectionResult::PlayerNotFound => panic!("user was created, but not returned"),
};
let mut conn = expect_user_authenticated(&server, &user, &mut player_conn).await?;
let muc_presence = conn.retrieve_muc_presence(&user.xmpp_name).await?;
let expected = Presence {
to: Some(Jid {
name: Some(conn.user.xmpp_name.clone()),
server: Server(conn.hostname.clone()),
resource: Some(conn.user.xmpp_resource.clone()),
}),
from: Some(Jid {
name: Some(user.xmpp_name.clone()),
server: Server(conn.hostname_rooms.clone()),
resource: Some(conn.user.xmpp_muc_name.clone()),
}),
custom: vec![XUser {
item: XUserItem {
affiliation: Affiliation::Member,
role: Role::Participant,
jid: Jid {
name: Some(conn.user.xmpp_name.clone()),
server: Server(conn.hostname.clone()),
resource: Some(conn.user.xmpp_resource.clone()),
},
},
self_presence: true,
just_created: false,
}],
..Default::default()
};
assert_eq!(expected, muc_presence);
server.shutdown().await;
Ok(())
}
// Test that joining a room second time after a server restart,
// i.e. in-memory cache of memberships is cleaned, does not cause any issues.
#[tokio::test]
async fn test_muc_joining_twice() -> Result<()> {
let server = TestServer::start().await?;
server.core.create_player(&PlayerId::from("tester")?).await?;
let player_id = PlayerId::from("tester")?;
let user = Authenticated {
player_id,
xmpp_name: Name("tester".into()),
xmpp_resource: Resource("tester".into()),
xmpp_muc_name: Resource("tester".into()),
};
let mut player_conn = match server.core.connect_to_player(&user.player_id).await? {
PlayerConnectionResult::Success(conn) => conn,
PlayerConnectionResult::PlayerNotFound => panic!("user was created, but not returned"),
};
let mut conn = expect_user_authenticated(&server, &user, &mut player_conn).await?;
let response = conn.retrieve_muc_presence(&user.xmpp_name).await?;
let expected = Presence {
to: Some(Jid {
name: Some(conn.user.xmpp_name.clone()),
server: Server(conn.hostname.clone()),
resource: Some(conn.user.xmpp_resource.clone()),
}),
from: Some(Jid {
name: Some(user.xmpp_name.clone()),
server: Server(conn.hostname_rooms.clone()),
resource: Some(conn.user.xmpp_muc_name.clone()),
}),
custom: vec![XUser {
item: XUserItem {
affiliation: Affiliation::Member,
role: Role::Participant,
jid: Jid {
name: Some(conn.user.xmpp_name.clone()),
server: Server(conn.hostname.clone()),
resource: Some(conn.user.xmpp_resource.clone()),
},
},
self_presence: true,
just_created: false,
}],
..Default::default()
};
assert_eq!(expected, response);
drop(conn);
let server = server.reboot().await.unwrap();
let mut player_conn = match server.core.connect_to_player(&user.player_id).await? {
PlayerConnectionResult::Success(conn) => conn,
PlayerConnectionResult::PlayerNotFound => panic!("user was created, but not returned"),
};
let mut conn = expect_user_authenticated(&server, &user, &mut player_conn).await?;
let response = conn.retrieve_muc_presence(&user.xmpp_name).await?;
assert_eq!(expected, response);
server.shutdown().await;
Ok(())
}
} }

View File

@ -7,6 +7,7 @@ use lavina_core::prelude::*;
use proto_xmpp::bind::BindRequest; use proto_xmpp::bind::BindRequest;
use proto_xmpp::client::{Iq, Message, Presence}; use proto_xmpp::client::{Iq, Message, Presence};
use proto_xmpp::disco::{InfoQuery, ItemQuery}; use proto_xmpp::disco::{InfoQuery, ItemQuery};
use proto_xmpp::mam::MessageArchiveRequest;
use proto_xmpp::roster::RosterQuery; use proto_xmpp::roster::RosterQuery;
use proto_xmpp::session::Session; use proto_xmpp::session::Session;
use proto_xmpp::xml::*; use proto_xmpp::xml::*;
@ -18,6 +19,7 @@ pub enum IqClientBody {
Roster(RosterQuery), Roster(RosterQuery),
DiscoInfo(InfoQuery), DiscoInfo(InfoQuery),
DiscoItem(ItemQuery), DiscoItem(ItemQuery),
MessageArchiveRequest(MessageArchiveRequest),
Unknown(Ignore), Unknown(Ignore),
} }
@ -25,7 +27,8 @@ impl FromXml for IqClientBody {
type P = impl Parser<Output = Result<Self>>; type P = impl Parser<Output = Result<Self>>;
fn parse() -> Self::P { fn parse() -> Self::P {
|(namespace, event): (ResolveResult<'static>, &'static Event<'static>)| -> Result<Self> { #[coroutine]
|(mut namespace, mut event): (ResolveResult<'static>, &'static Event<'static>)| -> Result<Self> {
let bytes = match event { let bytes = match event {
Event::Start(bytes) => bytes, Event::Start(bytes) => bytes,
Event::Empty(bytes) => bytes, Event::Empty(bytes) => bytes,
@ -38,6 +41,7 @@ impl FromXml for IqClientBody {
RosterQuery, RosterQuery,
InfoQuery, InfoQuery,
ItemQuery, ItemQuery,
MessageArchiveRequest,
{ {
delegate_parsing!(Ignore, namespace, event).into() delegate_parsing!(Ignore, namespace, event).into()
} }
@ -52,13 +56,15 @@ pub enum ClientPacket {
Message(Message<Ignore>), Message(Message<Ignore>),
Presence(Presence<Ignore>), Presence(Presence<Ignore>),
StreamEnd, StreamEnd,
Eos,
} }
impl FromXml for ClientPacket { impl FromXml for ClientPacket {
type P = impl Parser<Output = Result<Self>>; type P = impl Parser<Output = Result<Self>>;
fn parse() -> Self::P { fn parse() -> Self::P {
|(namespace, event): (ResolveResult<'static>, &'static Event<'static>)| -> Result<Self> { #[coroutine]
|(mut namespace, mut event): (ResolveResult<'static>, &'static Event<'static>)| -> Result<Self> {
match event { match event {
Event::Start(bytes) | Event::Empty(bytes) => { Event::Start(bytes) | Event::Empty(bytes) => {
let name = bytes.name(); let name = bytes.name();
@ -83,6 +89,7 @@ impl FromXml for ClientPacket {
return Err(anyhow!("Unexpected XML event: {event:?}")); return Err(anyhow!("Unexpected XML event: {event:?}"));
} }
} }
Event::Eof => Ok(ClientPacket::Eos),
_ => { _ => {
return Err(anyhow!("Unexpected XML event: {event:?}")); return Err(anyhow!("Unexpected XML event: {event:?}"));
} }

View File

@ -0,0 +1,77 @@
use prometheus::Registry as MetricsRegistry;
use crate::{Authenticated, XmppConnection};
use lavina_core::clustering::{ClusterConfig, ClusterMetadata};
use lavina_core::player::PlayerConnection;
use lavina_core::repo::{Storage, StorageConfig};
use lavina_core::LavinaCore;
use proto_xmpp::bind::{BindRequest, BindResponse, Jid, Name, Resource, Server};
pub(crate) struct TestServer {
pub core: LavinaCore,
}
impl TestServer {
pub async fn start() -> anyhow::Result<TestServer> {
let _ = tracing_subscriber::fmt::try_init();
let mut metrics = MetricsRegistry::new();
let storage = Storage::open(StorageConfig {
db_path: ":memory:".into(),
})
.await?;
let cluster_config = ClusterConfig {
metadata: ClusterMetadata {
node_id: 0,
main_owner: 0,
rooms: Default::default(),
},
addresses: vec![],
};
let core = LavinaCore::new(&mut metrics, cluster_config, storage).await?;
Ok(TestServer { core })
}
pub async fn reboot(self) -> anyhow::Result<TestServer> {
let storage = self.core.shutdown().await;
let mut metrics = MetricsRegistry::new();
let cluster_config = ClusterConfig {
metadata: ClusterMetadata {
node_id: 0,
main_owner: 0,
rooms: Default::default(),
},
addresses: vec![],
};
let core = LavinaCore::new(&mut metrics, cluster_config, storage).await?;
Ok(TestServer { core })
}
pub async fn shutdown(self) {
let storage = self.core.shutdown().await;
storage.close().await;
}
}
pub async fn expect_user_authenticated<'a>(
server: &'a TestServer,
user: &'a Authenticated,
conn: &'a mut PlayerConnection,
) -> anyhow::Result<XmppConnection<'a>> {
let conn = XmppConnection {
user: &user,
user_handle: conn,
core: &server.core,
hostname: "localhost".into(),
hostname_rooms: "rooms.localhost".into(),
};
let result = conn.bind(&BindRequest(Resource("whatever".into()))).await;
let expected = BindResponse(Jid {
name: Some(Name("tester".into())),
server: Server("localhost".into()),
resource: Some(Resource("tester".into())),
});
assert_eq!(expected, result);
Ok(conn)
}

View File

@ -17,16 +17,17 @@ impl<'a> XmppConnection<'a> {
room_id, room_id,
author_id, author_id,
body, body,
created_at: _,
} => { } => {
Message::<()> { Message::<()> {
to: Some(Jid { to: Some(Jid {
name: Some(self.user.xmpp_name.clone()), name: Some(self.user.xmpp_name.clone()),
server: Server("localhost".into()), server: Server(self.hostname.clone()),
resource: Some(self.user.xmpp_resource.clone()), resource: Some(self.user.xmpp_resource.clone()),
}), }),
from: Some(Jid { from: Some(Jid {
name: Some(Name(room_id.into_inner().into())), name: Some(Name(room_id.into_inner().into())),
server: Server("rooms.localhost".into()), server: Server(self.hostname_rooms.clone()),
resource: Some(Resource(author_id.into_inner().into())), resource: Some(Resource(author_id.into_inner().into())),
}), }),
id: None, id: None,
@ -38,6 +39,34 @@ impl<'a> XmppConnection<'a> {
} }
.serialize(output); .serialize(output);
} }
Updates::NewDialogMessage {
sender,
receiver,
body,
created_at: _,
} => {
if receiver == self.user.player_id {
Message::<()> {
to: Some(Jid {
name: Some(self.user.xmpp_name.clone()),
server: Server(self.hostname.clone()),
resource: Some(self.user.xmpp_resource.clone()),
}),
from: Some(Jid {
name: Some(Name(sender.as_inner().clone())),
server: Server(self.hostname.clone()),
resource: Some(Resource(sender.into_inner())),
}),
id: None,
r#type: MessageType::Chat,
lang: None,
subject: None,
body: body.into(),
custom: vec![],
}
.serialize(output);
}
}
_ => {} _ => {}
} }
Ok(()) Ok(())

View File

@ -1,3 +1,5 @@
use std::io::ErrorKind;
use std::str::from_utf8;
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
@ -5,8 +7,9 @@ use anyhow::Result;
use assert_matches::*; use assert_matches::*;
use prometheus::Registry as MetricsRegistry; use prometheus::Registry as MetricsRegistry;
use quick_xml::events::Event; use quick_xml::events::Event;
use quick_xml::name::LocalName;
use quick_xml::NsReader; use quick_xml::NsReader;
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader}; use tokio::io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader};
use tokio::io::{ReadHalf as GenericReadHalf, WriteHalf as GenericWriteHalf}; use tokio::io::{ReadHalf as GenericReadHalf, WriteHalf as GenericWriteHalf};
use tokio::net::tcp::{ReadHalf, WriteHalf}; use tokio::net::tcp::{ReadHalf, WriteHalf};
use tokio::net::TcpStream; use tokio::net::TcpStream;
@ -15,11 +18,15 @@ use tokio_rustls::rustls::client::ServerCertVerifier;
use tokio_rustls::rustls::{ClientConfig, ServerName}; use tokio_rustls::rustls::{ClientConfig, ServerName};
use tokio_rustls::TlsConnector; use tokio_rustls::TlsConnector;
use lavina_core::player::PlayerRegistry; use lavina_core::clustering::{ClusterConfig, ClusterMetadata};
use lavina_core::player::PlayerId;
use lavina_core::repo::{Storage, StorageConfig}; use lavina_core::repo::{Storage, StorageConfig};
use lavina_core::room::RoomRegistry; use lavina_core::LavinaCore;
use projection_xmpp::{launch, ServerConfig}; use projection_xmpp::{launch, RunningServer, ServerConfig};
use proto_xmpp::xml::{Continuation, FromXml, Parser};
fn element_name<'a>(local_name: &LocalName<'a>) -> &'a str {
from_utf8(local_name.into_inner()).unwrap()
}
pub async fn read_irc_message(reader: &mut BufReader<ReadHalf<'_>>, buf: &mut Vec<u8>) -> Result<usize> { pub async fn read_irc_message(reader: &mut BufReader<ReadHalf<'_>>, buf: &mut Vec<u8>) -> Result<usize> {
let mut size = 0; let mut size = 0;
@ -54,19 +61,13 @@ impl<'a> TestScope<'a> {
Ok(event) Ok(event)
} }
async fn read<T: FromXml>(&mut self) -> Result<T> { async fn expect_starttls_required(&mut self) -> Result<()> {
self.buffer.clear(); assert_matches!(self.next_xml_event().await?, Event::Start(b) => assert_eq!(element_name(&b.local_name()), "features"));
let (ns, event) = self.reader.read_resolved_event_into_async(&mut self.buffer).await?; assert_matches!(self.next_xml_event().await?, Event::Start(b) => assert_eq!(element_name(&b.local_name()), "starttls"));
let mut parser: Continuation<_, std::result::Result<T, anyhow::Error>> = T::parse().consume(ns, &event); assert_matches!(self.next_xml_event().await?, Event::Empty(b) => assert_eq!(element_name(&b.local_name()), "required"));
loop { assert_matches!(self.next_xml_event().await?, Event::End(b) => assert_eq!(element_name(&b.local_name()), "starttls"));
match parser { assert_matches!(self.next_xml_event().await?, Event::End(b) => assert_eq!(element_name(&b.local_name()), "features"));
Continuation::Final(res) => return Ok(res?), Ok(())
Continuation::Continue(next) => {
let (ns, event) = self.reader.read_resolved_event_into_async(&mut self.buffer).await?;
parser = next.consume(ns, &event);
}
}
}
} }
} }
@ -81,7 +82,7 @@ impl<'a> TestScopeTls<'a> {
fn new(stream: &'a mut TlsStream<TcpStream>, buffer: Vec<u8>) -> TestScopeTls<'a> { fn new(stream: &'a mut TlsStream<TcpStream>, buffer: Vec<u8>) -> TestScopeTls<'a> {
let (reader, writer) = tokio::io::split(stream); let (reader, writer) = tokio::io::split(stream);
let reader = NsReader::from_reader(BufReader::new(reader)); let reader = NsReader::from_reader(BufReader::new(reader));
let timeout = Duration::from_millis(100); let timeout = Duration::from_millis(500);
TestScopeTls { TestScopeTls {
reader, reader,
@ -97,6 +98,24 @@ impl<'a> TestScopeTls<'a> {
Ok(()) Ok(())
} }
async fn expect_auth_mechanisms(&mut self) -> Result<()> {
assert_matches!(self.next_xml_event().await?, Event::Start(b) => assert_eq!(element_name(&b.local_name()), "features"));
assert_matches!(self.next_xml_event().await?, Event::Start(b) => assert_eq!(element_name(&b.local_name()), "mechanisms"));
assert_matches!(self.next_xml_event().await?, Event::Start(b) => assert_eq!(element_name(&b.local_name()), "mechanism"));
assert_matches!(self.next_xml_event().await?, Event::Text(b) => assert_eq!(&*b, b"PLAIN"));
assert_matches!(self.next_xml_event().await?, Event::End(b) => assert_eq!(element_name(&b.local_name()), "mechanism"));
assert_matches!(self.next_xml_event().await?, Event::End(b) => assert_eq!(element_name(&b.local_name()), "mechanisms"));
assert_matches!(self.next_xml_event().await?, Event::End(b) => assert_eq!(element_name(&b.local_name()), "features"));
Ok(())
}
async fn expect_bind_feature(&mut self) -> Result<()> {
assert_matches!(self.next_xml_event().await?, Event::Start(b) => assert_eq!(element_name(&b.local_name()), "features"));
assert_matches!(self.next_xml_event().await?, Event::Empty(b) => assert_eq!(element_name(&b.local_name()), "bind"));
assert_matches!(self.next_xml_event().await?, Event::End(b) => assert_eq!(element_name(&b.local_name()), "features"));
Ok(())
}
async fn next_xml_event(&mut self) -> Result<Event<'_>> { async fn next_xml_event(&mut self) -> Result<Event<'_>> {
self.buffer.clear(); self.buffer.clear();
let event = self.reader.read_event_into_async(&mut self.buffer); let event = self.reader.read_event_into_async(&mut self.buffer);
@ -106,6 +125,7 @@ impl<'a> TestScopeTls<'a> {
} }
struct IgnoreCertVerification; struct IgnoreCertVerification;
impl ServerCertVerifier for IgnoreCertVerification { impl ServerCertVerifier for IgnoreCertVerification {
fn verify_server_cert( fn verify_server_cert(
&self, &self,
@ -120,43 +140,66 @@ impl ServerCertVerifier for IgnoreCertVerification {
} }
} }
struct TestServer {
core: LavinaCore,
server: RunningServer,
}
impl TestServer {
async fn start() -> Result<TestServer> {
let _ = tracing_subscriber::fmt::try_init();
let config = ServerConfig {
listen_on: "127.0.0.1:0".parse()?,
cert: "tests/certs/xmpp.pem".parse()?,
key: "tests/certs/xmpp.key".parse()?,
hostname: "localhost".into(),
};
let mut metrics = MetricsRegistry::new();
let storage = Storage::open(StorageConfig {
db_path: ":memory:".into(),
})
.await?;
let cluster_config = ClusterConfig {
addresses: vec![],
metadata: ClusterMetadata {
node_id: 0,
main_owner: 0,
rooms: Default::default(),
},
};
let core = LavinaCore::new(&mut metrics, cluster_config, storage).await?;
let server = launch(config, core.clone(), metrics.clone()).await.unwrap();
Ok(TestServer { core, server })
}
async fn shutdown(self) -> Result<()> {
self.server.terminate().await?;
let storage = self.core.shutdown().await;
storage.close().await;
Ok(())
}
}
#[tokio::test] #[tokio::test]
async fn scenario_basic() -> Result<()> { async fn scenario_basic() -> Result<()> {
tracing_subscriber::fmt::init(); let server = TestServer::start().await?;
let config = ServerConfig {
listen_on: "127.0.0.1:0".parse().unwrap(),
cert: "tests/certs/xmpp.pem".parse().unwrap(),
key: "tests/certs/xmpp.key".parse().unwrap(),
};
let mut metrics = MetricsRegistry::new();
let mut storage = Storage::open(StorageConfig {
db_path: ":memory:".into(),
})
.await?;
let rooms = RoomRegistry::new(&mut metrics, storage.clone()).unwrap();
let players = PlayerRegistry::empty(rooms.clone(), &mut metrics).unwrap();
let server = launch(config, players, rooms, metrics, storage.clone()).await.unwrap();
// test scenario // test scenario
storage.create_user("tester").await?; server.core.create_player(&PlayerId::from("tester")?).await?;
storage.set_password("tester", "password").await?; server.core.set_password("tester", "password").await?;
let mut stream = TcpStream::connect(server.addr).await?; let mut stream = TcpStream::connect(server.server.addr).await?;
let mut s = TestScope::new(&mut stream); let mut s = TestScope::new(&mut stream);
tracing::info!("TCP connection established"); tracing::info!("TCP connection established");
s.send(r#"<?xml version="1.0"?>"#).await?; s.send(r#"<?xml version="1.0"?>"#).await?;
s.send(r#"<stream:stream xmlns:stream="http://etherx.jabber.org/streams" to="127.0.0.1" xml:lang="en" xmlns:xml="http://www.w3.org/XML/1998/namespace" xmlns="jabber:client" version="1.0">"#).await?; s.send(r#"<stream:stream xmlns:stream="http://etherx.jabber.org/streams" to="127.0.0.1" xml:lang="en" xmlns:xml="http://www.w3.org/XML/1998/namespace" xmlns="jabber:client" version="1.0">"#).await?;
assert_matches!(s.next_xml_event().await?, Event::Decl(_) => {}); assert_matches!(s.next_xml_event().await?, Event::Decl(_) => {});
assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"stream")); assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(element_name(&b.local_name()), "stream"));
assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"features")); s.expect_starttls_required().await?;
assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"starttls"));
assert_matches!(s.next_xml_event().await?, Event::Empty(b) => assert_eq!(b.local_name().into_inner(), b"required"));
assert_matches!(s.next_xml_event().await?, Event::End(b) => assert_eq!(b.local_name().into_inner(), b"starttls"));
assert_matches!(s.next_xml_event().await?, Event::End(b) => assert_eq!(b.local_name().into_inner(), b"features"));
s.send(r#"<starttls/>"#).await?; s.send(r#"<starttls/>"#).await?;
assert_matches!(s.next_xml_event().await?, Event::Empty(b) => assert_eq!(b.local_name().into_inner(), b"proceed")); assert_matches!(s.next_xml_event().await?, Event::Empty(b) => assert_eq!(element_name(&b.local_name()), "proceed"));
let buffer = s.buffer; let buffer = s.buffer;
tracing::info!("TLS feature negotiation complete"); tracing::info!("TLS feature negotiation complete");
@ -167,7 +210,7 @@ async fn scenario_basic() -> Result<()> {
.with_no_client_auth(), .with_no_client_auth(),
)); ));
tracing::info!("Initiating TLS connection..."); tracing::info!("Initiating TLS connection...");
let mut stream = connector.connect(ServerName::IpAddress(server.addr.ip()), stream).await?; let mut stream = connector.connect(ServerName::IpAddress(server.server.addr.ip()), stream).await?;
tracing::info!("TLS connection established"); tracing::info!("TLS connection established");
let mut s = TestScopeTls::new(&mut stream, buffer); let mut s = TestScopeTls::new(&mut stream, buffer);
@ -175,12 +218,259 @@ async fn scenario_basic() -> Result<()> {
s.send(r#"<?xml version="1.0"?>"#).await?; s.send(r#"<?xml version="1.0"?>"#).await?;
s.send(r#"<stream:stream xmlns:stream="http://etherx.jabber.org/streams" to="127.0.0.1" xml:lang="en" xmlns:xml="http://www.w3.org/XML/1998/namespace" xmlns="jabber:client" version="1.0">"#).await?; s.send(r#"<stream:stream xmlns:stream="http://etherx.jabber.org/streams" to="127.0.0.1" xml:lang="en" xmlns:xml="http://www.w3.org/XML/1998/namespace" xmlns="jabber:client" version="1.0">"#).await?;
assert_matches!(s.next_xml_event().await?, Event::Decl(_) => {}); assert_matches!(s.next_xml_event().await?, Event::Decl(_) => {});
assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"stream")); assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(element_name(&b.local_name()), "stream"));
s.expect_auth_mechanisms().await?;
// base64-encoded "\x00tester\x00password"
s.send(r#"<auth xmlns="urn:ietf:params:xml:ns:xmpp-sasl" mechanism="PLAIN">AHRlc3RlcgBwYXNzd29yZA==</auth>"#)
.await?;
assert_matches!(s.next_xml_event().await?, Event::Empty(b) => assert_eq!(element_name(&b.local_name()), "success"));
s.send(r#"<?xml version="1.0"?>"#).await?;
s.send(r#"<stream:stream xmlns:stream="http://etherx.jabber.org/streams" to="127.0.0.1" xml:lang="en" xmlns:xml="http://www.w3.org/XML/1998/namespace" xmlns="jabber:client" version="1.0">"#).await?;
assert_matches!(s.next_xml_event().await?, Event::Decl(_) => {});
assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(element_name(&b.local_name()), "stream"));
s.expect_bind_feature().await?;
s.send(r#"<iq id="bind_1" type="set"><bind xmlns="urn:ietf:params:xml:ns:xmpp-bind"><resource>kek</resource></bind></iq>"#).await?;
assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(element_name(&b.local_name()), "iq"));
assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(element_name(&b.local_name()), "bind"));
assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(element_name(&b.local_name()), "jid"));
assert_matches!(s.next_xml_event().await?, Event::Text(b) => assert_eq!(&*b, b"tester@localhost/tester"));
assert_matches!(s.next_xml_event().await?, Event::End(b) => assert_eq!(element_name(&b.local_name()), "jid"));
assert_matches!(s.next_xml_event().await?, Event::End(b) => assert_eq!(element_name(&b.local_name()), "bind"));
assert_matches!(s.next_xml_event().await?, Event::End(b) => assert_eq!(element_name(&b.local_name()), "iq"));
s.send(r#"<presence xmlns="jabber:client" type="unavailable"><status>Logged out</status></presence>"#).await?;
stream.shutdown().await?; stream.shutdown().await?;
// wrap up // wrap up
server.terminate().await?; server.shutdown().await?;
Ok(())
}
#[tokio::test]
async fn scenario_wrong_password() -> Result<()> {
let server = TestServer::start().await?;
// test scenario
server.core.create_player(&PlayerId::from("tester")?).await?;
server.core.set_password("tester", "password").await?;
let mut stream = TcpStream::connect(server.server.addr).await?;
let mut s = TestScope::new(&mut stream);
tracing::info!("TCP connection established");
s.send(r#"<?xml version="1.0"?>"#).await?;
s.send(r#"<stream:stream xmlns:stream="http://etherx.jabber.org/streams" to="127.0.0.1" xml:lang="en" xmlns:xml="http://www.w3.org/XML/1998/namespace" xmlns="jabber:client" version="1.0">"#).await?;
assert_matches!(s.next_xml_event().await?, Event::Decl(_) => {});
assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(element_name(&b.local_name()), "stream"));
s.expect_starttls_required().await?;
s.send(r#"<starttls/>"#).await?;
assert_matches!(s.next_xml_event().await?, Event::Empty(b) => assert_eq!(element_name(&b.local_name()), "proceed"));
let buffer = s.buffer;
tracing::info!("TLS feature negotiation complete");
let connector = TlsConnector::from(Arc::new(
ClientConfig::builder()
.with_safe_defaults()
.with_custom_certificate_verifier(Arc::new(IgnoreCertVerification))
.with_no_client_auth(),
));
tracing::info!("Initiating TLS connection...");
let mut stream = connector.connect(ServerName::IpAddress(server.server.addr.ip()), stream).await?;
tracing::info!("TLS connection established");
let mut s = TestScopeTls::new(&mut stream, buffer);
s.send(r#"<?xml version="1.0"?>"#).await?;
s.send(r#"<stream:stream xmlns:stream="http://etherx.jabber.org/streams" to="127.0.0.1" xml:lang="en" xmlns:xml="http://www.w3.org/XML/1998/namespace" xmlns="jabber:client" version="1.0">"#).await?;
assert_matches!(s.next_xml_event().await?, Event::Decl(_) => {});
assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(element_name(&b.local_name()), "stream"));
s.expect_auth_mechanisms().await?;
// base64-encoded "\x00tester\x00password2"
s.send(r#"<auth xmlns="urn:ietf:params:xml:ns:xmpp-sasl" mechanism="PLAIN">AHRlc3RlcgBwYXNzd29yZDI=</auth>"#)
.await?;
assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(element_name(&b.local_name()), "failure"));
assert_matches!(s.next_xml_event().await?, Event::Empty(b) => assert_eq!(element_name(&b.local_name()), "not-authorized"));
assert_matches!(s.next_xml_event().await?, Event::End(b) => assert_eq!(element_name(&b.local_name()), "failure"));
let _ = stream.shutdown().await;
// wrap up
server.shutdown().await?;
Ok(())
}
#[tokio::test]
async fn scenario_basic_without_headers() -> Result<()> {
let server = TestServer::start().await?;
// test scenario
server.core.create_player(&PlayerId::from("tester")?).await?;
server.core.set_password("tester", "password").await?;
let mut stream = TcpStream::connect(server.server.addr).await?;
let mut s = TestScope::new(&mut stream);
tracing::info!("TCP connection established");
s.send(r#"<stream:stream xmlns:stream="http://etherx.jabber.org/streams" to="127.0.0.1" xml:lang="en" xmlns:xml="http://www.w3.org/XML/1998/namespace" xmlns="jabber:client" version="1.0">"#).await?;
assert_matches!(s.next_xml_event().await?, Event::Decl(_) => {});
assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(element_name(&b.local_name()), "stream"));
s.expect_starttls_required().await?;
s.send(r#"<starttls/>"#).await?;
assert_matches!(s.next_xml_event().await?, Event::Empty(b) => assert_eq!(element_name(&b.local_name()), "proceed"));
let buffer = s.buffer;
tracing::info!("TLS feature negotiation complete");
let connector = TlsConnector::from(Arc::new(
ClientConfig::builder()
.with_safe_defaults()
.with_custom_certificate_verifier(Arc::new(IgnoreCertVerification))
.with_no_client_auth(),
));
tracing::info!("Initiating TLS connection...");
let mut stream = connector.connect(ServerName::IpAddress(server.server.addr.ip()), stream).await?;
tracing::info!("TLS connection established");
let mut s = TestScopeTls::new(&mut stream, buffer);
s.send(r#"<stream:stream xmlns:stream="http://etherx.jabber.org/streams" to="127.0.0.1" xml:lang="en" xmlns:xml="http://www.w3.org/XML/1998/namespace" xmlns="jabber:client" version="1.0">"#).await?;
assert_matches!(s.next_xml_event().await?, Event::Decl(_) => {});
assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(element_name(&b.local_name()), "stream"));
stream.shutdown().await?;
// wrap up
server.shutdown().await?;
Ok(())
}
#[tokio::test]
async fn terminate_socket() -> Result<()> {
let server = TestServer::start().await?;
// test scenario
server.core.create_player(&PlayerId::from("tester")?).await?;
server.core.set_password("tester", "password").await?;
let mut stream = TcpStream::connect(server.server.addr).await?;
let mut s = TestScope::new(&mut stream);
tracing::info!("TCP connection established");
s.send(r#"<?xml version="1.0"?>"#).await?;
s.send(r#"<stream:stream xmlns:stream="http://etherx.jabber.org/streams" to="127.0.0.1" xml:lang="en" xmlns:xml="http://www.w3.org/XML/1998/namespace" xmlns="jabber:client" version="1.0">"#).await?;
assert_matches!(s.next_xml_event().await?, Event::Decl(_) => {});
assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(element_name(&b.local_name()), "stream"));
s.expect_starttls_required().await?;
s.send(r#"<starttls/>"#).await?;
assert_matches!(s.next_xml_event().await?, Event::Empty(b) => assert_eq!(element_name(&b.local_name()), "proceed"));
let connector = TlsConnector::from(Arc::new(
ClientConfig::builder()
.with_safe_defaults()
.with_custom_certificate_verifier(Arc::new(IgnoreCertVerification))
.with_no_client_auth(),
));
tracing::info!("Initiating TLS connection...");
let mut stream = connector.connect(ServerName::IpAddress(server.server.addr.ip()), stream).await?;
tracing::info!("TLS connection established");
server.shutdown().await?;
assert_eq!(stream.read_u8().await.unwrap_err().kind(), ErrorKind::UnexpectedEof);
Ok(())
}
#[tokio::test]
async fn test_message_archive_request() -> Result<()> {
let server = TestServer::start().await?;
// test scenario
server.core.create_player(&PlayerId::from("tester")?).await?;
server.core.set_password("tester", "password").await?;
let mut stream = TcpStream::connect(server.server.addr).await?;
let mut s = TestScope::new(&mut stream);
tracing::info!("TCP connection established");
s.send(r#"<?xml version="1.0"?>"#).await?;
s.send(r#"<stream:stream xmlns:stream="http://etherx.jabber.org/streams" to="127.0.0.1" xml:lang="en" xmlns:xml="http://www.w3.org/XML/1998/namespace" xmlns="jabber:client" version="1.0">"#).await?;
assert_matches!(s.next_xml_event().await?, Event::Decl(_) => {});
assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(element_name(&b.local_name()), "stream"));
s.expect_starttls_required().await?;
s.send(r#"<starttls/>"#).await?;
assert_matches!(s.next_xml_event().await?, Event::Empty(b) => assert_eq!(element_name(&b.local_name()), "proceed"));
let buffer = s.buffer;
tracing::info!("TLS feature negotiation complete");
let connector = TlsConnector::from(Arc::new(
ClientConfig::builder()
.with_safe_defaults()
.with_custom_certificate_verifier(Arc::new(IgnoreCertVerification))
.with_no_client_auth(),
));
tracing::info!("Initiating TLS connection...");
let mut stream = connector.connect(ServerName::IpAddress(server.server.addr.ip()), stream).await?;
tracing::info!("TLS connection established");
let mut s = TestScopeTls::new(&mut stream, buffer);
s.send(r#"<?xml version="1.0"?>"#).await?;
s.send(r#"<stream:stream xmlns:stream="http://etherx.jabber.org/streams" to="127.0.0.1" xml:lang="en" xmlns:xml="http://www.w3.org/XML/1998/namespace" xmlns="jabber:client" version="1.0">"#).await?;
assert_matches!(s.next_xml_event().await?, Event::Decl(_) => {});
assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(element_name(&b.local_name()), "stream"));
s.expect_auth_mechanisms().await?;
// base64-encoded "\x00tester\x00password"
s.send(r#"<auth xmlns="urn:ietf:params:xml:ns:xmpp-sasl" mechanism="PLAIN">AHRlc3RlcgBwYXNzd29yZA==</auth>"#)
.await?;
assert_matches!(s.next_xml_event().await?, Event::Empty(b) => assert_eq!(element_name(&b.local_name()), "success"));
s.send(r#"<?xml version="1.0"?>"#).await?;
s.send(r#"<stream:stream xmlns:stream="http://etherx.jabber.org/streams" to="127.0.0.1" xml:lang="en" xmlns:xml="http://www.w3.org/XML/1998/namespace" xmlns="jabber:client" version="1.0">"#).await?;
assert_matches!(s.next_xml_event().await?, Event::Decl(_) => {});
assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(element_name(&b.local_name()), "stream"));
s.expect_bind_feature().await?;
s.send(r#"<iq id="bind_1" type="set"><bind xmlns="urn:ietf:params:xml:ns:xmpp-bind"><resource>kek</resource></bind></iq>"#).await?;
assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(element_name(&b.local_name()), "iq"));
assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(element_name(&b.local_name()), "bind"));
assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(element_name(&b.local_name()), "jid"));
assert_matches!(s.next_xml_event().await?, Event::Text(b) => assert_eq!(&*b, b"tester@localhost/tester"));
assert_matches!(s.next_xml_event().await?, Event::End(b) => assert_eq!(element_name(&b.local_name()), "jid"));
assert_matches!(s.next_xml_event().await?, Event::End(b) => assert_eq!(element_name(&b.local_name()), "bind"));
assert_matches!(s.next_xml_event().await?, Event::End(b) => assert_eq!(element_name(&b.local_name()), "iq"));
s.send(r#"<iq type='get' id='juliet1'><query xmlns='urn:xmpp:mam:2' queryid='f27'/></iq>"#).await?;
assert_matches!(s.next_xml_event().await?, Event::Start(b) => {
assert_eq!(element_name(&b.local_name()), "iq")
});
assert_matches!(s.next_xml_event().await?, Event::Start(b) => {
assert_eq!(element_name(&b.local_name()), "fin")
});
assert_matches!(s.next_xml_event().await?, Event::Start(b) => {
assert_eq!(element_name(&b.local_name()), "set")
});
assert_matches!(s.next_xml_event().await?, Event::Start(b) => {
assert_eq!(element_name(&b.local_name()), "count")
});
assert_matches!(s.next_xml_event().await?, Event::Text(b) => {
assert_eq!(&*b, b"0")
});
s.send(r#"<presence xmlns="jabber:client" type="unavailable"><status>Logged out</status></presence>"#).await?;
stream.shutdown().await?;
// wrap up
server.shutdown().await?;
Ok(()) Ok(())
} }

View File

@ -1,66 +1,82 @@
use super::*;
use anyhow::{anyhow, Result}; use anyhow::{anyhow, Result};
use nom::combinator::{all_consuming, opt}; use nom::combinator::{all_consuming, opt};
use nonempty::NonEmpty; use nonempty::NonEmpty;
use super::*;
/// Client-to-server command. /// Client-to-server command.
#[derive(Clone, Debug, PartialEq, Eq)] #[derive(Clone, Debug, PartialEq, Eq)]
pub enum ClientMessage { pub enum ClientMessage {
/// CAP. Capability-related commands. /// `CAP`. Capability-related commands.
Capability { Capability {
subcommand: CapabilitySubcommand, subcommand: CapabilitySubcommand,
}, },
/// PING <token> /// `PING <token>`
Ping { Ping {
token: Str, token: Str,
}, },
/// PONG <token> /// `PONG <token>`
Pong { Pong {
token: Str, token: Str,
}, },
/// NICK <nickname> /// `NICK <nickname>`
Nick { Nick {
nickname: Str, nickname: Str,
}, },
/// PASS <password> /// `PASS <password>`
Pass { Pass {
password: Str, password: Str,
}, },
/// USER <username> 0 * :<realname> /// `USER <username> 0 * :<realname>`
User { User {
username: Str, username: Str,
realname: Str, realname: Str,
}, },
/// JOIN <chan> /// `JOIN <chan>`
Join(Chan), Join(Chan),
/// MODE <target> /// `MODE <target>`
Mode { Mode {
target: Recipient, target: Recipient,
}, },
/// WHO <target> /// `WHO <target>`
Who { Who {
target: Recipient, // aka mask target: Recipient, // aka mask
}, },
/// TOPIC <chan> :<topic> /// WHOIS [<target>] <nick>
Whois(Whois),
/// `TOPIC <chan> :<topic>`
Topic { Topic {
chan: Chan, chan: Chan,
topic: Str, topic: Str,
}, },
Part { Part {
chan: Chan, chan: Chan,
message: Str, message: Option<Str>,
}, },
/// PRIVMSG <target> :<msg> /// `PRIVMSG <target> :<msg>`
PrivateMessage { PrivateMessage {
recipient: Recipient, recipient: Recipient,
body: Str, body: Str,
}, },
/// QUIT :<reason> /// `QUIT :<reason>`
Quit { Quit {
reason: Str, reason: Str,
}, },
Authenticate(Str), Authenticate(Str),
ChatHistory(ChatHistory),
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum Whois {
Nick(Str),
TargetNick(Str, Str),
EmptyArgs,
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct ChatHistory {
pub chan: Chan,
pub limit: u32,
} }
pub fn client_message(input: &str) -> Result<ClientMessage> { pub fn client_message(input: &str) -> Result<ClientMessage> {
@ -74,11 +90,13 @@ pub fn client_message(input: &str) -> Result<ClientMessage> {
client_message_join, client_message_join,
client_message_mode, client_message_mode,
client_message_who, client_message_who,
client_message_whois,
client_message_topic, client_message_topic,
client_message_part, client_message_part,
client_message_privmsg, client_message_privmsg,
client_message_quit, client_message_quit,
client_message_authenticate, client_message_authenticate,
client_message_chathistory,
)))(input); )))(input);
match res { match res {
Ok((_, e)) => Ok(e), Ok((_, e)) => Ok(e),
@ -118,6 +136,7 @@ fn client_message_nick(input: &str) -> IResult<&str, ClientMessage> {
}, },
)) ))
} }
fn client_message_pass(input: &str) -> IResult<&str, ClientMessage> { fn client_message_pass(input: &str) -> IResult<&str, ClientMessage> {
let (input, _) = tag("PASS ")(input)?; let (input, _) = tag("PASS ")(input)?;
let (input, r) = opt(tag(":"))(input)?; let (input, r) = opt(tag(":"))(input)?;
@ -156,6 +175,7 @@ fn client_message_user(input: &str) -> IResult<&str, ClientMessage> {
}, },
)) ))
} }
fn client_message_join(input: &str) -> IResult<&str, ClientMessage> { fn client_message_join(input: &str) -> IResult<&str, ClientMessage> {
let (input, _) = tag("JOIN ")(input)?; let (input, _) = tag("JOIN ")(input)?;
let (input, chan) = chan(input)?; let (input, chan) = chan(input)?;
@ -177,6 +197,16 @@ fn client_message_who(input: &str) -> IResult<&str, ClientMessage> {
Ok((input, ClientMessage::Who { target })) Ok((input, ClientMessage::Who { target }))
} }
fn client_message_whois(input: &str) -> IResult<&str, ClientMessage> {
let (input, _) = tag("WHOIS ")(input)?;
let args: Vec<_> = input.split_whitespace().collect();
match args.as_slice()[..] {
[nick] => Ok(("", ClientMessage::Whois(Whois::Nick(nick.into())))),
[target, nick, ..] => Ok(("", ClientMessage::Whois(Whois::TargetNick(target.into(), nick.into())))),
[] => Ok(("", ClientMessage::Whois(Whois::EmptyArgs))),
}
}
fn client_message_topic(input: &str) -> IResult<&str, ClientMessage> { fn client_message_topic(input: &str) -> IResult<&str, ClientMessage> {
let (input, _) = tag("TOPIC ")(input)?; let (input, _) = tag("TOPIC ")(input)?;
let (input, chan) = chan(input)?; let (input, chan) = chan(input)?;
@ -194,14 +224,20 @@ fn client_message_topic(input: &str) -> IResult<&str, ClientMessage> {
fn client_message_part(input: &str) -> IResult<&str, ClientMessage> { fn client_message_part(input: &str) -> IResult<&str, ClientMessage> {
let (input, _) = tag("PART ")(input)?; let (input, _) = tag("PART ")(input)?;
let (input, chan) = chan(input)?; let (input, chan) = chan(input)?;
let (input, _) = tag(" ")(input)?; let (input, t) = opt(tag(" "))(input)?;
match t {
Some(_) => (),
None => {
return Ok((input, ClientMessage::Part { chan, message: None }));
}
}
let (input, r) = opt(tag(":"))(input)?; let (input, r) = opt(tag(":"))(input)?;
let (input, message) = match r { let (input, message) = match r {
Some(_) => token(input)?, Some(_) => token(input)?,
None => receiver(input)?, None => receiver(input)?,
}; };
let message = message.into(); let message = Some(message.into());
Ok((input, ClientMessage::Part { chan, message })) Ok((input, ClientMessage::Part { chan, message }))
} }
@ -233,6 +269,22 @@ fn client_message_authenticate(input: &str) -> IResult<&str, ClientMessage> {
Ok((input, ClientMessage::Authenticate(body.into()))) Ok((input, ClientMessage::Authenticate(body.into())))
} }
fn client_message_chathistory(input: &str) -> IResult<&str, ClientMessage> {
let (input, _) = tag("CHATHISTORY LATEST ")(input)?;
let (input, chan) = chan(input)?;
let (input, _) = tag(" * ")(input)?;
let (input, limit) = limit(input)?;
Ok((input, ClientMessage::ChatHistory(ChatHistory { chan, limit })))
}
fn limit(input: &str) -> IResult<&str, u32> {
let (input, limit) = receiver(input)?;
let limit = limit.parse().unwrap();
Ok((input, limit))
}
#[derive(Clone, Debug, PartialEq, Eq)] #[derive(Clone, Debug, PartialEq, Eq)]
pub enum CapabilitySubcommand { pub enum CapabilitySubcommand {
/// CAP LS {code} /// CAP LS {code}
@ -305,6 +357,7 @@ mod test {
use nonempty::nonempty; use nonempty::nonempty;
use super::*; use super::*;
#[test] #[test]
fn test_client_message_cap_ls() { fn test_client_message_cap_ls() {
let input = "CAP LS 302"; let input = "CAP LS 302";
@ -335,6 +388,7 @@ mod test {
let result = client_message(input); let result = client_message(input);
assert_matches!(result, Ok(result) => assert_eq!(expected, result)); assert_matches!(result, Ok(result) => assert_eq!(expected, result));
} }
#[test] #[test]
fn test_client_message_pong() { fn test_client_message_pong() {
let input = "PONG 1337"; let input = "PONG 1337";
@ -343,6 +397,7 @@ mod test {
let result = client_message(input); let result = client_message(input);
assert_matches!(result, Ok(result) => assert_eq!(expected, result)); assert_matches!(result, Ok(result) => assert_eq!(expected, result));
} }
#[test] #[test]
fn test_client_message_nick() { fn test_client_message_nick() {
let input = "NICK SomeNick"; let input = "NICK SomeNick";
@ -353,6 +408,56 @@ mod test {
let result = client_message(input); let result = client_message(input);
assert_matches!(result, Ok(result) => assert_eq!(expected, result)); assert_matches!(result, Ok(result) => assert_eq!(expected, result));
} }
#[test]
fn test_client_message_whois() {
let test_user = "WHOIS val";
let test_user_user = "WHOIS val val";
let test_server_user = "WHOIS com.test.server user";
let test_user_server = "WHOIS user com.test.server";
let test_users_list = "WHOIS user_1,user_2,user_3";
let test_server_users_list = "WHOIS com.test.server user_1,user_2,user_3";
let test_more_than_two_params = "WHOIS test.server user_1,user_2,user_3 whatever spam";
let test_none_none_params = "WHOIS ";
let res_one_arg = client_message(test_user);
let res_user_user = client_message(test_user_user);
let res_server_user = client_message(test_server_user);
let res_user_server = client_message(test_user_server);
let res_users_list = client_message(test_users_list);
let res_server_users_list = client_message(test_server_users_list);
let res_more_than_two_params = client_message(test_more_than_two_params);
let res_none_none_params = client_message(test_none_none_params);
let expected_arg = ClientMessage::Whois(Whois::Nick("val".into()));
let expected_user_user = ClientMessage::Whois(Whois::TargetNick("val".into(), "val".into()));
let expected_server_user = ClientMessage::Whois(Whois::TargetNick("com.test.server".into(), "user".into()));
let expected_user_server = ClientMessage::Whois(Whois::TargetNick("user".into(), "com.test.server".into()));
let expected_user_list = ClientMessage::Whois(Whois::Nick("user_1,user_2,user_3".into()));
let expected_server_user_list = ClientMessage::Whois(Whois::TargetNick(
"com.test.server".into(),
"user_1,user_2,user_3".into(),
));
let expected_more_than_two_params =
ClientMessage::Whois(Whois::TargetNick("test.server".into(), "user_1,user_2,user_3".into()));
let expected_none_none_params = ClientMessage::Whois(Whois::EmptyArgs);
assert_matches!(res_one_arg, Ok(result) => assert_eq!(expected_arg, result));
assert_matches!(res_user_user, Ok(result) => assert_eq!(expected_user_user, result));
assert_matches!(res_server_user, Ok(result) => assert_eq!(expected_server_user, result));
assert_matches!(res_user_server, Ok(result) => assert_eq!(expected_user_server, result));
assert_matches!(res_users_list, Ok(result) => assert_eq!(expected_user_list, result));
assert_matches!(res_server_users_list, Ok(result) => assert_eq!(expected_server_user_list, result));
assert_matches!(res_more_than_two_params, Ok(result) => assert_eq!(expected_more_than_two_params, result));
assert_matches!(res_none_none_params, Ok(result) => assert_eq!(expected_none_none_params, result))
}
#[test] #[test]
fn test_client_message_user() { fn test_client_message_user() {
let input = "USER SomeNick 8 * :Real Name"; let input = "USER SomeNick 8 * :Real Name";
@ -364,17 +469,31 @@ mod test {
let result = client_message(input); let result = client_message(input);
assert_matches!(result, Ok(result) => assert_eq!(expected, result)); assert_matches!(result, Ok(result) => assert_eq!(expected, result));
} }
#[test] #[test]
fn test_client_message_part() { fn test_client_message_part() {
let input = "PART #chan :Pokasiki !!!"; let input = "PART #chan :Pokasiki !!!";
let expected = ClientMessage::Part { let expected = ClientMessage::Part {
chan: Chan::Global("chan".into()), chan: Chan::Global("chan".into()),
message: "Pokasiki !!!".into(), message: Some("Pokasiki !!!".into()),
}; };
let result = client_message(input); let result = client_message(input);
assert_matches!(result, Ok(result) => assert_eq!(expected, result)); assert_matches!(result, Ok(result) => assert_eq!(expected, result));
} }
#[test]
fn test_client_message_part_empty() {
let input = "PART #chan";
let expected = ClientMessage::Part {
chan: Chan::Global("chan".into()),
message: None,
};
let result = client_message(input);
assert_matches!(result, Ok(result) => assert_eq!(expected, result));
}
#[test] #[test]
fn test_client_cap_req() { fn test_client_cap_req() {
let input = "CAP REQ :multi-prefix -sasl"; let input = "CAP REQ :multi-prefix -sasl";
@ -394,4 +513,16 @@ mod test {
let result = client_message(input); let result = client_message(input);
assert_matches!(result, Ok(result) => assert_eq!(expected, result)); assert_matches!(result, Ok(result) => assert_eq!(expected, result));
} }
#[test]
fn test_client_chat_history_latest() {
let input = "CHATHISTORY LATEST #chan * 10";
let expected = ClientMessage::ChatHistory(ChatHistory {
chan: Chan::Global("chan".into()),
limit: 10,
});
let result = client_message(input);
assert_matches!(result, Ok(result) => assert_eq!(expected, result));
}
} }

View File

@ -0,0 +1 @@
pub mod whois;

View File

@ -0,0 +1,67 @@
use tokio::io::{AsyncWrite, AsyncWriteExt};
use crate::{prelude::Str, response::WriteResponse};
/// ErrNoSuchNick401
pub struct ErrNoSuchNick401 {
client: Str,
nick: Str,
}
impl ErrNoSuchNick401 {
pub fn new(client: Str, nick: Str) -> Self {
ErrNoSuchNick401 { client, nick }
}
}
/// ErrNoSuchServer402
struct ErrNoSuchServer402 {
client: Str,
/// target parameter in WHOIS
/// example: `/whois <target> <nick>`
server_name: Str,
}
/// ErrNoNicknameGiven431
pub struct ErrNoNicknameGiven431 {
client: Str,
}
impl ErrNoNicknameGiven431 {
pub fn new(client: Str) -> Self {
ErrNoNicknameGiven431 { client }
}
}
impl WriteResponse for ErrNoSuchNick401 {
async fn write_response(&self, writer: &mut (impl AsyncWrite + Unpin)) -> std::io::Result<()> {
writer.write_all(b"401 ").await?;
writer.write_all(self.client.as_bytes()).await?;
writer.write_all(b" ").await?;
writer.write_all(self.nick.as_bytes()).await?;
writer.write_all(b" :").await?;
writer.write_all("No such nick/channel".as_bytes()).await?;
Ok(())
}
}
impl WriteResponse for ErrNoNicknameGiven431 {
async fn write_response(&self, writer: &mut (impl AsyncWrite + Unpin)) -> std::io::Result<()> {
writer.write_all(b"431").await?;
writer.write_all(self.client.as_bytes()).await?;
writer.write_all(b" :").await?;
writer.write_all("No nickname given".as_bytes()).await?;
Ok(())
}
}
impl WriteResponse for ErrNoSuchServer402 {
async fn write_response(&self, writer: &mut (impl AsyncWrite + Unpin)) -> std::io::Result<()> {
writer.write_all(b"402 ").await?;
writer.write_all(self.client.as_bytes()).await?;
writer.write_all(b" ").await?;
writer.write_all(self.server_name.as_bytes()).await?;
writer.write_all(b" :").await?;
writer.write_all("No such server".as_bytes()).await?;
Ok(())
}
}

View File

@ -0,0 +1,2 @@
pub mod error;
pub mod response;

View File

@ -0,0 +1,24 @@
use tokio::io::{AsyncWrite, AsyncWriteExt};
use crate::{prelude::Str, response::WriteResponse};
pub struct RplEndOfWhois318 {
client: Str,
nick: Str,
}
impl RplEndOfWhois318 {
pub fn new(client: Str, nick: Str) -> Self {
RplEndOfWhois318 { client, nick }
}
}
impl WriteResponse for RplEndOfWhois318 {
async fn write_response(&self, writer: &mut (impl AsyncWrite + Unpin)) -> std::io::Result<()> {
writer.write_all(b"318 ").await?;
writer.write_all(self.client.as_bytes()).await?;
writer.write_all(b" ").await?;
writer.write_all(self.nick.as_bytes()).await?;
writer.write_all(b" :").await?;
writer.write_all("End of /WHOIS list".as_bytes()).await?;
Ok(())
}
}

View File

@ -1,6 +1,8 @@
//! Client-to-Server IRC protocol. //! Client-to-Server IRC protocol.
pub mod client; pub mod client;
pub mod commands;
mod prelude; mod prelude;
pub mod response;
pub mod server; pub mod server;
#[cfg(test)] #[cfg(test)]
mod testkit; mod testkit;
@ -18,8 +20,19 @@ use tokio::io::{AsyncWrite, AsyncWriteExt};
/// Single message tag value. /// Single message tag value.
#[derive(Clone, Debug, PartialEq, Eq)] #[derive(Clone, Debug, PartialEq, Eq)]
pub struct Tag { pub struct Tag {
key: Str, pub key: Str,
value: Option<u8>, pub value: Option<Str>,
}
impl Tag {
pub async fn write_async(&self, writer: &mut (impl AsyncWrite + Unpin)) -> std::io::Result<()> {
writer.write_all(self.key.as_bytes()).await?;
if let Some(value) = &self.value {
writer.write_all(b"=").await?;
writer.write_all(value.as_bytes()).await?;
}
Ok(())
}
} }
fn receiver(input: &str) -> IResult<&str, &str> { fn receiver(input: &str) -> IResult<&str, &str> {
@ -36,9 +49,9 @@ fn params(input: &str) -> IResult<&str, &str> {
#[derive(Clone, Debug, PartialEq, Eq)] #[derive(Clone, Debug, PartialEq, Eq)]
pub enum Chan { pub enum Chan {
/// #<name> — network-global channel, available from any server in the network. /// `#<name>` — network-global channel, available from any server in the network.
Global(Str), Global(Str),
/// &<name> — server-local channel, available only to connections to the same server. Rarely used in practice. /// `&<name>` — server-local channel, available only to connections to the same server. Rarely used in practice.
Local(Str), Local(Str),
} }
impl Chan { impl Chan {
@ -118,9 +131,7 @@ mod test {
assert_matches!(result, Ok((_, result)) => assert_eq!(expected, result)); assert_matches!(result, Ok((_, result)) => assert_eq!(expected, result));
let mut bytes = vec![]; let mut bytes = vec![];
sync_future(expected.write_async(&mut bytes)) sync_future(expected.write_async(&mut bytes)).unwrap().unwrap();
.unwrap()
.unwrap();
assert_eq!(bytes.as_slice(), input.as_bytes()); assert_eq!(bytes.as_slice(), input.as_bytes());
} }
@ -134,9 +145,7 @@ mod test {
assert_matches!(result, Ok((_, result)) => assert_eq!(expected, result)); assert_matches!(result, Ok((_, result)) => assert_eq!(expected, result));
let mut bytes = vec![]; let mut bytes = vec![];
sync_future(expected.write_async(&mut bytes)) sync_future(expected.write_async(&mut bytes)).unwrap().unwrap();
.unwrap()
.unwrap();
assert_eq!(bytes.as_slice(), input.as_bytes()); assert_eq!(bytes.as_slice(), input.as_bytes());
} }
@ -150,9 +159,7 @@ mod test {
assert_matches!(result, Ok((_, result)) => assert_eq!(expected, result)); assert_matches!(result, Ok((_, result)) => assert_eq!(expected, result));
let mut bytes = vec![]; let mut bytes = vec![];
sync_future(expected.write_async(&mut bytes)) sync_future(expected.write_async(&mut bytes)).unwrap().unwrap();
.unwrap()
.unwrap();
assert_eq!(bytes.as_slice(), input.as_bytes()); assert_eq!(bytes.as_slice(), input.as_bytes());
} }

View File

@ -0,0 +1,47 @@
use std::future::Future;
use tokio::io::{AsyncWrite, AsyncWriteExt};
use crate::prelude::Str;
use crate::Tag;
pub trait WriteResponse {
fn write_response(&self, writer: &mut (impl AsyncWrite + Unpin)) -> impl Future<Output = std::io::Result<()>>;
}
/// Server-to-client enum agnostic message
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct IrcResponseMessage<T> {
/// Optional tags section, prefixed with `@`
pub tags: Vec<Tag>,
/// Optional server name, prefixed with `:`.
pub sender: Option<Str>,
pub body: T,
}
impl<T> IrcResponseMessage<T> {
pub fn empty_tags(sender: Option<Str>, body: T) -> Self {
IrcResponseMessage {
tags: vec![],
sender,
body,
}
}
pub fn new(tags: Vec<Tag>, sender: Option<Str>, body: T) -> Self {
IrcResponseMessage { tags, sender, body }
}
}
impl<T: WriteResponse> WriteResponse for IrcResponseMessage<T> {
async fn write_response(&self, writer: &mut (impl AsyncWrite + Unpin)) -> std::io::Result<()> {
if let Some(sender) = &self.sender {
writer.write_all(b":").await?;
writer.write_all(sender.as_bytes()).await?;
writer.write_all(b" ").await?;
}
self.body.write_response(writer).await?;
writer.write_all(b"\r\n").await?;
Ok(())
}
}

View File

@ -1,12 +1,11 @@
use std::sync::Arc;
use nonempty::NonEmpty; use nonempty::NonEmpty;
use tokio::io::AsyncWrite; use tokio::io::AsyncWrite;
use tokio::io::AsyncWriteExt; use tokio::io::AsyncWriteExt;
use super::*;
use crate::user::PrefixedNick; use crate::user::PrefixedNick;
use super::*;
/// Server-to-client message. /// Server-to-client message.
#[derive(Clone, Debug, PartialEq, Eq)] #[derive(Clone, Debug, PartialEq, Eq)]
pub struct ServerMessage { pub struct ServerMessage {
@ -19,6 +18,13 @@ pub struct ServerMessage {
impl ServerMessage { impl ServerMessage {
pub async fn write_async(&self, writer: &mut (impl AsyncWrite + Unpin)) -> std::io::Result<()> { pub async fn write_async(&self, writer: &mut (impl AsyncWrite + Unpin)) -> std::io::Result<()> {
if !self.tags.is_empty() {
for tag in &self.tags {
writer.write_all(b"@").await?;
tag.write_async(writer).await?;
writer.write_all(b" ").await?;
}
}
match &self.sender { match &self.sender {
Some(ref sender) => { Some(ref sender) => {
writer.write_all(b":").await?; writer.write_all(b":").await?;
@ -107,6 +113,12 @@ pub enum ServerMessageBody {
/// Usually `b"End of WHO list"` /// Usually `b"End of WHO list"`
msg: Str, msg: Str,
}, },
N318EndOfWhois {
client: Str,
nick: Str,
/// Usually `b"End of /WHOIS list"`
msg: Str,
},
N332Topic { N332Topic {
client: Str, client: Str,
chat: Chan, chat: Chan,
@ -136,6 +148,10 @@ pub enum ServerMessageBody {
client: Str, client: Str,
chan: Chan, chan: Chan,
}, },
N431ErrNoNicknameGiven {
client: Str,
message: Str,
},
N474BannedFromChan { N474BannedFromChan {
client: Str, client: Str,
chan: Chan, chan: Chan,
@ -158,7 +174,7 @@ pub enum ServerMessageBody {
N904SaslFail { N904SaslFail {
nick: Str, nick: Str,
text: Str, text: Str,
} },
} }
impl ServerMessageBody { impl ServerMessageBody {
@ -273,11 +289,15 @@ impl ServerMessageBody {
writer.write_all(b" :").await?; writer.write_all(b" :").await?;
writer.write_all(msg.as_bytes()).await?; writer.write_all(msg.as_bytes()).await?;
} }
ServerMessageBody::N332Topic { ServerMessageBody::N318EndOfWhois { client, nick, msg } => {
client, writer.write_all(b"318 ").await?;
chat, writer.write_all(client.as_bytes()).await?;
topic, writer.write_all(b" ").await?;
} => { writer.write_all(nick.as_bytes()).await?;
writer.write_all(b" :").await?;
writer.write_all(msg.as_bytes()).await?;
}
ServerMessageBody::N332Topic { client, chat, topic } => {
writer.write_all(b"332 ").await?; writer.write_all(b"332 ").await?;
writer.write_all(client.as_bytes()).await?; writer.write_all(client.as_bytes()).await?;
writer.write_all(b" ").await?; writer.write_all(b" ").await?;
@ -315,22 +335,21 @@ impl ServerMessageBody {
writer.write_all(b" ").await?; writer.write_all(b" ").await?;
writer.write_all(realname.as_bytes()).await?; writer.write_all(realname.as_bytes()).await?;
} }
ServerMessageBody::N353NamesReply { ServerMessageBody::N353NamesReply { client, chan, members } => {
client,
chan,
members,
} => {
writer.write_all(b"353 ").await?; writer.write_all(b"353 ").await?;
writer.write_all(client.as_bytes()).await?; writer.write_all(client.as_bytes()).await?;
writer.write_all(b" = ").await?; writer.write_all(b" = ").await?;
chan.write_async(writer).await?; chan.write_async(writer).await?;
writer.write_all(b" :").await?; writer.write_all(b" :").await?;
for member in members { {
writer let member = &members.head;
.write_all(member.prefix.to_string().as_bytes()) writer.write_all(member.prefix.to_string().as_bytes()).await?;
.await?;
writer.write_all(member.nick.as_bytes()).await?; writer.write_all(member.nick.as_bytes()).await?;
}
for member in &members.tail {
writer.write_all(b" ").await?; writer.write_all(b" ").await?;
writer.write_all(member.prefix.to_string().as_bytes()).await?;
writer.write_all(member.nick.as_bytes()).await?;
} }
} }
ServerMessageBody::N366NamesReplyEnd { client, chan } => { ServerMessageBody::N366NamesReplyEnd { client, chan } => {
@ -340,11 +359,13 @@ impl ServerMessageBody {
chan.write_async(writer).await?; chan.write_async(writer).await?;
writer.write_all(b" :End of /NAMES list").await?; writer.write_all(b" :End of /NAMES list").await?;
} }
ServerMessageBody::N474BannedFromChan { ServerMessageBody::N431ErrNoNicknameGiven { client, message } => {
client, writer.write_all(b"431").await?;
chan, writer.write_all(client.as_bytes()).await?;
message, writer.write_all(b" :").await?;
} => { writer.write_all(message.as_bytes()).await?;
}
ServerMessageBody::N474BannedFromChan { client, chan, message } => {
writer.write_all(b"474 ").await?; writer.write_all(b"474 ").await?;
writer.write_all(client.as_bytes()).await?; writer.write_all(client.as_bytes()).await?;
writer.write_all(b" ").await?; writer.write_all(b" ").await?;
@ -359,7 +380,12 @@ impl ServerMessageBody {
writer.write_all(b" :").await?; writer.write_all(b" :").await?;
writer.write_all(message.as_bytes()).await?; writer.write_all(message.as_bytes()).await?;
} }
ServerMessageBody::N900LoggedIn { nick, address, account, message } => { ServerMessageBody::N900LoggedIn {
nick,
address,
account,
message,
} => {
writer.write_all(b"900 ").await?; writer.write_all(b"900 ").await?;
writer.write_all(nick.as_bytes()).await?; writer.write_all(nick.as_bytes()).await?;
writer.write_all(b" ").await?; writer.write_all(b" ").await?;
@ -404,7 +430,7 @@ fn server_message_body(input: &str) -> IResult<&str, ServerMessageBody> {
server_message_body_notice, server_message_body_notice,
server_message_body_ping, server_message_body_ping,
server_message_body_pong, server_message_body_pong,
server_message_body_cap server_message_body_cap,
))(input) ))(input)
} }
@ -467,9 +493,10 @@ fn server_message_body_cap(input: &str) -> IResult<&str, ServerMessageBody> {
mod test { mod test {
use assert_matches::*; use assert_matches::*;
use super::*;
use crate::testkit::*; use crate::testkit::*;
use super::*;
#[test] #[test]
fn test_server_message_notice() { fn test_server_message_notice() {
let input = "NOTICE * :*** Looking up your hostname...\r\n"; let input = "NOTICE * :*** Looking up your hostname...\r\n";

View File

@ -11,12 +11,15 @@ pub const XMLNS: &'static str = "urn:ietf:params:xml:ns:xmpp-bind";
// TODO remove `pub` in newtypes, introduce validation // TODO remove `pub` in newtypes, introduce validation
/// Name (node identifier) of an XMPP entity. Placed before the `@` in a JID.
#[derive(PartialEq, Eq, Debug, Clone)] #[derive(PartialEq, Eq, Debug, Clone)]
pub struct Name(pub Str); pub struct Name(pub Str);
/// Server name of an XMPP entity. Placed after the `@` and before the `/` in a JID.
#[derive(PartialEq, Eq, Debug, Clone)] #[derive(PartialEq, Eq, Debug, Clone)]
pub struct Server(pub Str); pub struct Server(pub Str);
/// Resource of an XMPP entity. Placed after the `/` in a JID.
#[derive(PartialEq, Eq, Debug, Clone)] #[derive(PartialEq, Eq, Debug, Clone)]
pub struct Resource(pub Str); pub struct Resource(pub Str);
@ -42,27 +45,20 @@ impl Display for Jid {
impl Jid { impl Jid {
pub fn from_string(i: &str) -> Result<Jid> { pub fn from_string(i: &str) -> Result<Jid> {
use regex::Regex;
use lazy_static::lazy_static; use lazy_static::lazy_static;
use regex::Regex;
lazy_static! { lazy_static! {
static ref RE: Regex = Regex::new(r"^(([a-zA-Z]+)@)?([a-zA-Z.]+)(/([a-zA-Z\-]+))?$").unwrap(); static ref RE: Regex =
Regex::new(r"^(([a-zA-Z0-9]+)@)?([^@/]+)(/([a-zA-Z0-9\-]+))?$").expect("this is a correct regex");
} }
let m = RE let m = RE.captures(i).ok_or(anyhow!("Incorrectly format jid: {i}"))?;
.captures(i)
.ok_or(anyhow!("Incorrectly format jid: {i}"))?;
let name = m.get(2).map(|name| Name(name.as_str().into())); let name = m.get(2).map(|name| Name(name.as_str().into()));
let server = m.get(3).unwrap(); let server = m.get(3).unwrap();
let server = Server(server.as_str().into()); let server = Server(server.as_str().into());
let resource = m let resource = m.get(5).map(|resource| Resource(resource.as_str().into()));
.get(5)
.map(|resource| Resource(resource.as_str().into()));
Ok(Jid { Ok(Jid { name, server, resource })
name,
server,
resource,
})
} }
} }
@ -79,15 +75,16 @@ impl Jid {
pub struct BindRequest(pub Resource); pub struct BindRequest(pub Resource);
impl FromXmlTag for BindRequest { impl FromXmlTag for BindRequest {
const NS: &'static str = XMLNS;
const NAME: &'static str = "bind"; const NAME: &'static str = "bind";
const NS: &'static str = XMLNS;
} }
impl FromXml for BindRequest { impl FromXml for BindRequest {
type P = impl Parser<Output = Result<Self>>; type P = impl Parser<Output = Result<Self>>;
fn parse() -> Self::P { fn parse() -> Self::P {
|(namespace, event): (ResolveResult<'static>, &'static Event<'static>)| -> Result<Self> { #[coroutine]
|(mut namespace, mut event): (ResolveResult<'static>, &'static Event<'static>)| -> Result<Self> {
let mut resource: Option<Str> = None; let mut resource: Option<Str> = None;
let Event::Start(bytes) = event else { let Event::Start(bytes) = event else {
return Err(anyhow!("Unexpected XML event: {event:?}")); return Err(anyhow!("Unexpected XML event: {event:?}"));
@ -102,15 +99,15 @@ impl FromXml for BindRequest {
return Err(anyhow!("Incorrect namespace")); return Err(anyhow!("Incorrect namespace"));
} }
loop { loop {
let (namespace, event) = yield; (namespace, event) = yield;
match event { match event {
Event::Start(bytes) if bytes.name().0 == b"resource" => { Event::Start(bytes) if bytes.name().0 == b"resource" => {
let (namespace, event) = yield; (namespace, event) = yield;
let Event::Text(text) = event else { let Event::Text(text) = event else {
return Err(anyhow!("Unexpected XML event: {event:?}")); return Err(anyhow!("Unexpected XML event: {event:?}"));
}; };
resource = Some(std::str::from_utf8(&*text)?.into()); resource = Some(std::str::from_utf8(&*text)?.into());
let (namespace, event) = yield; (namespace, event) = yield;
let Event::End(bytes) = event else { let Event::End(bytes) = event else {
return Err(anyhow!("Unexpected XML event: {event:?}")); return Err(anyhow!("Unexpected XML event: {event:?}"));
}; };
@ -132,13 +129,15 @@ impl FromXml for BindRequest {
} }
} }
#[derive(PartialEq, Eq, Debug)]
pub struct BindResponse(pub Jid); pub struct BindResponse(pub Jid);
impl ToXml for BindResponse { impl ToXml for BindResponse {
fn serialize(&self, events: &mut Vec<Event<'static>>) { fn serialize(&self, events: &mut Vec<Event<'static>>) {
events.extend_from_slice(&[ events.extend_from_slice(&[
Event::Start(BytesStart::new( Event::Start(BytesStart::from_content(
r#"bind xmlns="urn:ietf:params:xml:ns:xmpp-bind""#, r#"bind xmlns="urn:ietf:params:xml:ns:xmpp-bind""#,
4,
)), )),
Event::Start(BytesStart::new(r#"jid"#)), Event::Start(BytesStart::new(r#"jid"#)),
Event::Text(BytesText::new(self.0.to_string().as_str()).into_owned()), Event::Text(BytesText::new(self.0.to_string().as_str()).into_owned()),
@ -155,77 +154,74 @@ mod tests {
use super::*; use super::*;
#[tokio::test] #[tokio::test]
async fn parse_message() { async fn parse_message() -> Result<()> {
let input = let input = r#"<bind xmlns="urn:ietf:params:xml:ns:xmpp-bind"><resource>mobile</resource></bind>"#;
r#"<bind xmlns="urn:ietf:params:xml:ns:xmpp-bind"><resource>mobile</resource></bind>"#;
let mut reader = NsReader::from_reader(input.as_bytes()); let mut reader = NsReader::from_reader(input.as_bytes());
let mut buf = vec![]; let mut buf = vec![];
let (ns, event) = reader let (ns, event) = reader.read_resolved_event_into_async(&mut buf).await?;
.read_resolved_event_into_async(&mut buf)
.await
.unwrap();
let mut parser = BindRequest::parse().consume(ns, &event); let mut parser = BindRequest::parse().consume(ns, &event);
let result = loop { let result = loop {
match parser { match parser {
Continuation::Final(res) => break res, Continuation::Final(res) => break res,
Continuation::Continue(next) => { Continuation::Continue(next) => {
let (ns, event) = reader let (ns, event) = reader.read_resolved_event_into_async(&mut buf).await?;
.read_resolved_event_into_async(&mut buf)
.await
.unwrap();
parser = next.consume(ns, &event); parser = next.consume(ns, &event);
} }
} }
} }?;
.unwrap(); assert_eq!(result, BindRequest(Resource("mobile".into())));
assert_eq!(result, BindRequest(Resource("mobile".into())),) Ok(())
} }
#[test] #[test]
fn jid_parse_full() { fn jid_parse_full() -> Result<()> {
let input = "chelik@server.example/kek"; let input = "chelik@server.example/kek";
let expected = Jid { let expected = Jid {
name: Some(Name("chelik".into())), name: Some(Name("chelik".into())),
server: Server("server.example".into()), server: Server("server.example".into()),
resource: Some(Resource("kek".into())), resource: Some(Resource("kek".into())),
}; };
let res = Jid::from_string(input).unwrap(); let res = Jid::from_string(input)?;
assert_eq!(res, expected); assert_eq!(res, expected);
Ok(())
} }
#[test] #[test]
fn jid_parse_user() { fn jid_parse_user() -> Result<()> {
let input = "chelik@server.example"; let input = "chelik@server.example";
let expected = Jid { let expected = Jid {
name: Some(Name("chelik".into())), name: Some(Name("chelik".into())),
server: Server("server.example".into()), server: Server("server.example".into()),
resource: None, resource: None,
}; };
let res = Jid::from_string(input).unwrap(); let res = Jid::from_string(input)?;
assert_eq!(res, expected); assert_eq!(res, expected);
Ok(())
} }
#[test] #[test]
fn jid_parse_server() { fn jid_parse_server() -> Result<()> {
let input = "server.example"; let input = "server.example";
let expected = Jid { let expected = Jid {
name: None, name: None,
server: Server("server.example".into()), server: Server("server.example".into()),
resource: None, resource: None,
}; };
let res = Jid::from_string(input).unwrap(); let res = Jid::from_string(input)?;
assert_eq!(res, expected); assert_eq!(res, expected);
Ok(())
} }
#[test] #[test]
fn jid_parse_server_resource() { fn jid_parse_server_resource() -> Result<()> {
let input = "server.example/kek"; let input = "server.example/kek";
let expected = Jid { let expected = Jid {
name: None, name: None,
server: Server("server.example".into()), server: Server("server.example".into()),
resource: Some(Resource("kek".into())), resource: Some(Resource("kek".into())),
}; };
let res = Jid::from_string(input).unwrap(); let res = Jid::from_string(input)?;
assert_eq!(res, expected); assert_eq!(res, expected);
Ok(())
} }
} }

View File

@ -20,8 +20,8 @@ pub struct Message<T> {
// default is Normal // default is Normal
pub r#type: MessageType, pub r#type: MessageType,
pub lang: Option<Str>, pub lang: Option<Str>,
pub subject: Option<Str>, pub subject: Option<Subject>,
pub body: Str, pub body: Option<Str>,
pub custom: Vec<T>, pub custom: Vec<T>,
} }
@ -38,6 +38,20 @@ impl<T: FromXml> FromXml for Message<T> {
} }
} }
#[derive(PartialEq, Eq, Debug)]
pub struct Subject(pub Option<Str>);
impl ToXml for Subject {
fn serialize(&self, events: &mut Vec<Event<'static>>) {
if let Some(ref s) = self.0 {
events.push(Event::Start(BytesStart::new("subject")));
events.push(Event::Text(BytesText::new(s).into_owned()));
events.push(Event::End(BytesEnd::new("subject")));
} else {
events.push(Event::Empty(BytesStart::new("subject")));
}
}
}
#[derive(From)] #[derive(From)]
struct MessageParser<T: FromXml>(MessageParserInner<T>); struct MessageParser<T: FromXml>(MessageParserInner<T>);
@ -57,7 +71,7 @@ struct MessageParserState<T> {
to: Option<Jid>, to: Option<Jid>,
r#type: MessageType, r#type: MessageType,
lang: Option<Str>, lang: Option<Str>,
subject: Option<Str>, subject: Option<Subject>,
body: Option<Str>, body: Option<Str>,
custom: Vec<T>, custom: Vec<T>,
} }
@ -121,22 +135,16 @@ impl<T: FromXml> Parser for MessageParser<T> {
} }
} }
} }
Event::End(_) => { Event::End(_) => Continuation::Final(Ok(Message {
if let Some(body) = state.body { from: state.from,
Continuation::Final(Ok(Message { id: state.id,
from: state.from, to: state.to,
id: state.id, r#type: state.r#type,
to: state.to, lang: state.lang,
r#type: state.r#type, subject: state.subject,
lang: state.lang, body: state.body,
subject: state.subject, custom: state.custom,
body, })),
custom: state.custom,
}))
} else {
Continuation::Final(Err(ffail!("Body not found")))
}
}
Event::Empty(_) => { Event::Empty(_) => {
let parser = T::parse(); let parser = T::parse();
match parser.consume(namespace, event) { match parser.consume(namespace, event) {
@ -153,7 +161,7 @@ impl<T: FromXml> Parser for MessageParser<T> {
InSubject(mut state) => match event { InSubject(mut state) => match event {
Event::Text(ref bytes) => { Event::Text(ref bytes) => {
let subject = fail_fast!(std::str::from_utf8(&*bytes)); let subject = fail_fast!(std::str::from_utf8(&*bytes));
state.subject = Some(subject.into()); state.subject = Some(Subject(Some(subject.into())));
Continuation::Continue(InSubject(state).into()) Continuation::Continue(InSubject(state).into())
} }
Event::End(_) => Continuation::Continue(Outer(state).into()), Event::End(_) => Continuation::Continue(Outer(state).into()),
@ -208,9 +216,14 @@ impl<T: ToXml> ToXml for Message<T> {
value: self.r#type.as_str().as_bytes().into(), value: self.r#type.as_str().as_bytes().into(),
}); });
events.push(Event::Start(bytes)); events.push(Event::Start(bytes));
events.push(Event::Start(BytesStart::new("body"))); if let Some(subject) = &self.subject {
events.push(Event::Text(BytesText::new(&self.body).into_owned())); subject.serialize(events);
events.push(Event::End(BytesEnd::new("body"))); }
if let Some(body) = &self.body {
events.push(Event::Start(BytesStart::new("body")));
events.push(Event::Text(BytesText::new(body).into_owned()));
events.push(Event::End(BytesEnd::new("body")));
}
events.push(Event::End(BytesEnd::new("message"))); events.push(Event::End(BytesEnd::new("message")));
} }
} }
@ -255,11 +268,72 @@ impl MessageType {
} }
} }
/// Error response to an IQ request.
///
/// https://xmpp.org/rfcs/rfc6120.html#stanzas-error
pub struct IqError {
pub r#type: IqErrorType,
pub condition: Option<IqErrorCondition>,
}
pub enum IqErrorCondition {
ItemNotFound,
}
impl IqErrorCondition {
pub fn as_str(&self) -> &'static str {
match self {
IqErrorCondition::ItemNotFound => "item-not-found",
}
}
}
pub enum IqErrorType {
/// Retry after providing credentials
Auth,
/// Do not retry (the error cannot be remedied)
Cancel,
/// Proceed (the condition was only a warning)
Continue,
/// Retry after changing the data sent
Modify,
/// Retry after waiting (the error is temporary)
Wait,
}
impl IqErrorType {
pub fn as_str(&self) -> &'static str {
match self {
IqErrorType::Auth => "auth",
IqErrorType::Cancel => "cancel",
IqErrorType::Continue => "continue",
IqErrorType::Modify => "modify",
IqErrorType::Wait => "wait",
}
}
}
impl ToXml for IqError {
fn serialize(&self, events: &mut Vec<Event<'static>>) {
let bytes = BytesStart::new(format!(r#"error xmlns="{}" type="{}""#, XMLNS, self.r#type.as_str()));
match self.condition {
None => {
events.push(Event::Empty(bytes));
}
Some(IqErrorCondition::ItemNotFound) => {
events.push(Event::Start(bytes));
let bytes2 = BytesStart::new(r#"item-not-found xmlns="urn:ietf:params:xml:ns:xmpp-stanzas""#);
events.push(Event::Empty(bytes2));
let bytes = BytesEnd::new("error");
events.push(Event::End(bytes));
}
}
}
}
#[derive(PartialEq, Eq, Debug)] #[derive(PartialEq, Eq, Debug)]
pub struct Iq<T> { pub struct Iq<T> {
pub from: Option<String>, pub from: Option<Jid>,
pub id: String, pub id: String,
pub to: Option<String>, pub to: Option<Jid>,
pub r#type: IqType, pub r#type: IqType,
pub body: T, pub body: T,
} }
@ -285,9 +359,9 @@ enum IqParserInner<T: FromXml> {
Final(IqParserState<T>), Final(IqParserState<T>),
} }
struct IqParserState<T> { struct IqParserState<T> {
pub from: Option<String>, pub from: Option<Jid>,
pub id: Option<String>, pub id: Option<String>,
pub to: Option<String>, pub to: Option<Jid>,
pub r#type: Option<IqType>, pub r#type: Option<IqType>,
pub body: Option<T>, pub body: Option<T>,
} }
@ -310,13 +384,15 @@ impl<T: FromXml> Parser for IqParser<T> {
let attr = fail_fast!(attr); let attr = fail_fast!(attr);
if attr.key.0 == b"from" { if attr.key.0 == b"from" {
let value = fail_fast!(std::str::from_utf8(&*attr.value)); let value = fail_fast!(std::str::from_utf8(&*attr.value));
state.from = Some(value.to_string()) let value = fail_fast!(Jid::from_string(value));
state.from = Some(value)
} else if attr.key.0 == b"id" { } else if attr.key.0 == b"id" {
let value = fail_fast!(std::str::from_utf8(&*attr.value)); let value = fail_fast!(std::str::from_utf8(&*attr.value));
state.id = Some(value.to_string()) state.id = Some(value.to_string())
} else if attr.key.0 == b"to" { } else if attr.key.0 == b"to" {
let value = fail_fast!(std::str::from_utf8(&*attr.value)); let value = fail_fast!(std::str::from_utf8(&*attr.value));
state.to = Some(value.to_string()) let value = fail_fast!(Jid::from_string(value));
state.to = Some(value)
} else if attr.key.0 == b"type" { } else if attr.key.0 == b"type" {
let value = fail_fast!(IqType::from_str(&*attr.value)); let value = fail_fast!(IqType::from_str(&*attr.value));
state.r#type = Some(value); state.r#type = Some(value);
@ -338,7 +414,7 @@ impl<T: FromXml> Parser for IqParser<T> {
} }
}, },
IqParserInner::Final(state) => { IqParserInner::Final(state) => {
if let Event::End(ref bytes) = event { if let Event::End(_) = event {
let id = fail_fast!(state.id.ok_or_else(|| ffail!("No id provided"))); let id = fail_fast!(state.id.ok_or_else(|| ffail!("No id provided")));
let r#type = fail_fast!(state.r#type.ok_or_else(|| ffail!("No type provided"))); let r#type = fail_fast!(state.r#type.ok_or_else(|| ffail!("No type provided")));
let body = fail_fast!(state.body.ok_or_else(|| ffail!("No body provided"))); let body = fail_fast!(state.body.ok_or_else(|| ffail!("No body provided")));
@ -393,15 +469,17 @@ impl<T: ToXml> ToXml for Iq<T> {
let mut start = BytesStart::new(start); let mut start = BytesStart::new(start);
let mut attrs = vec![]; let mut attrs = vec![];
if let Some(ref from) = self.from { if let Some(ref from) = self.from {
let value = from.to_string().into_bytes();
attrs.push(Attribute { attrs.push(Attribute {
key: QName(b"from"), key: QName(b"from"),
value: from.as_bytes().into(), value: value.into(),
}); });
}; };
if let Some(ref to) = self.to { if let Some(ref to) = self.to {
let value = to.to_string().into_bytes();
attrs.push(Attribute { attrs.push(Attribute {
key: QName(b"to"), key: QName(b"to"),
value: to.as_bytes().into(), value: value.into(),
}); });
} }
attrs.push(Attribute { attrs.push(Attribute {
@ -422,6 +500,7 @@ impl<T: ToXml> ToXml for Iq<T> {
#[derive(PartialEq, Eq, Debug)] #[derive(PartialEq, Eq, Debug)]
pub struct Presence<T> { pub struct Presence<T> {
pub id: Option<String>,
pub to: Option<Jid>, pub to: Option<Jid>,
pub from: Option<Jid>, pub from: Option<Jid>,
pub priority: Option<PresencePriority>, pub priority: Option<PresencePriority>,
@ -434,6 +513,7 @@ pub struct Presence<T> {
impl<T> Default for Presence<T> { impl<T> Default for Presence<T> {
fn default() -> Self { fn default() -> Self {
Self { Self {
id: Default::default(),
to: Default::default(), to: Default::default(),
from: Default::default(), from: Default::default(),
priority: Default::default(), priority: Default::default(),
@ -486,7 +566,8 @@ impl<T: FromXml> FromXml for Presence<T> {
type P = impl Parser<Output = Result<Presence<T>>>; type P = impl Parser<Output = Result<Presence<T>>>;
fn parse() -> Self::P { fn parse() -> Self::P {
|(namespace, event): (ResolveResult<'static>, &'static Event<'static>)| -> Result<Self> { #[coroutine]
|(mut namespace, mut event): (ResolveResult<'static>, &'static Event<'static>)| -> Result<Self> {
let (bytes, end) = match event { let (bytes, end) = match event {
Event::Start(bytes) => (bytes, false), Event::Start(bytes) => (bytes, false),
Event::Empty(bytes) => (bytes, true), Event::Empty(bytes) => (bytes, true),
@ -508,6 +589,10 @@ impl<T: FromXml> FromXml for Presence<T> {
let s = std::str::from_utf8(&attr.value)?; let s = std::str::from_utf8(&attr.value)?;
p.r#type = Some(s.into()); p.r#type = Some(s.into());
} }
b"id" => {
let s = std::str::from_utf8(&attr.value)?;
p.id = Option::from(s.to_string());
}
_ => {} _ => {}
} }
} }
@ -515,37 +600,37 @@ impl<T: FromXml> FromXml for Presence<T> {
return Ok(p); return Ok(p);
} }
loop { loop {
let (namespace, event) = yield; (namespace, event) = yield;
match event { match event {
Event::Start(bytes) => match bytes.name().0 { Event::Start(bytes) => match bytes.name().0 {
b"show" => { b"show" => {
let (_, event) = yield; (namespace, event) = yield;
let Event::Text(bytes) = event else { let Event::Text(bytes) = event else {
return Err(ffail!("Unexpected XML event: {event:?}")); return Err(ffail!("Unexpected XML event: {event:?}"));
}; };
let i = PresenceShow::from_str(bytes)?; let i = PresenceShow::from_str(bytes)?;
p.show = Some(i); p.show = Some(i);
let (_, event) = yield; (namespace, event) = yield;
let Event::End(_) = event else { let Event::End(_) = event else {
return Err(ffail!("Unexpected XML event: {event:?}")); return Err(ffail!("Unexpected XML event: {event:?}"));
}; };
} }
b"status" => { b"status" => {
let (_, event) = yield; (namespace, event) = yield;
let Event::Text(bytes) = event else { let Event::Text(bytes) = event else {
return Err(ffail!("Unexpected XML event: {event:?}")); return Err(ffail!("Unexpected XML event: {event:?}"));
}; };
let s = std::str::from_utf8(bytes)?; let s = std::str::from_utf8(bytes)?;
p.status.push(s.to_string()); p.status.push(s.to_string());
let (_, event) = yield; (namespace, event) = yield;
let Event::End(_) = event else { let Event::End(_) = event else {
return Err(ffail!("Unexpected XML event: {event:?}")); return Err(ffail!("Unexpected XML event: {event:?}"));
}; };
} }
b"priority" => { b"priority" => {
let (_, event) = yield; (namespace, event) = yield;
let Event::Text(bytes) = event else { let Event::Text(bytes) = event else {
return Err(ffail!("Unexpected XML event: {event:?}")); return Err(ffail!("Unexpected XML event: {event:?}"));
}; };
@ -553,7 +638,7 @@ impl<T: FromXml> FromXml for Presence<T> {
let i = s.parse()?; let i = s.parse()?;
p.priority = Some(PresencePriority(i)); p.priority = Some(PresencePriority(i));
let (_, event) = yield; (namespace, event) = yield;
let Event::End(_) = event else { let Event::End(_) = event else {
return Err(ffail!("Unexpected XML event: {event:?}")); return Err(ffail!("Unexpected XML event: {event:?}"));
}; };
@ -595,6 +680,12 @@ impl<T: ToXml> ToXml for Presence<T> {
value: from.to_string().as_bytes().into(), value: from.to_string().as_bytes().into(),
}]); }]);
} }
if let Some(ref id) = self.id {
start.extend_attributes([Attribute {
key: QName(b"id"),
value: id.to_string().as_bytes().into(),
}]);
}
events.push(Event::Start(start)); events.push(Event::Start(start));
if let Some(ref priority) = self.priority { if let Some(ref priority) = self.priority {
let s = priority.0.to_string(); let s = priority.0.to_string();
@ -604,6 +695,9 @@ impl<T: ToXml> ToXml for Presence<T> {
Event::End(BytesEnd::new("priority")), Event::End(BytesEnd::new("priority")),
]); ]);
} }
for c in &self.custom {
c.serialize(events);
}
events.push(Event::End(BytesEnd::new("presence"))); events.push(Event::End(BytesEnd::new("presence")));
} }
} }
@ -616,7 +710,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn parse_message() { async fn parse_message() {
let input = r#"<message id="aacea" type="chat" to="nikita@vlnv.dev"><subject>daa</subject><body>bbb</body><unknown-stuff></unknown-stuff></message>"#; let input = r#"<message id="aacea" type="chat" to="chelik@example.com"><subject>daa</subject><body>bbb</body><unknown-stuff></unknown-stuff></message>"#;
let result: Message<Ignore> = crate::xml::parse(input).unwrap(); let result: Message<Ignore> = crate::xml::parse(input).unwrap();
assert_eq!( assert_eq!(
result, result,
@ -624,14 +718,14 @@ mod tests {
from: None, from: None,
id: Some("aacea".to_string()), id: Some("aacea".to_string()),
to: Some(Jid { to: Some(Jid {
name: Some(Name("nikita".into())), name: Some(Name("chelik".into())),
server: Server("vlnv.dev".into()), server: Server("example.com".into()),
resource: None resource: None
}), }),
r#type: MessageType::Chat, r#type: MessageType::Chat,
lang: None, lang: None,
subject: Some("daa".into()), subject: Some(Subject(Some("daa".into()))),
body: "bbb".into(), body: Some("bbb".into()),
custom: vec![Ignore], custom: vec![Ignore],
} }
) )
@ -639,7 +733,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn parse_message_empty_custom() { async fn parse_message_empty_custom() {
let input = r#"<message id="aacea" type="chat" to="nikita@vlnv.dev"><subject>daa</subject><body>bbb</body><unknown-stuff/></message>"#; let input = r#"<message id="aacea" type="chat" to="chelik@example.com"><subject>daa</subject><body>bbb</body><unknown-stuff/></message>"#;
let result: Message<Ignore> = crate::xml::parse(input).unwrap(); let result: Message<Ignore> = crate::xml::parse(input).unwrap();
assert_eq!( assert_eq!(
result, result,
@ -647,14 +741,14 @@ mod tests {
from: None, from: None,
id: Some("aacea".to_string()), id: Some("aacea".to_string()),
to: Some(Jid { to: Some(Jid {
name: Some(Name("nikita".into())), name: Some(Name("chelik".into())),
server: Server("vlnv.dev".into()), server: Server("example.com".into()),
resource: None resource: None
}), }),
r#type: MessageType::Chat, r#type: MessageType::Chat,
lang: None, lang: None,
subject: Some("daa".into()), subject: Some(Subject(Some("daa".into()))),
body: "bbb".into(), body: Some("bbb".into()),
custom: vec![Ignore], custom: vec![Ignore],
} }
) )

View File

@ -2,8 +2,8 @@ use quick_xml::events::attributes::Attribute;
use quick_xml::events::{BytesEnd, BytesStart, Event}; use quick_xml::events::{BytesEnd, BytesStart, Event};
use quick_xml::name::{QName, ResolveResult}; use quick_xml::name::{QName, ResolveResult};
use anyhow::{Result, anyhow as ffail};
use crate::xml::*; use crate::xml::*;
use anyhow::{anyhow as ffail, Result};
use super::bind::Jid; use super::bind::Jid;
@ -21,7 +21,8 @@ impl FromXml for InfoQuery {
type P = impl Parser<Output = Result<Self>>; type P = impl Parser<Output = Result<Self>>;
fn parse() -> Self::P { fn parse() -> Self::P {
|(namespace, event): (ResolveResult<'static>, &'static Event<'static>)| -> Result<Self> { #[coroutine]
|(mut namespace, mut event): (ResolveResult<'static>, &'static Event<'static>)| -> Result<Self> {
let mut node = None; let mut node = None;
let mut identity = vec![]; let mut identity = vec![];
let mut feature = vec![]; let mut feature = vec![];
@ -48,7 +49,7 @@ impl FromXml for InfoQuery {
}); });
} }
loop { loop {
let (namespace, event) = yield; (namespace, event) = yield;
let bytes = match event { let bytes = match event {
Event::Start(bytes) => bytes, Event::Start(bytes) => bytes,
Event::Empty(bytes) => bytes, Event::Empty(bytes) => bytes,
@ -141,7 +142,8 @@ impl FromXml for Identity {
type P = impl Parser<Output = Result<Self>>; type P = impl Parser<Output = Result<Self>>;
fn parse() -> Self::P { fn parse() -> Self::P {
|(namespace, event): (ResolveResult<'static>, &'static Event<'static>)| -> Result<Self> { #[coroutine]
|(mut namespace, mut event): (ResolveResult<'static>, &'static Event<'static>)| -> Result<Self> {
let mut category = None; let mut category = None;
let mut name = None; let mut name = None;
let mut r#type = None; let mut r#type = None;
@ -174,17 +176,13 @@ impl FromXml for Identity {
let Some(r#type) = r#type else { let Some(r#type) = r#type else {
return Err(ffail!("No type provided")); return Err(ffail!("No type provided"));
}; };
let item = Identity { let item = Identity { category, name, r#type };
category,
name,
r#type,
};
if end { if end {
return Ok(item); return Ok(item);
} }
let (namespace, event) = yield; (namespace, event) = yield;
let Event::End(bytes) = event else { let Event::End(_) = event else {
return Err(ffail!("Unexpected XML event: {event:?}")); return Err(ffail!("Unexpected XML event: {event:?}"));
}; };
Ok(item) Ok(item)
@ -213,7 +211,8 @@ impl FromXml for Feature {
type P = impl Parser<Output = Result<Self>>; type P = impl Parser<Output = Result<Self>>;
fn parse() -> Self::P { fn parse() -> Self::P {
|(namespace, event): (ResolveResult<'static>, &'static Event<'static>)| -> Result<Self> { #[coroutine]
|(mut namespace, mut event): (ResolveResult<'static>, &'static Event<'static>)| -> Result<Self> {
let mut var = None; let mut var = None;
let (bytes, end) = match event { let (bytes, end) = match event {
Event::Start(bytes) => (bytes, false), Event::Start(bytes) => (bytes, false),
@ -238,8 +237,8 @@ impl FromXml for Feature {
return Ok(item); return Ok(item);
} }
let (namespace, event) = yield; (namespace, event) = yield;
let Event::End(bytes) = event else { let Event::End(_) = event else {
return Err(ffail!("Unexpected XML event: {event:?}")); return Err(ffail!("Unexpected XML event: {event:?}"));
}; };
Ok(item) Ok(item)
@ -262,9 +261,10 @@ impl FromXml for ItemQuery {
type P = impl Parser<Output = Result<Self>>; type P = impl Parser<Output = Result<Self>>;
fn parse() -> Self::P { fn parse() -> Self::P {
|(namespace, event): (ResolveResult<'static>, &'static Event<'static>)| -> Result<Self> { #[coroutine]
|(mut namespace, mut event): (ResolveResult<'static>, &'static Event<'static>)| -> Result<Self> {
let mut item = vec![]; let mut item = vec![];
let (bytes, end) = match event { let (_, end) = match event {
Event::Start(bytes) => (bytes, false), Event::Start(bytes) => (bytes, false),
Event::Empty(bytes) => (bytes, true), Event::Empty(bytes) => (bytes, true),
_ => return Err(ffail!("Unexpected XML event: {event:?}")), _ => return Err(ffail!("Unexpected XML event: {event:?}")),
@ -273,7 +273,7 @@ impl FromXml for ItemQuery {
return Ok(ItemQuery { item }); return Ok(ItemQuery { item });
} }
loop { loop {
let (namespace, event) = yield; (namespace, event) = yield;
let bytes = match event { let bytes = match event {
Event::Start(bytes) => bytes, Event::Start(bytes) => bytes,
Event::Empty(bytes) => bytes, Event::Empty(bytes) => bytes,
@ -300,7 +300,7 @@ impl FromXmlTag for ItemQuery {
impl ToXml for ItemQuery { impl ToXml for ItemQuery {
fn serialize(&self, events: &mut Vec<Event<'static>>) { fn serialize(&self, events: &mut Vec<Event<'static>>) {
let mut bytes = BytesStart::new(format!(r#"query xmlns="{}""#, XMLNS_ITEM)); let bytes = BytesStart::new(format!(r#"query xmlns="{}""#, XMLNS_ITEM));
let empty = self.item.is_empty(); let empty = self.item.is_empty();
if empty { if empty {
events.push(Event::Empty(bytes)); events.push(Event::Empty(bytes));
@ -346,7 +346,8 @@ impl FromXml for Item {
type P = impl Parser<Output = Result<Self>>; type P = impl Parser<Output = Result<Self>>;
fn parse() -> Self::P { fn parse() -> Self::P {
|(namespace, event): (ResolveResult<'static>, &'static Event<'static>)| -> Result<Self> { #[coroutine]
|(_, mut event): (ResolveResult<'static>, &'static Event<'static>)| -> Result<Self> {
let mut jid = None; let mut jid = None;
let mut name = None; let mut name = None;
let mut node = None; let mut node = None;
@ -382,8 +383,8 @@ impl FromXml for Item {
return Ok(item); return Ok(item);
} }
let (namespace, event) = yield; (_, event) = yield;
let Event::End(bytes) = event else { let Event::End(_) = event else {
return Err(ffail!("Unexpected XML event: {event:?}")); return Err(ffail!("Unexpected XML event: {event:?}"));
}; };
Ok(item) Ok(item)

View File

@ -1,23 +1,23 @@
#![feature( #![feature(coroutines, coroutine_trait, type_alias_impl_trait, impl_trait_in_assoc_type)]
coroutines,
coroutine_trait,
type_alias_impl_trait,
impl_trait_in_assoc_type
)]
pub mod bind; pub mod bind;
pub mod client; pub mod client;
pub mod disco; pub mod disco;
pub mod mam;
pub mod muc; pub mod muc;
mod prelude;
pub mod roster; pub mod roster;
pub mod sasl; pub mod sasl;
pub mod session; pub mod session;
pub mod stanzaerror; pub mod stanzaerror;
pub mod stream; pub mod stream;
pub mod streamerror;
pub mod tls; pub mod tls;
mod prelude;
pub mod xml; pub mod xml;
#[cfg(test)]
mod testkit;
// Implemented as a macro instead of a fn due to borrowck limitations // Implemented as a macro instead of a fn due to borrowck limitations
macro_rules! skip_text { macro_rules! skip_text {
($reader: ident, $buf: ident) => { ($reader: ident, $buf: ident) => {

View File

@ -0,0 +1,226 @@
use anyhow::{anyhow, Result};
use quick_xml::events::{BytesEnd, BytesStart, BytesText, Event};
use quick_xml::name::{Namespace, ResolveResult};
use crate::xml::*;
pub const MAM_XMLNS: &'static str = "urn:xmpp:mam:2";
pub const DATA_XMLNS: &'static str = "jabber:x:data";
pub const RESULT_SET_XMLNS: &'static str = "http://jabber.org/protocol/rsm";
#[derive(PartialEq, Eq, Debug, Clone)]
pub struct MessageArchiveRequest {
pub x: Option<X>,
}
#[derive(PartialEq, Eq, Debug, Clone)]
pub struct X {
pub fields: Vec<Field>,
}
#[derive(PartialEq, Eq, Debug, Clone)]
pub struct Field {
pub values: Vec<String>,
}
// Message archive response styled as a result set.
#[derive(PartialEq, Eq, Debug, Clone)]
pub struct Fin {
pub set: Set,
}
#[derive(PartialEq, Eq, Debug, Clone)]
pub struct Set {
pub count: Option<i32>,
}
impl ToXml for Fin {
fn serialize(&self, events: &mut Vec<Event<'static>>) {
let fin_bytes = BytesStart::new(format!(r#"fin xmlns="{}" complete=True"#, MAM_XMLNS));
let set_bytes = BytesStart::new(format!(r#"set xmlns="{}""#, RESULT_SET_XMLNS));
events.push(Event::Start(fin_bytes));
events.push(Event::Start(set_bytes));
if let &Some(count) = &self.set.count {
events.push(Event::Start(BytesStart::new("count")));
events.push(Event::Text(BytesText::new(count.to_string().as_str()).into_owned()));
events.push(Event::End(BytesEnd::new("count")));
}
events.push(Event::End(BytesEnd::new("set")));
events.push(Event::End(BytesEnd::new("fin")));
}
}
impl FromXmlTag for X {
const NAME: &'static str = "x";
const NS: &'static str = DATA_XMLNS;
}
impl FromXmlTag for MessageArchiveRequest {
const NAME: &'static str = "query";
const NS: &'static str = MAM_XMLNS;
}
impl FromXml for X {
type P = impl Parser<Output = Result<Self>>;
fn parse() -> Self::P {
#[coroutine]
|(mut namespace, mut event): (ResolveResult<'static>, &'static Event<'static>)| -> Result<Self> {
println!("X::parse {:?}", event);
let bytes = match event {
Event::Start(bytes) if bytes.name().0 == X::NAME.as_bytes() => bytes,
Event::Empty(bytes) if bytes.name().0 == X::NAME.as_bytes() => return Ok(X { fields: vec![] }),
_ => return Err(anyhow!("Unexpected XML event: {event:?}")),
};
let mut fields = vec![];
loop {
(namespace, event) = yield;
match event {
Event::Start(_) => {
// start of <field>
let mut values = vec![];
loop {
(namespace, event) = yield;
match event {
Event::Start(bytes) if bytes.name().0 == b"value" => {
// start of <value>
}
Event::End(bytes) if bytes.name().0 == b"field" => {
// end of </field>
break;
}
_ => return Err(anyhow!("Unexpected XML event: {event:?}")),
}
(namespace, event) = yield;
let text: String = match event {
Event::Text(bytes) => {
// text inside <value></value>
String::from_utf8(bytes.to_vec())?
}
_ => return Err(anyhow!("Unexpected XML event: {event:?}")),
};
(namespace, event) = yield;
match event {
Event::End(bytes) if bytes.name().0 == b"value" => {
// end of </value>
}
_ => return Err(anyhow!("Unexpected XML event: {event:?}")),
}
values.push(text);
}
fields.push(Field { values })
}
Event::End(bytes) if bytes.name().0 == X::NAME.as_bytes() => {
// end of <x/>
return Ok(X { fields });
}
_ => return Err(anyhow!("Unexpected XML event: {event:?}")),
}
}
}
}
}
impl FromXml for MessageArchiveRequest {
type P = impl Parser<Output = Result<Self>>;
fn parse() -> Self::P {
#[coroutine]
|(mut namespace, mut event): (ResolveResult<'static>, &'static Event<'static>)| -> Result<Self> {
println!("MessageArchiveRequest::parse {:?}", event);
let bytes = match event {
Event::Empty(_) => return Ok(MessageArchiveRequest { x: None }),
Event::Start(bytes) => bytes,
_ => return Err(anyhow!("Unexpected XML event: {event:?}")),
};
if bytes.name().0 != MessageArchiveRequest::NAME.as_bytes() {
return Err(anyhow!("Unexpected XML tag: {:?}", bytes.name()));
}
let ResolveResult::Bound(Namespace(ns)) = namespace else {
return Err(anyhow!("No namespace provided"));
};
if ns != MAM_XMLNS.as_bytes() {
return Err(anyhow!("Incorrect namespace"));
}
(namespace, event) = yield;
match event {
Event::End(bytes) if bytes.name().0 == MessageArchiveRequest::NAME.as_bytes() => {
Ok(MessageArchiveRequest { x: None })
}
Event::Start(bytes) | Event::Empty(bytes) if bytes.name().0 == X::NAME.as_bytes() => {
let x = delegate_parsing!(X, namespace, event)?;
Ok(MessageArchiveRequest { x: Some(x) })
}
_ => Err(anyhow!("Unexpected XML event: {event:?}")),
}
}
}
}
impl MessageArchiveRequest {}
#[cfg(test)]
mod tests {
use super::*;
use crate::bind::{Jid, Name, Server};
use crate::client::{Iq, IqType};
#[test]
fn test_parse_archive_query() {
let input = r#"<iq to='pubsub.shakespeare.lit' type='set' id='juliet1'><query xmlns='urn:xmpp:mam:2' queryid='f28'/></iq>"#;
let result: Iq<MessageArchiveRequest> = parse(input).unwrap();
assert_eq!(
result,
Iq {
from: None,
id: "juliet1".to_string(),
to: Option::from(Jid {
name: None,
server: Server("pubsub.shakespeare.lit".into()),
resource: None,
}),
r#type: IqType::Set,
body: MessageArchiveRequest { x: None },
}
);
}
#[test]
fn test_parse_query_messages_from_jid() {
let input = r#"<iq type='set' id='juliet1'><query xmlns='urn:xmpp:mam:2'><x xmlns='jabber:x:data' type='submit'><field var='FORM_TYPE' type='hidden'><value>value1</value></field><field var='with'><value>juliet@capulet.lit</value></field></x></query></iq>"#;
let result: Iq<MessageArchiveRequest> = parse(input).unwrap();
assert_eq!(
result,
Iq {
from: None,
id: "juliet1".to_string(),
to: None,
r#type: IqType::Set,
body: MessageArchiveRequest {
x: Some(X {
fields: vec![
Field {
values: vec!["value1".to_string()],
},
Field {
values: vec!["juliet@capulet.lit".to_string()],
},
]
})
},
}
);
}
#[test]
fn test_parse_query_messages_from_jid_with_unclosed_tag() {
let input = r#"<iq type='set' id='juliet1'><query xmlns='urn:xmpp:mam:2'><x xmlns='jabber:x:data' type='submit'><field var='FORM_TYPE' type='hidden'><value>value1</value></field><field var='with'><value>juliet@capulet.lit</value></field></query></iq>"#;
assert!(parse::<Iq<MessageArchiveRequest>>(input).is_err())
}
}

View File

@ -1,12 +1,16 @@
#![allow(unused_variables)] #![allow(unused_variables)]
use quick_xml::events::Event;
use quick_xml::name::ResolveResult;
use crate::xml::*;
use anyhow::{anyhow, Result}; use anyhow::{anyhow, Result};
use quick_xml::events::attributes::Attribute;
use quick_xml::events::{BytesEnd, BytesStart, BytesText, Event};
use quick_xml::name::{QName, ResolveResult};
use crate::bind::Jid;
use crate::xml::*;
pub const XMLNS: &'static str = "http://jabber.org/protocol/muc"; pub const XMLNS: &'static str = "http://jabber.org/protocol/muc";
pub const XMLNS_USER: &'static str = "http://jabber.org/protocol/muc#user";
pub const XMLNS_DELAY: &'static str = "urn:xmpp:delay";
#[derive(PartialEq, Eq, Debug, Default)] #[derive(PartialEq, Eq, Debug, Default)]
pub struct History { pub struct History {
@ -19,7 +23,8 @@ impl FromXml for History {
type P = impl Parser<Output = Result<Self>>; type P = impl Parser<Output = Result<Self>>;
fn parse() -> Self::P { fn parse() -> Self::P {
|(namespace, event): (ResolveResult<'static>, &'static Event<'static>)| -> Result<Self> { #[coroutine]
|(mut namespace, mut event): (ResolveResult<'static>, &'static Event<'static>)| -> Result<Self> {
let mut history = History::default(); let mut history = History::default();
let (bytes, end) = match event { let (bytes, end) = match event {
Event::Start(bytes) if bytes.name().0 == Self::NAME.as_bytes() => (bytes, false), Event::Start(bytes) if bytes.name().0 == Self::NAME.as_bytes() => (bytes, false),
@ -51,7 +56,7 @@ impl FromXml for History {
return Ok(history); return Ok(history);
} }
let (namespace, event) = yield; (namespace, event) = yield;
let Event::End(bytes) = event else { let Event::End(bytes) = event else {
return Err(anyhow!("Unexpected XML event: {event:?}")); return Err(anyhow!("Unexpected XML event: {event:?}"));
}; };
@ -73,17 +78,18 @@ impl FromXml for Password {
type P = impl Parser<Output = Result<Self>>; type P = impl Parser<Output = Result<Self>>;
fn parse() -> Self::P { fn parse() -> Self::P {
|(namespace, event): (ResolveResult<'static>, &'static Event<'static>)| -> Result<Self> { #[coroutine]
|(mut namespace, mut event): (ResolveResult<'static>, &'static Event<'static>)| -> Result<Self> {
let bytes = match event { let bytes = match event {
Event::Start(bytes) if bytes.name().0 == Self::NAME.as_bytes() => bytes, Event::Start(bytes) if bytes.name().0 == Self::NAME.as_bytes() => bytes,
_ => return Err(anyhow!("Unexpected XML event: {event:?}")), _ => return Err(anyhow!("Unexpected XML event: {event:?}")),
}; };
let (namespace, event) = yield; (namespace, event) = yield;
let Event::Text(bytes) = event else { let Event::Text(bytes) = event else {
return Err(anyhow!("Unexpected XML event: {event:?}")); return Err(anyhow!("Unexpected XML event: {event:?}"));
}; };
let s = std::str::from_utf8(bytes)?.to_string(); let s = std::str::from_utf8(bytes)?.to_string();
let (namespace, event) = yield; (namespace, event) = yield;
let Event::End(bytes) = event else { let Event::End(bytes) = event else {
return Err(anyhow!("Unexpected XML event: {event:?}")); return Err(anyhow!("Unexpected XML event: {event:?}"));
}; };
@ -108,7 +114,8 @@ impl FromXml for X {
type P = impl Parser<Output = Result<Self>>; type P = impl Parser<Output = Result<Self>>;
fn parse() -> Self::P { fn parse() -> Self::P {
|(namespace, event): (ResolveResult<'static>, &'static Event<'static>)| -> Result<Self> { #[coroutine]
|(mut namespace, mut event): (ResolveResult<'static>, &'static Event<'static>)| -> Result<Self> {
let mut res = X::default(); let mut res = X::default();
let (_, end) = match event { let (_, end) = match event {
Event::Start(bytes) => (bytes, false), Event::Start(bytes) => (bytes, false),
@ -120,7 +127,7 @@ impl FromXml for X {
} }
loop { loop {
let (namespace, event) = yield; (namespace, event) = yield;
let bytes = match event { let bytes = match event {
Event::Start(bytes) => bytes, Event::Start(bytes) => bytes,
Event::Empty(bytes) => bytes, Event::Empty(bytes) => bytes,
@ -143,9 +150,191 @@ impl FromXml for X {
} }
} }
/// Information about an MUC member. May contain [MUC status codes](https://xmpp.org/registrar/mucstatus.html).
#[derive(Debug, PartialEq, Eq)]
pub struct XUser {
pub item: XUserItem,
/// Code 110. The receiver is the user referred to in the presence stanza.
pub self_presence: bool,
/// Code 201. The room from which the presence stanza was sent was just created.
pub just_created: bool,
}
impl ToXml for XUser {
fn serialize(&self, output: &mut Vec<Event<'static>>) {
let mut tag = BytesStart::new("x");
tag.push_attribute(("xmlns", XMLNS_USER));
output.push(Event::Start(tag));
self.item.serialize(output);
if self.self_presence {
let mut meg = BytesStart::new("status");
meg.push_attribute(("code", "110"));
output.push(Event::Empty(meg));
}
if self.just_created {
let mut meg = BytesStart::new("status");
meg.push_attribute(("code", "201"));
output.push(Event::Empty(meg));
}
output.push(Event::End(BytesEnd::new("x")));
}
}
#[derive(Debug, PartialEq, Eq)]
pub struct XUserItem {
pub affiliation: Affiliation,
pub jid: Jid,
pub role: Role,
}
impl ToXml for XUserItem {
fn serialize(&self, output: &mut Vec<Event<'static>>) {
let mut meg = BytesStart::new("item");
meg.push_attribute(("affiliation", self.affiliation.to_str()));
meg.push_attribute(("role", self.role.to_str()));
meg.push_attribute(("jid", self.jid.to_string().as_str()));
output.push(Event::Empty(meg));
}
}
#[derive(Debug, PartialEq, Eq)]
pub enum Affiliation {
Owner,
Admin,
Member,
Outcast,
None,
}
impl Affiliation {
pub fn from_str(s: &str) -> Option<Self> {
match s {
"owner" => Some(Self::Owner),
"admin" => Some(Self::Admin),
"member" => Some(Self::Member),
"outcast" => Some(Self::Outcast),
"none" => Some(Self::None),
_ => None,
}
}
pub fn to_str(&self) -> &str {
match self {
Self::Owner => "owner",
Self::Admin => "admin",
Self::Member => "member",
Self::Outcast => "outcast",
Self::None => "none",
}
}
}
#[derive(Debug, PartialEq, Eq)]
pub enum Role {
Moderator,
Participant,
Visitor,
None,
}
impl Role {
pub fn from_str(s: &str) -> Option<Self> {
match s {
"moderator" => Some(Self::Moderator),
"participant" => Some(Self::Participant),
"visitor" => Some(Self::Visitor),
"none" => Some(Self::None),
_ => None,
}
}
pub fn to_str(&self) -> &str {
match self {
Self::Moderator => "moderator",
Self::Participant => "participant",
Self::Visitor => "visitor",
Self::None => "none",
}
}
}
#[derive(Debug, PartialEq, Eq)]
pub struct Delay {
pub from: Jid,
pub stamp: String,
}
impl ToXml for Delay {
fn serialize(&self, events: &mut Vec<Event>) {
let mut tag = BytesStart::new("delay");
tag.push_attribute(Attribute {
key: QName(b"xmlns"),
value: XMLNS_DELAY.as_bytes().into(),
});
tag.push_attribute(Attribute {
key: QName(b"from"),
value: self.from.to_string().into_bytes().into(),
});
tag.push_attribute(Attribute {
key: QName(b"stamp"),
value: self.stamp.as_bytes().into(),
});
events.push(Event::Empty(tag));
}
}
/// Message-stanza of a historic message.
///
/// Example:
/// ```xml
/// <message from="duqedadi@conference.example.com/misha" xml:lang="en" to="misha@example.com/tux" type="groupchat" id="7ca7cb14-b2af-49c9-bd90-05dabb1113a5">
/// <delay xmlns="urn:xmpp:delay" stamp="2024-05-17T16:05:28.440337Z" from="duqedadi@conference.example.com"/>
/// <body></body>
/// </message>
/// ```
#[derive(Debug, PartialEq, Eq)]
pub struct XmppHistoryMessage {
pub id: String,
pub to: Jid,
pub from: Jid,
pub delay: Delay,
pub body: String,
}
impl ToXml for XmppHistoryMessage {
fn serialize(&self, events: &mut Vec<Event<'static>>) {
let mut message_tag = BytesStart::new("message");
message_tag.push_attribute(Attribute {
key: QName(b"id"),
value: self.id.as_str().as_bytes().into(),
});
message_tag.push_attribute(Attribute {
key: QName(b"to"),
value: self.to.to_string().into_bytes().into(),
});
message_tag.push_attribute(Attribute {
key: QName(b"from"),
value: self.from.to_string().into_bytes().into(),
});
message_tag.push_attribute(Attribute {
key: QName(b"type"),
value: b"groupchat".into(),
});
events.push(Event::Start(message_tag));
self.delay.serialize(events);
let body_tag = BytesStart::new("body");
events.push(Event::Start(body_tag));
events.push(Event::Text(BytesText::new(self.body.to_string().as_str()).into_owned()));
events.push(Event::End(BytesEnd::new("body")));
events.push(Event::End(BytesEnd::new("message")));
}
}
#[cfg(test)] #[cfg(test)]
mod test { mod test {
use super::*; use super::*;
use crate::bind::{Name, Resource, Server};
use crate::testkit::assemble_string_from_event_flow;
#[test] #[test]
fn test_history_success_empty() { fn test_history_success_empty() {
@ -228,4 +417,40 @@ mod test {
}; };
assert_eq!(res, expected); assert_eq!(res, expected);
} }
#[test]
fn test_history_message_serialization() {
// Arrange
let history_message = XmppHistoryMessage {
id: "id".to_string(),
to: Jid {
name: Some(Name("sauer@example.com".into())),
server: Server("localhost".into()),
resource: Some(Resource("tester".into())),
},
from: Jid {
name: Some(Name("pepe".into())),
server: Server("rooms.localhost".into()),
resource: Some(Resource("sauer".into())),
},
delay: Delay {
from: Jid {
name: Some(Name("pepe".into())),
server: Server("rooms.localhost".into()),
resource: Some(Resource("tester".into())),
},
stamp: "2021-10-10T10:10:10Z".to_string(),
},
body: "Hello World.".to_string(),
};
let mut events = vec![];
let expected = r#"<message id="id" to="sauer@example.com@localhost/tester" from="pepe@rooms.localhost/sauer" type="groupchat"><delay xmlns="urn:xmpp:delay" from="pepe@rooms.localhost/tester" stamp="2021-10-10T10:10:10Z"/><body>Hello World.</body></message>"#;
// Act
history_message.serialize(&mut events);
let flow = assemble_string_from_event_flow(&events);
// Assert
assert_eq!(flow, expected);
}
} }

View File

@ -1,47 +1,34 @@
use quick_xml::events::{BytesStart, Event}; use quick_xml::events::{BytesStart, Event};
use crate::xml::*; use crate::xml::*;
use anyhow::{anyhow as ffail, Result}; use anyhow::{anyhow, Result};
use quick_xml::name::{Namespace, ResolveResult};
pub const XMLNS: &'static str = "jabber:iq:roster"; pub const XMLNS: &'static str = "jabber:iq:roster";
#[derive(PartialEq, Eq, Debug)] #[derive(PartialEq, Eq, Debug)]
pub struct RosterQuery; pub struct RosterQuery;
pub struct QueryParser(QueryParserInner);
enum QueryParserInner {
Initial,
InQuery,
}
impl Parser for QueryParser {
type Output = Result<RosterQuery>;
fn consume<'a>(
self: Self,
namespace: quick_xml::name::ResolveResult,
event: &quick_xml::events::Event<'a>,
) -> Continuation<Self, Self::Output> {
match self.0 {
QueryParserInner::Initial => match event {
Event::Start(_) => Continuation::Continue(QueryParser(QueryParserInner::InQuery)),
Event::Empty(_) => Continuation::Final(Ok(RosterQuery)),
_ => Continuation::Final(Err(ffail!("Unexpected XML event: {event:?}"))),
},
QueryParserInner::InQuery => match event {
Event::End(_) => Continuation::Final(Ok(RosterQuery)),
_ => Continuation::Final(Err(ffail!("Unexpected XML event: {event:?}"))),
},
}
}
}
impl FromXml for RosterQuery { impl FromXml for RosterQuery {
type P = QueryParser; type P = impl Parser<Output = Result<Self>>;
fn parse() -> Self::P { fn parse() -> Self::P {
QueryParser(QueryParserInner::Initial) #[coroutine]
|(mut namespace, mut event): (ResolveResult<'static>, &'static Event<'static>)| -> Result<Self> {
let ResolveResult::Bound(Namespace(ns)) = namespace else {
return Err(anyhow!("No namespace provided"));
};
match event {
Event::Start(_) => (),
Event::Empty(_) => return Ok(RosterQuery),
_ => return Err(anyhow!("Unexpected XML event: {event:?}")),
}
(namespace, event) = yield;
match event {
Event::End(_) => return Ok(RosterQuery),
_ => return Err(anyhow!("Unexpected XML event: {event:?}")),
}
}
} }
} }
@ -52,9 +39,43 @@ impl FromXmlTag for RosterQuery {
impl ToXml for RosterQuery { impl ToXml for RosterQuery {
fn serialize(&self, events: &mut Vec<Event<'static>>) { fn serialize(&self, events: &mut Vec<Event<'static>>) {
events.push(Event::Empty(BytesStart::new(format!( events.push(Event::Empty(BytesStart::new(format!(r#"query xmlns="{}""#, XMLNS))));
r#"query xmlns="{}""#, }
XMLNS }
))));
#[cfg(test)]
mod tests {
use super::*;
use crate::bind::{Jid, Name, Resource, Server};
use crate::client::{Iq, IqType};
#[test]
fn test_parse() -> Result<()> {
let input =
r#"<iq from='juliet@example.com/balcony' id='bv1bs71f' type='get'><query xmlns='jabber:iq:roster'/></iq>"#;
let result: Iq<RosterQuery> = parse(input)?;
assert_eq!(
result,
Iq {
from: Option::from(Jid {
name: Option::from(Name("juliet".into())),
server: Server("example.com".into()),
resource: Option::from(Resource("balcony".into())),
}),
id: "bv1bs71f".to_string(),
to: None,
r#type: IqType::Get,
body: RosterQuery,
}
);
Ok(())
}
#[test]
fn test_missing_namespace() {
let input = r#"<iq from='juliet@example.com/balcony' id='bv1bs71f' type='get'><query/></iq>"#;
assert!(parse::<Iq<RosterQuery>>(input).is_err());
} }
} }

View File

@ -1,7 +1,7 @@
use std::borrow::Borrow; use std::borrow::Borrow;
use anyhow::{anyhow, Result}; use anyhow::{anyhow, Result};
use quick_xml::events::{BytesStart, Event}; use quick_xml::events::{BytesEnd, BytesStart, Event};
use quick_xml::{NsReader, Writer}; use quick_xml::{NsReader, Writer};
use tokio::io::{AsyncBufRead, AsyncWrite}; use tokio::io::{AsyncBufRead, AsyncWrite};
@ -74,3 +74,16 @@ impl Success {
Ok(()) Ok(())
} }
} }
pub struct Failure;
impl Failure {
pub async fn write_xml(&self, writer: &mut Writer<impl AsyncWrite + Unpin>) -> Result<()> {
let event = BytesStart::new(r#"failure xmlns="urn:ietf:params:xml:ns:xmpp-sasl""#);
writer.write_event_async(Event::Start(event)).await?;
let event = BytesStart::new(r#"not-authorized"#);
writer.write_event_async(Event::Empty(event)).await?;
let event = BytesEnd::new(r#"failure"#);
writer.write_event_async(Event::End(event)).await?;
Ok(())
}
}

View File

@ -2,48 +2,30 @@ use quick_xml::events::{BytesStart, Event};
use crate::xml::*; use crate::xml::*;
use anyhow::{anyhow, Result}; use anyhow::{anyhow, Result};
use quick_xml::name::ResolveResult;
pub const XMLNS: &'static str = "urn:ietf:params:xml:ns:xmpp-session"; pub const XMLNS: &'static str = "urn:ietf:params:xml:ns:xmpp-session";
#[derive(PartialEq, Eq, Debug)] #[derive(PartialEq, Eq, Debug)]
pub struct Session; pub struct Session;
pub struct SessionParser(SessionParserInner);
enum SessionParserInner {
Initial,
InSession,
}
impl Parser for SessionParser {
type Output = Result<Session>;
fn consume<'a>(
self: Self,
namespace: quick_xml::name::ResolveResult,
event: &quick_xml::events::Event<'a>,
) -> Continuation<Self, Self::Output> {
match self.0 {
SessionParserInner::Initial => match event {
Event::Start(_) => {
Continuation::Continue(SessionParser(SessionParserInner::InSession))
}
Event::Empty(_) => Continuation::Final(Ok(Session)),
_ => Continuation::Final(Err(anyhow!("Unexpected XML event: {event:?}"))),
},
SessionParserInner::InSession => match event {
Event::End(_) => Continuation::Final(Ok(Session)),
_ => Continuation::Final(Err(anyhow!("Unexpected XML event: {event:?}"))),
},
}
}
}
impl FromXml for Session { impl FromXml for Session {
type P = SessionParser; type P = impl Parser<Output = Result<Self>>;
fn parse() -> Self::P { fn parse() -> Self::P {
SessionParser(SessionParserInner::Initial) #[coroutine]
|(mut namespace, mut event): (ResolveResult<'static>, &'static Event<'static>)| -> Result<Self> {
match event {
Event::Start(_) => (),
Event::Empty(_) => return Ok(Session),
_ => return Err(anyhow!("Unexpected XML event: {event:?}")),
}
(namespace, event) = yield;
match event {
Event::End(_) => return Ok(Session),
_ => return Err(anyhow!("Unexpected XML event: {event:?}")),
}
}
} }
} }
@ -54,9 +36,6 @@ impl FromXmlTag for Session {
impl ToXml for Session { impl ToXml for Session {
fn serialize(&self, events: &mut Vec<Event<'static>>) { fn serialize(&self, events: &mut Vec<Event<'static>>) {
events.push(Event::Empty(BytesStart::new(format!( events.push(Event::Empty(BytesStart::new(format!(r#"session xmlns="{}""#, XMLNS))));
r#"session xmlns="{}""#,
XMLNS
))));
} }
} }

View File

@ -6,8 +6,8 @@ use tokio::io::{AsyncBufRead, AsyncWrite};
use super::skip_text; use super::skip_text;
use anyhow::{anyhow, Result};
use crate::xml::ToXml; use crate::xml::ToXml;
use anyhow::{anyhow, Result};
pub static XMLNS: &'static str = "http://etherx.jabber.org/streams"; pub static XMLNS: &'static str = "http://etherx.jabber.org/streams";
pub static PREFIX: &'static str = "stream"; pub static PREFIX: &'static str = "stream";
@ -24,7 +24,17 @@ impl ClientStreamStart {
reader: &mut NsReader<impl AsyncBufRead + Unpin>, reader: &mut NsReader<impl AsyncBufRead + Unpin>,
buf: &mut Vec<u8>, buf: &mut Vec<u8>,
) -> Result<ClientStreamStart> { ) -> Result<ClientStreamStart> {
let incoming = skip_text!(reader, buf); let mut incoming = skip_text!(reader, buf);
if let Event::Decl(bytes) = incoming {
// this is <?xml ...> header
if let Some(encoding) = bytes.encoding() {
let encoding = encoding?;
if &*encoding != b"UTF-8" {
return Err(anyhow!("Unsupported encoding: {encoding:?}"));
}
}
incoming = skip_text!(reader, buf);
}
if let Event::Start(e) = incoming { if let Event::Start(e) = incoming {
let (ns, local) = reader.resolve_element(e.name()); let (ns, local) = reader.resolve_element(e.name());
if ns != ResolveResult::Bound(Namespace(XMLNS.as_bytes())) { if ns != ResolveResult::Bound(Namespace(XMLNS.as_bytes())) {
@ -44,10 +54,7 @@ impl ClientStreamStart {
let value = attr.unescape_value()?; let value = attr.unescape_value()?;
to = Some(value.to_string()); to = Some(value.to_string());
} }
( (ResolveResult::Bound(Namespace(b"http://www.w3.org/XML/1998/namespace")), b"lang") => {
ResolveResult::Bound(Namespace(b"http://www.w3.org/XML/1998/namespace")),
b"lang",
) => {
let value = attr.unescape_value()?; let value = attr.unescape_value()?;
lang = Some(value.to_string()); lang = Some(value.to_string());
} }
@ -124,21 +131,15 @@ pub struct Features {
} }
impl Features { impl Features {
pub async fn write_xml(&self, writer: &mut Writer<impl AsyncWrite + Unpin>) -> Result<()> { pub async fn write_xml(&self, writer: &mut Writer<impl AsyncWrite + Unpin>) -> Result<()> {
writer writer.write_event_async(Event::Start(BytesStart::new("stream:features"))).await?;
.write_event_async(Event::Start(BytesStart::new("stream:features")))
.await?;
if self.start_tls { if self.start_tls {
writer writer
.write_event_async(Event::Start(BytesStart::new( .write_event_async(Event::Start(BytesStart::new(
r#"starttls xmlns="urn:ietf:params:xml:ns:xmpp-tls""#, r#"starttls xmlns="urn:ietf:params:xml:ns:xmpp-tls""#,
))) )))
.await?; .await?;
writer writer.write_event_async(Event::Empty(BytesStart::new("required"))).await?;
.write_event_async(Event::Empty(BytesStart::new("required"))) writer.write_event_async(Event::End(BytesEnd::new("starttls"))).await?;
.await?;
writer
.write_event_async(Event::End(BytesEnd::new("starttls")))
.await?;
} }
if self.mechanisms { if self.mechanisms {
writer writer
@ -146,18 +147,10 @@ impl Features {
r#"mechanisms xmlns="urn:ietf:params:xml:ns:xmpp-sasl""#, r#"mechanisms xmlns="urn:ietf:params:xml:ns:xmpp-sasl""#,
))) )))
.await?; .await?;
writer writer.write_event_async(Event::Start(BytesStart::new(r#"mechanism"#))).await?;
.write_event_async(Event::Start(BytesStart::new(r#"mechanism"#))) writer.write_event_async(Event::Text(BytesText::new("PLAIN"))).await?;
.await?; writer.write_event_async(Event::End(BytesEnd::new("mechanism"))).await?;
writer writer.write_event_async(Event::End(BytesEnd::new("mechanisms"))).await?;
.write_event_async(Event::Text(BytesText::new("PLAIN")))
.await?;
writer
.write_event_async(Event::End(BytesEnd::new("mechanism")))
.await?;
writer
.write_event_async(Event::End(BytesEnd::new("mechanisms")))
.await?;
} }
if self.bind { if self.bind {
writer writer
@ -166,9 +159,7 @@ impl Features {
))) )))
.await?; .await?;
} }
writer writer.write_event_async(Event::End(BytesEnd::new("stream:features"))).await?;
.write_event_async(Event::End(BytesEnd::new("stream:features")))
.await?;
Ok(()) Ok(())
} }
} }
@ -178,32 +169,31 @@ mod test {
use super::*; use super::*;
#[tokio::test] #[tokio::test]
async fn client_stream_start_correct_parse() { async fn client_stream_start_correct_parse() -> Result<()> {
let input = r###"<stream:stream xmlns:stream="http://etherx.jabber.org/streams" to="vlnv.dev" version="1.0" xmlns="jabber:client" xml:lang="en" xmlns:xml="http://www.w3.org/XML/1998/namespace">"###; let input = r###"<stream:stream xmlns:stream="http://etherx.jabber.org/streams" to="example.com" version="1.0" xmlns="jabber:client" xml:lang="en" xmlns:xml="http://www.w3.org/XML/1998/namespace">"###;
let mut reader = NsReader::from_reader(input.as_bytes()); let mut reader = NsReader::from_reader(input.as_bytes());
let mut buf = vec![]; let mut buf = vec![];
let res = ClientStreamStart::parse(&mut reader, &mut buf) let res = ClientStreamStart::parse(&mut reader, &mut buf).await?;
.await
.unwrap();
assert_eq!( assert_eq!(
res, res,
ClientStreamStart { ClientStreamStart {
to: "vlnv.dev".to_owned(), to: "example.com".to_owned(),
lang: Some("en".to_owned()), lang: Some("en".to_owned()),
version: "1.0".to_owned() version: "1.0".to_owned()
} }
) );
Ok(())
} }
#[tokio::test] #[tokio::test]
async fn server_stream_start_write() { async fn server_stream_start_write() {
let input = ServerStreamStart { let input = ServerStreamStart {
from: "vlnv.dev".to_owned(), from: "example.com".to_owned(),
lang: "en".to_owned(), lang: "en".to_owned(),
id: "stream_id".to_owned(), id: "stream_id".to_owned(),
version: "1.0".to_owned(), version: "1.0".to_owned(),
}; };
let expected = r###"<stream:stream from="vlnv.dev" version="1.0" xmlns="jabber:client" xmlns:stream="http://etherx.jabber.org/streams" xml:lang="en" id="stream_id">"###; let expected = r###"<stream:stream from="example.com" version="1.0" xmlns="jabber:client" xmlns:stream="http://etherx.jabber.org/streams" xml:lang="en" id="stream_id">"###;
let mut output: Vec<u8> = vec![]; let mut output: Vec<u8> = vec![];
let mut writer = Writer::new(&mut output); let mut writer = Writer::new(&mut output);
input.write_xml(&mut writer).await.unwrap(); input.write_xml(&mut writer).await.unwrap();

View File

@ -0,0 +1,41 @@
use crate::xml::ToXml;
use quick_xml::events::{BytesEnd, BytesStart, Event};
/// Stream error condition
///
/// [Spec](https://xmpp.org/rfcs/rfc6120.html#streams-error-conditions).
pub enum StreamErrorKind {
/// The server has experienced a misconfiguration or other internal error that prevents it from servicing the stream.
InternalServerError,
/// The server is being shut down and all active streams are being closed.
SystemShutdown,
}
impl StreamErrorKind {
pub fn from_str(s: &str) -> Option<Self> {
match s {
"internal-server-error" => Some(Self::InternalServerError),
"system-shutdown" => Some(Self::SystemShutdown),
_ => None,
}
}
pub fn as_str(&self) -> &'static str {
match self {
Self::InternalServerError => "internal-server-error",
Self::SystemShutdown => "system-shutdown",
}
}
}
pub struct StreamError {
pub kind: StreamErrorKind,
}
impl ToXml for StreamError {
fn serialize(&self, events: &mut Vec<Event<'static>>) {
events.push(Event::Start(BytesStart::new("stream:error")));
events.push(Event::Empty(BytesStart::new(format!(
r#"{} xmlns="urn:ietf:params:xml:ns:xmpp-streams""#,
self.kind.as_str()
))));
events.push(Event::End(BytesEnd::new("stream:error")));
}
}

View File

@ -0,0 +1,12 @@
use quick_xml::events::Event;
use quick_xml::Writer;
use std::io::Cursor;
pub fn assemble_string_from_event_flow(events: &Vec<Event<'_>>) -> String {
let mut writer = Writer::new(Cursor::new(Vec::new()));
for event in events {
writer.write_event(event).unwrap();
}
let result = writer.into_inner().into_inner();
String::from_utf8(result).unwrap()
}

View File

@ -12,10 +12,7 @@ pub static XMLNS: &'static str = "urn:ietf:params:xml:ns:xmpp-tls";
pub struct StartTLS; pub struct StartTLS;
impl StartTLS { impl StartTLS {
pub async fn parse( pub async fn parse(reader: &mut NsReader<impl AsyncBufRead + Unpin>, buf: &mut Vec<u8>) -> Result<StartTLS> {
reader: &mut NsReader<impl AsyncBufRead + Unpin>,
buf: &mut Vec<u8>,
) -> Result<StartTLS> {
let incoming = skip_text!(reader, buf); let incoming = skip_text!(reader, buf);
if let Event::Empty(ref e) = incoming { if let Event::Empty(ref e) = incoming {
if e.name().0 == b"starttls" { if e.name().0 == b"starttls" {

View File

@ -1,59 +1,36 @@
use super::*; use super::*;
use derive_more::From;
#[derive(Default, Debug, PartialEq, Eq)] #[derive(Default, Debug, PartialEq, Eq)]
pub struct Ignore; pub struct Ignore;
#[derive(From)]
pub struct IgnoreParser(IgnoreParserInner);
enum IgnoreParserInner {
Initial,
InTag { name: Vec<u8>, depth: u8 },
}
impl Parser for IgnoreParser {
type Output = Result<Ignore>;
fn consume<'a>(
self: Self,
_: ResolveResult,
event: &Event<'a>,
) -> Continuation<Self, Self::Output> {
match self.0 {
IgnoreParserInner::Initial => match event {
Event::Start(bytes) => {
let name = bytes.name().0.to_owned();
Continuation::Continue(IgnoreParserInner::InTag { name, depth: 0 }.into())
}
Event::Empty(_) => Continuation::Final(Ok(Ignore)),
_ => Continuation::Final(Ok(Ignore)),
},
IgnoreParserInner::InTag { name, depth } => match event {
Event::End(bytes) if name == bytes.name().0 => {
if depth == 0 {
Continuation::Final(Ok(Ignore))
} else {
Continuation::Continue(
IgnoreParserInner::InTag {
name,
depth: depth - 1,
}
.into(),
)
}
}
_ => Continuation::Continue(IgnoreParserInner::InTag { name, depth }.into()),
},
}
}
}
impl FromXml for Ignore { impl FromXml for Ignore {
type P = IgnoreParser; type P = impl Parser<Output = Result<Self>>;
fn parse() -> Self::P { fn parse() -> Self::P {
IgnoreParserInner::Initial.into() #[coroutine]
|(mut namespace, mut event): (ResolveResult<'static>, &'static Event<'static>)| -> Result<Self> {
let mut depth = match event {
Event::Start(bytes) => 0,
Event::Empty(_) => return Ok(Ignore),
_ => return Ok(Ignore),
};
loop {
(namespace, event) = yield;
match event {
Event::End(_) => {
if depth == 0 {
return Ok(Ignore);
} else {
depth -= 1;
}
}
Event::Start(_) => {
depth += 1;
}
_ => (),
}
}
}
} }
} }

View File

@ -1,18 +1,48 @@
use std::ops::Coroutine; use std::ops::Coroutine;
use std::pin::Pin; use std::pin::Pin;
use quick_xml::NsReader;
use quick_xml::events::Event; use quick_xml::events::Event;
use quick_xml::name::ResolveResult; use quick_xml::name::ResolveResult;
use quick_xml::NsReader;
use anyhow::Result; use anyhow::Result;
mod ignore; mod ignore;
pub use ignore::Ignore; pub use ignore::Ignore;
/// Types which can be parsed from an XML input stream.
///
/// Example:
/// ```
/// #![feature(type_alias_impl_trait)]
/// #![feature(impl_trait_in_assoc_type)]
/// #![feature(coroutines)]
/// # use proto_xmpp::xml::FromXml;
/// # use quick_xml::events::Event;
/// # use quick_xml::name::ResolveResult;
/// # use proto_xmpp::xml::Parser;
/// # use anyhow::Result;
///
/// struct MyStruct;
/// impl FromXml for MyStruct {
/// type P = impl Parser<Output = Result<Self>>;
///
/// fn parse() -> Self::P {
/// #[coroutine]
/// |(mut namespace, mut event): (ResolveResult<'static>, &'static Event<'static>)| -> Result<Self> {
/// (namespace, event) = yield;
/// Ok(MyStruct)
/// }
/// }
/// }
/// ```
pub trait FromXml: Sized { pub trait FromXml: Sized {
/// The type of parser instances.
///
/// If the result type of the [parse] is anonymous, this type member can be defined by using `impl Trait`.
type P: Parser<Output = Result<Self>>; type P: Parser<Output = Result<Self>>;
/// Creates a new instance of a parser with an initial state.
fn parse() -> Self::P; fn parse() -> Self::P;
} }
@ -25,28 +55,28 @@ pub trait FromXmlTag: FromXml {
const NS: &'static str; const NS: &'static str;
} }
/// A stateful parser instance which consumes XML events until the parsing is complete.
///
/// Usually implemented with the experimental coroutine syntax, which yields to consume the next XML event,
/// and returns the final result when the parsing is done.
pub trait Parser: Sized { pub trait Parser: Sized {
type Output; type Output;
fn consume<'a>( /// Advance the parsing by one XML event.
self: Self, ///
namespace: ResolveResult, /// This method consumes `self`, but if the parsing is incomplete,
event: &Event<'a>, /// it will return the next state of the parser in the returned result.
) -> Continuation<Self, Self::Output>; /// Otherwise, it will return the final result of parsing.
fn consume<'a>(self: Self, namespace: ResolveResult, event: &Event<'a>) -> Continuation<Self, Self::Output>;
} }
impl<T, Out> Parser for T impl<T, Out> Parser for T
where where
T: Coroutine<(ResolveResult<'static>, &'static Event<'static>), Yield = (), Return = Out> T: Coroutine<(ResolveResult<'static>, &'static Event<'static>), Yield = (), Return = Out> + Unpin,
+ Unpin,
{ {
type Output = Out; type Output = Out;
fn consume<'a>( fn consume<'a>(mut self: Self, namespace: ResolveResult, event: &Event<'a>) -> Continuation<Self, Self::Output> {
mut self: Self,
namespace: ResolveResult,
event: &Event<'a>,
) -> Continuation<Self, Self::Output> {
let s = Pin::new(&mut self); let s = Pin::new(&mut self);
// this is a very rude workaround fixing the fact that rust coroutines // this is a very rude workaround fixing the fact that rust coroutines
// 1. don't support higher-kinded lifetimes (i.e. no `impl for <'a> Coroutine<Event<'a>>) // 1. don't support higher-kinded lifetimes (i.e. no `impl for <'a> Coroutine<Event<'a>>)
@ -59,8 +89,11 @@ where
} }
} }
/// The result of a single parser iteration.
pub enum Continuation<Parser, Res> { pub enum Continuation<Parser, Res> {
/// The parsing is complete and the final result is available.
Final(Res), Final(Res),
/// The parsing is not complete and more XML events are required.
Continue(Parser), Continue(Parser),
} }
@ -98,8 +131,8 @@ macro_rules! delegate_parsing {
Continuation::Final(Ok(res)) => break Ok(res.into()), Continuation::Final(Ok(res)) => break Ok(res.into()),
Continuation::Final(Err(err)) => break Err(err), Continuation::Final(Err(err)) => break Err(err),
Continuation::Continue(p) => { Continuation::Continue(p) => {
let (namespace, event) = yield; ($namespace, $event) = yield;
parser = p.consume(namespace, event); parser = p.consume($namespace, $event);
} }
} }
} }

View File

@ -37,52 +37,51 @@ impl AuthBody {
mod test { mod test {
use super::*; use super::*;
#[test] #[test]
fn test_returning_auth_body() { fn test_returning_auth_body() -> Result<()> {
let orig = b"\x00login\x00pass"; let orig = b"\x00login\x00pass";
let encoded = general_purpose::STANDARD.encode(orig); let encoded = general_purpose::STANDARD.encode(orig);
let expected = AuthBody { let expected = AuthBody {
login: "login".to_string(), login: "login".to_string(),
password: "pass".to_string(), password: "pass".to_string(),
}; };
let result = AuthBody::from_str(encoded.as_bytes()).unwrap(); let result = AuthBody::from_str(encoded.as_bytes())?;
assert_eq!(expected, result); assert_eq!(expected, result);
Ok(())
} }
#[test] #[test]
fn test_ignoring_first_segment() { fn test_ignoring_first_segment() -> Result<()> {
let orig = b"ignored\x00login\x00pass"; let orig = b"ignored\x00login\x00pass";
let encoded = general_purpose::STANDARD.encode(orig); let encoded = general_purpose::STANDARD.encode(orig);
let expected = AuthBody { let expected = AuthBody {
login: "login".to_string(), login: "login".to_string(),
password: "pass".to_string(), password: "pass".to_string(),
}; };
let result = AuthBody::from_str(encoded.as_bytes()).unwrap(); let result = AuthBody::from_str(encoded.as_bytes())?;
assert_eq!(expected, result); assert_eq!(expected, result);
Ok(())
} }
#[test] #[test]
fn test_returning_auth_body_with_empty_strings() { fn test_returning_auth_body_with_empty_strings() -> Result<()> {
let orig = b"\x00\x00"; let orig = b"\x00\x00";
let encoded = general_purpose::STANDARD.encode(orig); let encoded = general_purpose::STANDARD.encode(orig);
let expected = AuthBody { let expected = AuthBody {
login: "".to_string(), login: "".to_string(),
password: "".to_string(), password: "".to_string(),
}; };
let result = AuthBody::from_str(encoded.as_bytes()).unwrap(); let result = AuthBody::from_str(encoded.as_bytes())?;
assert_eq!(expected, result); assert_eq!(expected, result);
Ok(())
} }
#[test] #[test]
fn test_fail_if_size_less_then_3() { fn test_fail_if_size_less_then_3() {
let orig = b"login\x00pass"; let orig = b"login\x00pass";
let encoded = general_purpose::STANDARD.encode(orig); let encoded = general_purpose::STANDARD.encode(orig);
let expected = AuthBody {
login: "login".to_string(),
password: "pass".to_string(),
};
let result = AuthBody::from_str(encoded.as_bytes()); let result = AuthBody::from_str(encoded.as_bytes());
assert!(result.is_err()); assert!(result.is_err());
@ -92,10 +91,6 @@ mod test {
fn test_fail_if_size_greater_then_3() { fn test_fail_if_size_greater_then_3() {
let orig = b"first\x00login\x00pass\x00other"; let orig = b"first\x00login\x00pass\x00other";
let encoded = general_purpose::STANDARD.encode(orig); let encoded = general_purpose::STANDARD.encode(orig);
let expected = AuthBody {
login: "login".to_string(),
password: "pass".to_string(),
};
let result = AuthBody::from_str(encoded.as_bytes()); let result = AuthBody::from_str(encoded.as_bytes());
assert!(result.is_err()); assert!(result.is_err());

View File

@ -8,11 +8,12 @@ Some useful commands for development and testing.
Following commands require `OpenSSL` to be installed. It is provided as `openssl` package in Arch Linux. Following commands require `OpenSSL` to be installed. It is provided as `openssl` package in Arch Linux.
Generate self-signed TLS certificate: Generate self-signed TLS certificate. Mind the common name (CN) field, it should match the domain name of the server.
Example for localhost:
openssl req -x509 -newkey rsa:4096 -sha256 -days 365 -noenc \ openssl req -x509 -newkey rsa:4096 -sha256 -days 365 -noenc \
-keyout certs/xmpp.key -out certs/xmpp.pem \ -keyout certs/xmpp.key -out certs/xmpp.pem \
-subj "/CN=example.com" -subj "/CN=localhost"
Print content of a TLS certificate: Print content of a TLS certificate:

View File

@ -19,9 +19,15 @@ server_name = "irc.localhost"
listen_on = "127.0.0.1:5222" listen_on = "127.0.0.1:5222"
cert = "./certs/xmpp.pem" cert = "./certs/xmpp.pem"
key = "./certs/xmpp.key" key = "./certs/xmpp.key"
hostname = "localhost"
[storage] [storage]
db_path = "db.sqlite" db_path = "db.sqlite"
[tracing]
# otlp grpc endpoint
endpoint = "http://jaeger:4317"
service_name = "lavina"
``` ```
## With Docker Compose ## With Docker Compose
@ -40,6 +46,15 @@ services:
- '5222:5222' # xmpp - '5222:5222' # xmpp
- '6667:6667' # irc non-tls - '6667:6667' # irc non-tls
- '127.0.0.1:1380:8080' # management http (private) - '127.0.0.1:1380:8080' # management http (private)
# if you want to observe traces
jaeger:
image: "jaegertracing/all-in-one:1.56"
ports:
- "16686:16686" # web ui
- "4317:4317" # grpc ingest endpoint
environment:
- COLLECTOR_OTLP_ENABLED=true
- SPAN_STORAGE_TYPE=memory
``` ```
## With Cargo ## With Cargo
@ -52,3 +67,35 @@ Or you can build it and run manually:
cargo build --release cargo build --release
./target/release/lavina --config config.toml ./target/release/lavina --config config.toml
## Migrations
### Prerequisites
Install sqlx-cli into ~/.local/bin:
cargo install --locked sqlx-cli
### Steps
Migrations run on every application start. For manual run, use sqlx:
sqlx mig run \
--source ./crates/lavina-core/migrations/ \
--database-url sqlite://db.sqlite
To see current status:
sqlx mig info \
--source ./crates/lavina-core/migrations/ \
--database-url sqlite://db.sqlite
sqlx mig info outputs
0/installed first
1/installed msg author
2/installed created at for messages
3/installed dialogs
4/installed new challenges
5/pending message datetime

View File

@ -1 +1 @@
nightly-2024-02-08 nightly-2024-06-18

View File

@ -1 +1,2 @@
max_width = 120 max_width = 120
chain_width = 120

View File

@ -12,13 +12,16 @@ use prometheus::{Encoder, Registry as MetricsRegistry, TextEncoder};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use tokio::net::TcpListener; use tokio::net::TcpListener;
use lavina_core::auth::UpdatePasswordResult;
use lavina_core::player::{ChangeRoomTopicResult, PlayerConnectionResult, PlayerId, SendMessageResult};
use lavina_core::prelude::*; use lavina_core::prelude::*;
use lavina_core::repo::Storage; use lavina_core::room::RoomId;
use lavina_core::room::RoomRegistry;
use lavina_core::terminator::Terminator; use lavina_core::terminator::Terminator;
use lavina_core::LavinaCore;
use mgmt_api::*; use mgmt_api::*;
mod clustering;
type HttpResult<T> = std::result::Result<T, Infallible>; type HttpResult<T> = std::result::Result<T, Infallible>;
#[derive(Deserialize, Debug)] #[derive(Deserialize, Debug)]
@ -26,24 +29,18 @@ pub struct ServerConfig {
pub listen_on: SocketAddr, pub listen_on: SocketAddr,
} }
pub async fn launch( pub async fn launch(config: ServerConfig, metrics: MetricsRegistry, core: LavinaCore) -> Result<Terminator> {
config: ServerConfig,
metrics: MetricsRegistry,
rooms: RoomRegistry,
storage: Storage,
) -> Result<Terminator> {
log::info!("Starting the http service"); log::info!("Starting the http service");
let listener = TcpListener::bind(config.listen_on).await?; let listener = TcpListener::bind(config.listen_on).await?;
log::debug!("Listener started"); log::debug!("Listener started");
let terminator = Terminator::spawn(|rx| main_loop(listener, metrics, rooms, storage, rx.map(|_| ()))); let terminator = Terminator::spawn(|rx| main_loop(listener, metrics, core, rx.map(|_| ())));
Ok(terminator) Ok(terminator)
} }
async fn main_loop( async fn main_loop(
listener: TcpListener, listener: TcpListener,
metrics: MetricsRegistry, metrics: MetricsRegistry,
rooms: RoomRegistry, core: LavinaCore,
storage: Storage,
termination: impl Future<Output = ()>, termination: impl Future<Output = ()>,
) -> Result<()> { ) -> Result<()> {
pin!(termination); pin!(termination);
@ -55,13 +52,10 @@ async fn main_loop(
let (stream, _) = result?; let (stream, _) = result?;
let stream = TokioIo::new(stream); let stream = TokioIo::new(stream);
let metrics = metrics.clone(); let metrics = metrics.clone();
let rooms = rooms.clone(); let core = core.clone();
let storage = storage.clone();
tokio::task::spawn(async move { tokio::task::spawn(async move {
let registry = metrics.clone(); let svc_fn = service_fn(|r| route(&metrics, &core, r));
let rooms = rooms.clone(); let server = http1::Builder::new().serve_connection(stream, svc_fn);
let storage = storage.clone();
let server = http1::Builder::new().serve_connection(stream, service_fn(move |r| route(registry.clone(), rooms.clone(), storage.clone(), r)));
if let Err(err) = server.await { if let Err(err) = server.await {
tracing::error!("Error serving connection: {:?}", err); tracing::error!("Error serving connection: {:?}", err);
} }
@ -73,90 +67,151 @@ async fn main_loop(
Ok(()) Ok(())
} }
#[tracing::instrument(skip_all)]
async fn route( async fn route(
registry: MetricsRegistry, registry: &MetricsRegistry,
rooms: RoomRegistry, core: &LavinaCore,
storage: Storage,
request: Request<hyper::body::Incoming>, request: Request<hyper::body::Incoming>,
) -> HttpResult<Response<Full<Bytes>>> { ) -> HttpResult<Response<Full<Bytes>>> {
propagade_span_from_headers(&request);
let res = match (request.method(), request.uri().path()) { let res = match (request.method(), request.uri().path()) {
(&Method::GET, "/metrics") => endpoint_metrics(registry), (&Method::GET, "/metrics") => endpoint_metrics(registry),
(&Method::GET, "/rooms") => endpoint_rooms(rooms).await, (&Method::GET, "/rooms") => endpoint_rooms(core).await,
(&Method::POST, paths::CREATE_PLAYER) => endpoint_create_player(request, storage).await.or5xx(), (&Method::POST, paths::CREATE_PLAYER) => endpoint_create_player(request, core).await.or5xx(),
(&Method::POST, paths::SET_PASSWORD) => endpoint_set_password(request, storage).await.or5xx(), (&Method::POST, paths::STOP_PLAYER) => endpoint_stop_player(request, core).await.or5xx(),
_ => not_found(), (&Method::POST, paths::SET_PASSWORD) => endpoint_set_password(request, core).await.or5xx(),
(&Method::POST, rooms::paths::SEND_MESSAGE) => endpoint_send_room_message(request, core).await.or5xx(),
(&Method::POST, rooms::paths::SET_TOPIC) => endpoint_set_room_topic(request, core).await.or5xx(),
_ => clustering::route(core, request).await.unwrap_or_else(endpoint_not_found),
}; };
Ok(res) Ok(res)
} }
fn endpoint_metrics(registry: MetricsRegistry) -> Response<Full<Bytes>> { fn endpoint_metrics(registry: &MetricsRegistry) -> Response<Full<Bytes>> {
let mf = registry.gather(); let mf = registry.gather();
let mut buffer = vec![]; let mut buffer = vec![];
TextEncoder.encode(&mf, &mut buffer).expect("write to vec cannot fail"); TextEncoder.encode(&mf, &mut buffer).expect("write to vec cannot fail");
Response::new(Full::new(Bytes::from(buffer))) Response::new(Full::new(Bytes::from(buffer)))
} }
async fn endpoint_rooms(rooms: RoomRegistry) -> Response<Full<Bytes>> { #[tracing::instrument(skip_all)]
async fn endpoint_rooms(core: &LavinaCore) -> Response<Full<Bytes>> {
// TODO introduce management API types independent from core-domain types // TODO introduce management API types independent from core-domain types
// TODO remove `Serialize` implementations from all core-domain types // TODO remove `Serialize` implementations from all core-domain types
let room_list = rooms.get_all_rooms().await.to_body(); let room_list = core.get_all_rooms().await.to_body();
Response::new(room_list) Response::new(room_list)
} }
#[tracing::instrument(skip_all)]
async fn endpoint_create_player( async fn endpoint_create_player(
request: Request<hyper::body::Incoming>, request: Request<hyper::body::Incoming>,
mut storage: Storage, core: &LavinaCore,
) -> Result<Response<Full<Bytes>>> { ) -> Result<Response<Full<Bytes>>> {
let str = request.collect().await?.to_bytes(); let str = request.collect().await?.to_bytes();
let Ok(res) = serde_json::from_slice::<CreatePlayerRequest>(&str[..]) else { let Ok(res) = serde_json::from_slice::<CreatePlayerRequest>(&str[..]) else {
let payload = ErrorResponse { return Ok(malformed_request());
code: errors::MALFORMED_REQUEST,
message: "The request payload contains incorrect JSON value",
}
.to_body();
let mut response = Response::new(payload);
*response.status_mut() = StatusCode::BAD_REQUEST;
return Ok(response);
}; };
storage.create_user(&res.name).await?; core.create_player(&PlayerId::from(res.name)?).await?;
log::info!("Player {} created", res.name); log::info!("Player {} created", res.name);
let mut response = Response::new(Full::<Bytes>::default()); let mut response = Response::new(Full::<Bytes>::default());
*response.status_mut() = StatusCode::CREATED; *response.status_mut() = StatusCode::CREATED;
Ok(response) Ok(response)
} }
#[tracing::instrument(skip_all)]
async fn endpoint_stop_player(
request: Request<hyper::body::Incoming>,
core: &LavinaCore,
) -> Result<Response<Full<Bytes>>> {
let str = request.collect().await?.to_bytes();
let Ok(res) = serde_json::from_slice::<StopPlayerRequest>(&str[..]) else {
return Ok(malformed_request());
};
let Ok(player_id) = PlayerId::from(res.name) else {
return Ok(player_not_found());
};
let Some(()) = core.stop_player(&player_id).await? else {
return Ok(player_not_found());
};
Ok(empty_204_request())
}
#[tracing::instrument(skip_all)]
async fn endpoint_set_password( async fn endpoint_set_password(
request: Request<hyper::body::Incoming>, request: Request<hyper::body::Incoming>,
mut storage: Storage, core: &LavinaCore,
) -> Result<Response<Full<Bytes>>> { ) -> Result<Response<Full<Bytes>>> {
let str = request.collect().await?.to_bytes(); let str = request.collect().await?.to_bytes();
let Ok(res) = serde_json::from_slice::<ChangePasswordRequest>(&str[..]) else { let Ok(res) = serde_json::from_slice::<ChangePasswordRequest>(&str[..]) else {
let payload = ErrorResponse { return Ok(malformed_request());
code: errors::MALFORMED_REQUEST,
message: "The request payload contains incorrect JSON value",
}
.to_body();
let mut response = Response::new(payload);
*response.status_mut() = StatusCode::BAD_REQUEST;
return Ok(response);
}; };
let Some(_) = storage.set_password(&res.player_name, &res.password).await? else { let verdict = core.set_password(&res.player_name, &res.password).await?;
let payload = ErrorResponse { match verdict {
code: errors::PLAYER_NOT_FOUND, UpdatePasswordResult::PasswordUpdated => {}
message: "No such player exists", UpdatePasswordResult::UserNotFound => {
return Ok(player_not_found());
} }
.to_body(); }
let mut response = Response::new(payload); Ok(empty_204_request())
*response.status_mut() = StatusCode::UNPROCESSABLE_ENTITY;
return Ok(response);
};
log::info!("Password changed for player {}", res.player_name);
let mut response = Response::new(Full::<Bytes>::default());
*response.status_mut() = StatusCode::NO_CONTENT;
Ok(response)
} }
pub fn not_found() -> Response<Full<Bytes>> { #[tracing::instrument(skip_all)]
async fn endpoint_send_room_message(
request: Request<hyper::body::Incoming>,
core: &LavinaCore,
) -> Result<Response<Full<Bytes>>> {
let str = request.collect().await?.to_bytes();
let Ok(req) = serde_json::from_slice::<rooms::SendMessageReq>(&str[..]) else {
return Ok(malformed_request());
};
let Ok(room_id) = RoomId::try_from(req.room_id) else {
return Ok(room_not_found());
};
let Ok(player_id) = PlayerId::from(req.author_id) else {
return Ok(player_not_found());
};
let mut connection = match core.connect_to_player(&player_id).await? {
PlayerConnectionResult::Success(connection) => connection,
PlayerConnectionResult::PlayerNotFound => {
return Ok(player_not_found());
}
};
let res = connection.send_message(room_id, req.message.into()).await?;
match res {
SendMessageResult::NoSuchRoom => Ok(room_not_found()),
SendMessageResult::Success(_) => Ok(empty_204_request()),
}
}
#[tracing::instrument(skip_all)]
async fn endpoint_set_room_topic(
request: Request<hyper::body::Incoming>,
core: &LavinaCore,
) -> Result<Response<Full<Bytes>>> {
let str = request.collect().await?.to_bytes();
let Ok(req) = serde_json::from_slice::<rooms::SetTopicReq>(&str[..]) else {
return Ok(malformed_request());
};
let Ok(room_id) = RoomId::try_from(req.room_id) else {
return Ok(room_not_found());
};
let Ok(player_id) = PlayerId::from(req.author_id) else {
return Ok(player_not_found());
};
let mut connection = match core.connect_to_player(&player_id).await? {
PlayerConnectionResult::Success(connection) => connection,
PlayerConnectionResult::PlayerNotFound => {
return Ok(player_not_found());
}
};
let res = connection.change_topic(room_id, req.topic.into()).await?;
match res {
ChangeRoomTopicResult::Success => Ok(empty_204_request()),
ChangeRoomTopicResult::NoSuchRoom => Ok(room_not_found()),
}
}
fn endpoint_not_found() -> Response<Full<Bytes>> {
let payload = ErrorResponse { let payload = ErrorResponse {
code: errors::INVALID_PATH, code: errors::INVALID_PATH,
message: "The path does not exist", message: "The path does not exist",
@ -168,25 +223,64 @@ pub fn not_found() -> Response<Full<Bytes>> {
response response
} }
fn player_not_found() -> Response<Full<Bytes>> {
let payload = ErrorResponse {
code: errors::PLAYER_NOT_FOUND,
message: "No such player exists",
}
.to_body();
let mut response = Response::new(payload);
*response.status_mut() = StatusCode::UNPROCESSABLE_ENTITY;
response
}
fn room_not_found() -> Response<Full<Bytes>> {
let payload = ErrorResponse {
code: rooms::errors::ROOM_NOT_FOUND,
message: "No such room exists",
}
.to_body();
let mut response = Response::new(payload);
*response.status_mut() = StatusCode::UNPROCESSABLE_ENTITY;
response
}
fn malformed_request() -> Response<Full<Bytes>> {
let payload = ErrorResponse {
code: errors::MALFORMED_REQUEST,
message: "The request payload contains incorrect JSON value",
}
.to_body();
let mut response = Response::new(payload);
*response.status_mut() = StatusCode::BAD_REQUEST;
return response;
}
fn empty_204_request() -> Response<Full<Bytes>> {
let mut response = Response::new(Full::<Bytes>::default());
*response.status_mut() = StatusCode::NO_CONTENT;
response
}
trait Or5xx { trait Or5xx {
fn or5xx(self) -> Response<Full<Bytes>>; fn or5xx(self) -> Response<Full<Bytes>>;
} }
impl Or5xx for Result<Response<Full<Bytes>>> { impl Or5xx for Result<Response<Full<Bytes>>> {
fn or5xx(self) -> Response<Full<Bytes>> { fn or5xx(self) -> Response<Full<Bytes>> {
match self { self.unwrap_or_else(|e| {
Ok(e) => e, let mut response = Response::new(Full::new(e.to_string().into()));
Err(e) => { *response.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
let mut response = Response::new(Full::new(e.to_string().into())); response
*response.status_mut() = StatusCode::INTERNAL_SERVER_ERROR; })
response
}
}
} }
} }
trait ToBody { trait ToBody {
fn to_body(&self) -> Full<Bytes>; fn to_body(&self) -> Full<Bytes>;
} }
impl<T> ToBody for T impl<T> ToBody for T
where where
T: Serialize, T: Serialize,
@ -197,3 +291,24 @@ where
Full::new(Bytes::from(buffer)) Full::new(Bytes::from(buffer))
} }
} }
fn propagade_span_from_headers<T>(req: &Request<T>) {
use opentelemetry::propagation::Extractor;
use tracing::Span;
use tracing_opentelemetry::OpenTelemetrySpanExt;
struct HttpReqExtractor<'a, T> {
req: &'a Request<T>,
}
impl<'a, T> Extractor for HttpReqExtractor<'a, T> {
fn get(&self, key: &str) -> Option<&str> {
self.req.headers().get(key).and_then(|v| v.to_str().ok())
}
fn keys(&self) -> Vec<&str> {
self.req.headers().keys().map(|k| k.as_str()).collect()
}
}
let ctx = opentelemetry::global::get_text_map_propagator(|pp| pp.extract(&HttpReqExtractor { req }));
Span::current().set_parent(ctx);
}

72
src/http/clustering.rs Normal file
View File

@ -0,0 +1,72 @@
use http_body_util::{BodyExt, Full};
use hyper::body::Bytes;
use hyper::{Method, Request, Response};
use super::Or5xx;
use crate::http::{empty_204_request, malformed_request, player_not_found, room_not_found};
use lavina_core::clustering::room::{paths, JoinRoomReq, SendMessageReq};
use lavina_core::player::PlayerId;
use lavina_core::room::RoomId;
use lavina_core::LavinaCore;
// TODO move this into core
pub async fn route(core: &LavinaCore, request: Request<hyper::body::Incoming>) -> Option<Response<Full<Bytes>>> {
match (request.method(), request.uri().path()) {
(&Method::POST, paths::JOIN) => Some(endpoint_cluster_join_room(request, core).await.or5xx()),
(&Method::POST, paths::ADD_MESSAGE) => Some(endpoint_cluster_add_message(request, core).await.or5xx()),
_ => None,
}
}
#[tracing::instrument(skip_all, name = "endpoint_cluster_join_room")]
async fn endpoint_cluster_join_room(
request: Request<hyper::body::Incoming>,
core: &LavinaCore,
) -> lavina_core::prelude::Result<Response<Full<Bytes>>> {
let str = request.collect().await?.to_bytes();
let Ok(req) = serde_json::from_slice::<JoinRoomReq>(&str[..]) else {
return Ok(malformed_request());
};
tracing::info!("Incoming request: {:?}", &req);
let Ok(room_id) = RoomId::try_from(req.room_id) else {
dbg!(&req.room_id);
return Ok(room_not_found());
};
let Ok(player_id) = PlayerId::from(req.player_id) else {
dbg!(&req.player_id);
return Ok(player_not_found());
};
core.cluster_join_room(room_id, &player_id).await?;
Ok(empty_204_request())
}
#[tracing::instrument(skip_all, name = "endpoint_cluster_add_message")]
async fn endpoint_cluster_add_message(
request: Request<hyper::body::Incoming>,
core: &LavinaCore,
) -> lavina_core::prelude::Result<Response<Full<Bytes>>> {
let str = request.collect().await?.to_bytes();
let Ok(req) = serde_json::from_slice::<SendMessageReq>(&str[..]) else {
return Ok(malformed_request());
};
tracing::info!("Incoming request: {:?}", &req);
let Ok(created_at) = chrono::DateTime::parse_from_rfc3339(req.created_at) else {
dbg!(&req.created_at);
return Ok(malformed_request());
};
let Ok(room_id) = RoomId::try_from(req.room_id) else {
dbg!(&req.room_id);
return Ok(room_not_found());
};
let Ok(player_id) = PlayerId::from(req.player_id) else {
dbg!(&req.player_id);
return Ok(player_not_found());
};
let res = core.cluster_send_room_message(room_id, &player_id, req.message.into(), created_at.to_utc()).await?;
if let Some(_) = res {
Ok(empty_204_request())
} else {
Ok(room_not_found())
}
}

View File

@ -6,13 +6,23 @@ use std::path::Path;
use clap::Parser; use clap::Parser;
use figment::providers::Format; use figment::providers::Format;
use figment::{providers::Toml, Figment}; use figment::{providers::Toml, Figment};
use opentelemetry::global::set_text_map_propagator;
use opentelemetry::KeyValue;
use opentelemetry_otlp::WithExportConfig;
use opentelemetry_sdk::propagation::TraceContextPropagator;
use opentelemetry_sdk::trace::{BatchConfig, RandomIdGenerator, Sampler};
use opentelemetry_sdk::{runtime, Resource};
use opentelemetry_semantic_conventions::resource::SERVICE_NAME;
use opentelemetry_semantic_conventions::SCHEMA_URL;
use prometheus::Registry as MetricsRegistry; use prometheus::Registry as MetricsRegistry;
use serde::Deserialize; use serde::Deserialize;
use tracing_opentelemetry::OpenTelemetryLayer;
use tracing_subscriber::fmt::Subscriber;
use tracing_subscriber::prelude::*;
use lavina_core::player::PlayerRegistry;
use lavina_core::prelude::*; use lavina_core::prelude::*;
use lavina_core::repo::Storage; use lavina_core::repo::Storage;
use lavina_core::room::RoomRegistry; use lavina_core::LavinaCore;
#[derive(Deserialize, Debug)] #[derive(Deserialize, Debug)]
struct ServerConfig { struct ServerConfig {
@ -20,6 +30,14 @@ struct ServerConfig {
irc: projection_irc::ServerConfig, irc: projection_irc::ServerConfig,
xmpp: projection_xmpp::ServerConfig, xmpp: projection_xmpp::ServerConfig,
storage: lavina_core::repo::StorageConfig, storage: lavina_core::repo::StorageConfig,
cluster: lavina_core::clustering::ClusterConfig,
tracing: Option<TracingConfig>,
}
#[derive(Deserialize, Debug)]
struct TracingConfig {
endpoint: String,
service_name: String,
} }
#[derive(Parser)] #[derive(Parser)]
@ -37,9 +55,9 @@ fn load_config() -> Result<ServerConfig> {
#[tokio::main] #[tokio::main]
async fn main() -> Result<()> { async fn main() -> Result<()> {
set_up_logging()?;
let sleep = ctrl_c()?; let sleep = ctrl_c()?;
let config = load_config()?; let config = load_config()?;
set_up_logging(&config.tracing)?;
tracing::info!("Booting up"); tracing::info!("Booting up");
tracing::info!("Loaded config: {config:?}"); tracing::info!("Loaded config: {config:?}");
@ -48,21 +66,15 @@ async fn main() -> Result<()> {
irc: irc_config, irc: irc_config,
xmpp: xmpp_config, xmpp: xmpp_config,
storage: storage_config, storage: storage_config,
cluster: cluster_config,
tracing: _,
} = config; } = config;
let mut metrics = MetricsRegistry::new(); let mut metrics = MetricsRegistry::new();
let storage = Storage::open(storage_config).await?; let storage = Storage::open(storage_config).await?;
let rooms = RoomRegistry::new(&mut metrics, storage.clone())?; let core = LavinaCore::new(&mut metrics, cluster_config, storage).await?;
let mut players = PlayerRegistry::empty(rooms.clone(), &mut metrics)?; let telemetry_terminator = http::launch(telemetry_config, metrics.clone(), core.clone()).await?;
let telemetry_terminator = http::launch(telemetry_config, metrics.clone(), rooms.clone(), storage.clone()).await?; let irc = projection_irc::launch(irc_config, core.clone(), metrics.clone()).await?;
let irc = projection_irc::launch( let xmpp = projection_xmpp::launch(xmpp_config, core.clone(), metrics.clone()).await?;
irc_config,
players.clone(),
rooms.clone(),
metrics.clone(),
storage.clone(),
)
.await?;
let xmpp = projection_xmpp::launch(xmpp_config, players.clone(), rooms.clone(), metrics.clone(), storage.clone()).await?;
tracing::info!("Started"); tracing::info!("Started");
sleep.await; sleep.await;
@ -71,10 +83,8 @@ async fn main() -> Result<()> {
xmpp.terminate().await?; xmpp.terminate().await?;
irc.terminate().await?; irc.terminate().await?;
telemetry_terminator.terminate().await?; telemetry_terminator.terminate().await?;
players.shutdown_all().await?; let storage = core.shutdown().await;
drop(players); storage.close().await;
drop(rooms);
storage.close().await?;
tracing::info!("Shutdown complete"); tracing::info!("Shutdown complete");
Ok(()) Ok(())
} }
@ -99,7 +109,45 @@ fn ctrl_c() -> Result<impl Future<Output = ()>> {
Ok(recv(chan)) Ok(recv(chan))
} }
fn set_up_logging() -> Result<()> { fn set_up_logging(tracing_config: &Option<TracingConfig>) -> Result<()> {
tracing_subscriber::fmt::init(); let subscriber = tracing_subscriber::registry().with(tracing_subscriber::fmt::layer());
let targets = {
use std::{env, str::FromStr};
use tracing_subscriber::filter::Targets;
match env::var("RUST_LOG") {
Ok(var) => Targets::from_str(&var)
.map_err(|e| {
eprintln!("Ignoring `RUST_LOG={:?}`: {}", var, e);
})
.unwrap_or_default(),
Err(env::VarError::NotPresent) => Targets::new().with_default(Subscriber::DEFAULT_MAX_LEVEL),
Err(e) => {
eprintln!("Ignoring `RUST_LOG`: {}", e);
Targets::new().with_default(Subscriber::DEFAULT_MAX_LEVEL)
}
}
};
if let Some(config) = tracing_config {
let trace_config = opentelemetry_sdk::trace::Config::default()
.with_sampler(Sampler::ParentBased(Box::new(Sampler::TraceIdRatioBased(1.0))))
.with_id_generator(RandomIdGenerator::default())
.with_resource(Resource::from_schema_url(
[KeyValue::new(SERVICE_NAME, config.service_name.to_string())],
SCHEMA_URL,
));
let trace_exporter = opentelemetry_otlp::new_exporter().tonic().with_endpoint(&config.endpoint);
let tracer = opentelemetry_otlp::new_pipeline()
.tracing()
.with_trace_config(trace_config)
.with_batch_config(BatchConfig::default())
.with_exporter(trace_exporter)
.install_batch(runtime::Tokio)?;
let subscriber = subscriber.with(OpenTelemetryLayer::new(tracer));
set_text_map_propagator(TraceContextPropagator::new());
targets.with_subscriber(subscriber).try_init()?;
} else {
targets.with_subscriber(subscriber).try_init()?;
}
Ok(()) Ok(())
} }