Compare commits

...

34 Commits

Author SHA1 Message Date
Mikhail a22017ee8f Sketch x tag internals 2024-05-15 13:05:24 +02:00
Nikita Vilunov b2a06ef984 xmpp: add x-user element to muc presence response 2024-05-14 03:21:06 +02:00
Nikita Vilunov 89918d9de1 xmpp: add item-not-found error condition to room disco#info iq 2024-05-13 16:42:52 +02:00
Nikita Vilunov 26720a2a08 core: separate the model from the logic implementation (#66)
This separates the core in two layers – the model objects and the `LavinaCore` service. Service is responsible for implementing the application logic and exposing it as core's public API to projections, while the model objects will be independent of each other and responsible only for managing and owning in-memory data.

The model objects include:
1. `Storage` – the open connection to the SQLite DB.
2. `PlayerRegistry` – creates, stores refs to, and stops player actors.
3. `RoomRegistry` – manages active rooms.
4. `DialogRegistry` – manages active dialogs.
5. `Broadcasting` – manages subscriptions of players to rooms on remote cluster nodes.
6. `LavinaClient` – manages HTTP connections to remote cluster nodes.
7. `ClusterMetadata` – read-only configuration of the cluster metadata, i.e. allocation of entities to nodes.

As a result:
1. Model objects will be fully independent of each other, e.g. it's no longer necessary to provide a `Storage` to all registries, or to provide `PlayerRegistry` and `DialogRegistry` to each other.
2. Model objects will no longer be `Arc`-wrapped; instead the whole `Services` object will be `Arc`ed and provided to projections.
3. The public API of `lavina-core` will be properly delimited by the APIs of `LavinaCore`, `PlayerConnection` and so on.
4. `LavinaCore` and `PlayerConnection` will also contain APIs of all features, unlike it was previously with `RoomRegistry` and `DialogRegistry`. This is unfortunate, but it could be improved in future.

Reviewed-on: lavina/lavina#66
2024-05-13 14:32:45 +00:00
Nikita Vilunov d1a72e7978 move repo methods into submodules and clean up warnings 2024-05-10 23:50:34 +02:00
Nikita Vilunov 6749103726 scalability: initial support for remote rooms (#61)
Reviewed-on: lavina/lavina#61
2024-05-10 20:44:24 +00:00
Mikhail 3b454ad7cd xmpp: unit-tests for resource bind it and muc presence
Reviewed-on: lavina/lavina#64
Co-authored-by: Mikhail <mikhail@liamets.dev>
Co-committed-by: Mikhail <mikhail@liamets.dev>
2024-05-10 13:35:34 +00:00
Mikhail 5512a74999 Check if user is a member before inserting a membership (#62)
It would typically fail on insertion due to uniqueness constraints: user id - room id.

Reviewed-on: lavina/lavina#62
Co-authored-by: Mikhail <mikhail@liamets.dev>
Co-committed-by: Mikhail <mikhail@liamets.dev>
2024-05-08 22:10:32 +00:00
Nikita Vilunov 7f2c6a1013 continue propagated traces in http request handlers 2024-05-05 19:24:58 +02:00
Nikita Vilunov bb0fe3bf0b use borrows in http endpoint handlers 2024-05-05 19:24:42 +02:00
Nikita Vilunov 8ac64ba8f5 get rid of storage usages in projections 2024-05-05 19:24:23 +02:00
homycdev abe9a26925 irc: implement WHOIS command (#43)
Reviewed-on: lavina/lavina#43
Co-authored-by: homycdev <abdulkhamid98@gmail.com>
Co-committed-by: homycdev <abdulkhamid98@gmail.com>
2024-05-05 17:21:40 +00:00
Mikhail adf1d8c14c xmpp: Implement Message Archive Management stub for XEP-0313 (#60)
https://xmpp.org/extensions/xep-0313.html
Reviewed-on: lavina/lavina#60
Co-authored-by: Mikhail <mikhail@liamets.dev>
Co-committed-by: Mikhail <mikhail@liamets.dev>
2024-05-05 15:12:58 +00:00
Nikita Vilunov 9a09ff717e management api endpoints for rooms 2024-05-01 17:30:31 +02:00
Nikita Vilunov a87f7c9d73 xmpp: extract common fragments of integration tests into functions 2024-04-29 23:56:18 +02:00
Nikita Vilunov 25605322a0 player shutdown API (#58)
Reviewed-on: lavina/lavina#58
2024-04-29 17:24:43 +00:00
Nikita Vilunov 31f9da9b05 xmpp: fix incorrect auth test 2024-04-29 19:13:32 +02:00
Nikita Vilunov c1dc2df150 xmpp: document xml parsing types 2024-04-28 17:29:31 +02:00
Nikita Vilunov c69513f38b xmpp: use mutable namespace and event in parser coroutines 2024-04-28 17:19:31 +02:00
Nikita Vilunov 8ec9ecfe2c xmpp: handle incorrect credentials by replying with an error 2024-04-28 17:11:29 +02:00
Nikita Vilunov a047d55ab5 xmpp: handle correctly unavailable self-presence and improve basic test scenario 2024-04-28 15:43:22 +02:00
Nikita Vilunov ea81ddadfc dialog message persistence 2024-04-27 12:58:27 +02:00
Nikita Vilunov 4b5ab02322 start next version 2024-04-26 13:43:43 +02:00
Nikita Vilunov 843d0e9c82 bump version 2024-04-26 13:31:47 +02:00
Nikita Vilunov 72f5010988 clean up http.rs a little 2024-04-26 12:28:13 +02:00
Nikita Vilunov 4ff09ea05f tracing otlp exporter and instrumentation annotations (#57)
Resolves #56

Reviewed-on: lavina/lavina#57
2024-04-26 10:16:23 +00:00
Nikita Vilunov ec49489ef1 validate that rooms and dialogs are owned exclusively on shutdown 2024-04-23 19:18:46 +02:00
Nikita Vilunov d305f5bf77 argon2-based password hashing (#55)
Reviewed-on: lavina/lavina#55
2024-04-23 16:31:00 +00:00
Nikita Vilunov 799da8366c basic dialog implementation with irc and xmpp support (#53)
Reviewed-on: lavina/lavina#53
2024-04-23 16:26:40 +00:00
Nikita Vilunov d805061d5b refactor auth logic into a common module (#54)
Reviewed-on: lavina/lavina#54
2024-04-23 10:10:10 +00:00
Nikita Vilunov 6c08d69f41 sasl: remove unused code 2024-04-23 00:41:54 +02:00
Nikita Vilunov 12d30ca5c2 irc: implement server-time capability for incoming messages (#52)
Spec: https://ircv3.net/specs/extensions/server-time
Reviewed-on: lavina/lavina#52
2024-04-21 21:00:44 +00:00
Nikita Vilunov ddb348bee9 refactor lavina core by grouping public services into a new LavinaCore struct.
this will be useful in future when additional services will be introduced and passed as dependencies
2024-04-21 19:45:50 +02:00
Nikita Vilunov 5a09b743c9 return AlreadyJoined when a player attempts to join a room they are already in 2024-04-20 17:09:44 +02:00
62 changed files with 4066 additions and 774 deletions

2
.gitignore vendored
View File

@ -1,4 +1,4 @@
/target
/db.sqlite
*.sqlite
.idea/
.DS_Store

542
Cargo.lock generated
View File

@ -114,12 +114,57 @@ version = "1.0.82"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f538837af36e6f6a9be0faa67f9a314f8119e4e4b5867c6ab40ed60360142519"
[[package]]
name = "argon2"
version = "0.5.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3c3610892ee6e0cbce8ae2700349fcf8f98adb0dbfbee85aec3c9179d29cc072"
dependencies = [
"base64ct",
"blake2",
"cpufeatures",
"password-hash",
]
[[package]]
name = "assert_matches"
version = "1.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9b34d609dfbaf33d6889b2b7106d3ca345eacad44200913df5ba02bfd31d2ba9"
[[package]]
name = "async-stream"
version = "0.3.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cd56dd203fef61ac097dd65721a419ddccb106b2d2b70ba60a6b529f03961a51"
dependencies = [
"async-stream-impl",
"futures-core",
"pin-project-lite",
]
[[package]]
name = "async-stream-impl"
version = "0.3.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "16e62a023e7c117e27523144c5d2459f4397fcc3cab0085af8e2224f643a0193"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.60",
]
[[package]]
name = "async-trait"
version = "0.1.80"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c6fa2087f2753a7da8cc1c0dbfcf89579dd57458e36769de5ac750b4671737ca"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.60",
]
[[package]]
name = "atoi"
version = "2.0.0"
@ -144,6 +189,51 @@ version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f1fdabc7756949593fe60f30ec81974b613357de856987752631dea1e3394c80"
[[package]]
name = "axum"
version = "0.6.20"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3b829e4e32b91e643de6eafe82b1d90675f5874230191a4ffbc1b336dec4d6bf"
dependencies = [
"async-trait",
"axum-core",
"bitflags 1.3.2",
"bytes",
"futures-util",
"http 0.2.12",
"http-body 0.4.6",
"hyper 0.14.28",
"itoa",
"matchit 0.7.3",
"memchr",
"mime",
"percent-encoding",
"pin-project-lite",
"rustversion",
"serde",
"sync_wrapper",
"tower",
"tower-layer",
"tower-service",
]
[[package]]
name = "axum-core"
version = "0.3.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "759fa577a247914fd3f7f76d62972792636412fbfd634cd452f6a385a74d2d2c"
dependencies = [
"async-trait",
"bytes",
"futures-util",
"http 0.2.12",
"http-body 0.4.6",
"mime",
"rustversion",
"tower-layer",
"tower-service",
]
[[package]]
name = "backtrace"
version = "0.3.71"
@ -192,6 +282,15 @@ dependencies = [
"serde",
]
[[package]]
name = "blake2"
version = "0.10.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "46502ad458c9a52b69d4d4d32775c788b7a1b85e8bc9d482d92250fc0e3f8efe"
dependencies = [
"digest",
]
[[package]]
name = "block-buffer"
version = "0.10.4"
@ -339,6 +438,15 @@ version = "2.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "19d374276b40fb8bbdee95aef7c7fa6b5316ec764510eb64b8dd0e2ed0d7e7f5"
[[package]]
name = "crossbeam-channel"
version = "0.5.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ab3db02a9c5b5121e1e42fbdb1aeb65f5e02624cc58c43f2884c6ccac0b82f95"
dependencies = [
"crossbeam-utils",
]
[[package]]
name = "crossbeam-queue"
version = "0.3.11"
@ -601,8 +709,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "94b22e06ecb0110981051723910cbf0b5f5e09a2062dd7663334ee79a9d1286c"
dependencies = [
"cfg-if",
"js-sys",
"libc",
"wasi",
"wasm-bindgen",
]
[[package]]
@ -611,6 +721,37 @@ version = "0.28.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4271d37baee1b8c7e4b708028c57d816cf9d2434acb33a549475f78c181f6253"
[[package]]
name = "glob"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b"
[[package]]
name = "h2"
version = "0.3.26"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "81fe527a889e1532da5c525686d96d4c2e74cdd345badf8dfef9f6b39dd5f5e8"
dependencies = [
"bytes",
"fnv",
"futures-core",
"futures-sink",
"futures-util",
"http 0.2.12",
"indexmap 2.2.6",
"slab",
"tokio",
"tokio-util",
"tracing",
]
[[package]]
name = "hashbrown"
version = "0.12.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888"
[[package]]
name = "hashbrown"
version = "0.14.3"
@ -627,7 +768,7 @@ version = "0.8.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e8094feaf31ff591f651a2664fb9cfd92bba7a60ce3197265e9482ebe753c8f7"
dependencies = [
"hashbrown",
"hashbrown 0.14.3",
]
[[package]]
@ -684,6 +825,17 @@ dependencies = [
"windows-sys 0.52.0",
]
[[package]]
name = "http"
version = "0.2.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "601cbb57e577e2f5ef5be8e7b83f0f63994f25aa94d673e54a92d5c516d101f1"
dependencies = [
"bytes",
"fnv",
"itoa",
]
[[package]]
name = "http"
version = "1.1.0"
@ -695,6 +847,17 @@ dependencies = [
"itoa",
]
[[package]]
name = "http-body"
version = "0.4.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7ceab25649e9960c0311ea418d17bee82c0dcec1bd053b5f9a66e265a693bed2"
dependencies = [
"bytes",
"http 0.2.12",
"pin-project-lite",
]
[[package]]
name = "http-body"
version = "1.0.0"
@ -702,7 +865,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1cac85db508abc24a2e48553ba12a996e87244a0395ce011e62b37158745d643"
dependencies = [
"bytes",
"http",
"http 1.1.0",
]
[[package]]
@ -713,8 +876,8 @@ checksum = "0475f8b2ac86659c21b64320d5d653f9efe42acd2a4e560073ec61a155a34f1d"
dependencies = [
"bytes",
"futures-core",
"http",
"http-body",
"http 1.1.0",
"http-body 1.0.0",
"pin-project-lite",
]
@ -730,6 +893,30 @@ version = "1.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9"
[[package]]
name = "hyper"
version = "0.14.28"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bf96e135eb83a2a8ddf766e426a841d8ddd7449d5f00d34ea02b41d2f19eef80"
dependencies = [
"bytes",
"futures-channel",
"futures-core",
"futures-util",
"h2",
"http 0.2.12",
"http-body 0.4.6",
"httparse",
"httpdate",
"itoa",
"pin-project-lite",
"socket2",
"tokio",
"tower-service",
"tracing",
"want",
]
[[package]]
name = "hyper"
version = "1.3.1"
@ -739,8 +926,8 @@ dependencies = [
"bytes",
"futures-channel",
"futures-util",
"http",
"http-body",
"http 1.1.0",
"http-body 1.0.0",
"httparse",
"httpdate",
"itoa",
@ -750,6 +937,18 @@ dependencies = [
"want",
]
[[package]]
name = "hyper-timeout"
version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bbb958482e8c7be4bc3cf272a766a2b0bf1a6755e7a6ae777f017a31d11b13b1"
dependencies = [
"hyper 0.14.28",
"pin-project-lite",
"tokio",
"tokio-io-timeout",
]
[[package]]
name = "hyper-util"
version = "0.1.3"
@ -759,9 +958,9 @@ dependencies = [
"bytes",
"futures-channel",
"futures-util",
"http",
"http-body",
"hyper",
"http 1.1.0",
"http-body 1.0.0",
"hyper 1.3.1",
"pin-project-lite",
"socket2",
"tokio",
@ -803,6 +1002,16 @@ dependencies = [
"unicode-normalization",
]
[[package]]
name = "indexmap"
version = "1.9.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bd070e393353796e801d209ad339e89596eb4c8d430d18ede6a1cced8fafbd99"
dependencies = [
"autocfg",
"hashbrown 0.12.3",
]
[[package]]
name = "indexmap"
version = "2.2.6"
@ -810,7 +1019,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "168fb715dda47215e360912c096649d23d58bf392ac62f73919e831745e40f26"
dependencies = [
"equivalent",
"hashbrown",
"hashbrown 0.14.3",
]
[[package]]
@ -851,20 +1060,25 @@ dependencies = [
[[package]]
name = "lavina"
version = "0.0.2-dev"
version = "0.0.3-dev"
dependencies = [
"anyhow",
"assert_matches",
"chrono",
"clap",
"derive_more",
"figment",
"futures-util",
"http-body-util",
"hyper",
"hyper 1.3.1",
"hyper-util",
"lavina-core",
"mgmt-api",
"nonempty",
"opentelemetry",
"opentelemetry-otlp",
"opentelemetry-semantic-conventions",
"opentelemetry_sdk",
"projection-irc",
"projection-xmpp",
"prometheus",
@ -874,16 +1088,24 @@ dependencies = [
"serde_json",
"tokio",
"tracing",
"tracing-opentelemetry",
"tracing-subscriber",
]
[[package]]
name = "lavina-core"
version = "0.0.2-dev"
version = "0.0.3-dev"
dependencies = [
"anyhow",
"argon2",
"chrono",
"mgmt-api",
"opentelemetry",
"prometheus",
"rand_core",
"reqwest",
"reqwest-middleware",
"reqwest-tracing",
"serde",
"sqlx",
"tokio",
@ -944,6 +1166,18 @@ version = "0.4.21"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "90ed8c1e510134f979dbc4f070f87d4313098b704861a105fe34231c70a3901c"
[[package]]
name = "matchit"
version = "0.7.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0e7465ac9959cc2b1404e8e2367b43684a6d13790fe23056cc8c6c5a6b7bcb94"
[[package]]
name = "matchit"
version = "0.8.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "540f1c43aed89909c0cc0cc604e3bb2f7e7a341a3728a9e6cfe760e733cd11ed"
[[package]]
name = "md-5"
version = "0.10.6"
@ -962,7 +1196,7 @@ checksum = "6c8640c5d730cb13ebd907d8d04b52f55ac9a2eec55b440c8892f40d56c76c1d"
[[package]]
name = "mgmt-api"
version = "0.0.2-dev"
version = "0.0.3-dev"
dependencies = [
"serde",
]
@ -1097,6 +1331,89 @@ version = "1.19.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92"
[[package]]
name = "opentelemetry"
version = "0.22.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "900d57987be3f2aeb70d385fff9b27fb74c5723cc9a52d904d4f9c807a0667bf"
dependencies = [
"futures-core",
"futures-sink",
"js-sys",
"once_cell",
"pin-project-lite",
"thiserror",
"urlencoding",
]
[[package]]
name = "opentelemetry-otlp"
version = "0.15.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1a016b8d9495c639af2145ac22387dcb88e44118e45320d9238fbf4e7889abcb"
dependencies = [
"async-trait",
"futures-core",
"http 0.2.12",
"opentelemetry",
"opentelemetry-proto",
"opentelemetry-semantic-conventions",
"opentelemetry_sdk",
"prost",
"thiserror",
"tokio",
"tonic",
]
[[package]]
name = "opentelemetry-proto"
version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3a8fddc9b68f5b80dae9d6f510b88e02396f006ad48cac349411fbecc80caae4"
dependencies = [
"opentelemetry",
"opentelemetry_sdk",
"prost",
"tonic",
]
[[package]]
name = "opentelemetry-semantic-conventions"
version = "0.14.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f9ab5bd6c42fb9349dcf28af2ba9a0667f697f9bdcca045d39f2cec5543e2910"
[[package]]
name = "opentelemetry_sdk"
version = "0.22.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9e90c7113be649e31e9a0f8b5ee24ed7a16923b322c3c5ab6367469c049d6b7e"
dependencies = [
"async-trait",
"crossbeam-channel",
"futures-channel",
"futures-executor",
"futures-util",
"glob",
"once_cell",
"opentelemetry",
"ordered-float",
"percent-encoding",
"rand",
"thiserror",
"tokio",
"tokio-stream",
]
[[package]]
name = "ordered-float"
version = "4.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a76df7075c7d4d01fdcb46c912dd17fba5b60c78ea480b475f2b6ab6f666584e"
dependencies = [
"num-traits",
]
[[package]]
name = "overload"
version = "0.1.1"
@ -1126,6 +1443,17 @@ dependencies = [
"windows-targets 0.48.5",
]
[[package]]
name = "password-hash"
version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "346f04948ba92c43e8469c1ee6736c7563d71012b17d40745260fe106aac2166"
dependencies = [
"base64ct",
"rand_core",
"subtle",
]
[[package]]
name = "paste"
version = "1.0.14"
@ -1259,10 +1587,11 @@ dependencies = [
[[package]]
name = "projection-irc"
version = "0.0.2-dev"
version = "0.0.3-dev"
dependencies = [
"anyhow",
"bitflags 2.5.0",
"chrono",
"futures-util",
"lavina-core",
"nonempty",
@ -1277,7 +1606,7 @@ dependencies = [
[[package]]
name = "projection-xmpp"
version = "0.0.2-dev"
version = "0.0.3-dev"
dependencies = [
"anyhow",
"assert_matches",
@ -1311,9 +1640,32 @@ dependencies = [
"thiserror",
]
[[package]]
name = "prost"
version = "0.12.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d0f5d036824e4761737860779c906171497f6d55681139d8312388f8fe398922"
dependencies = [
"bytes",
"prost-derive",
]
[[package]]
name = "prost-derive"
version = "0.12.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "19de2de2a00075bf566bee3bd4db014b11587e84184d3f7a791bc17f1a8e9e48"
dependencies = [
"anyhow",
"itertools",
"proc-macro2",
"quote",
"syn 2.0.60",
]
[[package]]
name = "proto-irc"
version = "0.0.2-dev"
version = "0.0.3-dev"
dependencies = [
"anyhow",
"assert_matches",
@ -1325,7 +1677,7 @@ dependencies = [
[[package]]
name = "proto-xmpp"
version = "0.0.2-dev"
version = "0.0.3-dev"
dependencies = [
"anyhow",
"assert_matches",
@ -1425,18 +1777,18 @@ checksum = "adad44e29e4c806119491a7f06f03de4d1af22c3a680dd47f1e6e179439d1f56"
[[package]]
name = "reqwest"
version = "0.12.3"
version = "0.12.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3e6cc1e89e689536eb5aeede61520e874df5a4707df811cd5da4aa5fbb2aae19"
checksum = "566cafdd92868e0939d3fb961bd0dc25fcfaaed179291093b3d43e6b3150ea10"
dependencies = [
"base64 0.22.0",
"bytes",
"futures-core",
"futures-util",
"http",
"http-body",
"http 1.1.0",
"http-body 1.0.0",
"http-body-util",
"hyper",
"hyper 1.3.1",
"hyper-util",
"ipnet",
"js-sys",
@ -1458,6 +1810,39 @@ dependencies = [
"winreg",
]
[[package]]
name = "reqwest-middleware"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0209efb52486ad88136190094ee214759ef7507068b27992256ed6610eb71a01"
dependencies = [
"anyhow",
"async-trait",
"http 1.1.0",
"reqwest",
"serde",
"thiserror",
"tower-service",
]
[[package]]
name = "reqwest-tracing"
version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b253954a1979e02eabccd7e9c3d61d8f86576108baa160775e7f160bb4e800a3"
dependencies = [
"anyhow",
"async-trait",
"getrandom",
"http 1.1.0",
"matchit 0.8.2",
"opentelemetry",
"reqwest",
"reqwest-middleware",
"tracing",
"tracing-opentelemetry",
]
[[package]]
name = "ring"
version = "0.17.8"
@ -1552,6 +1937,12 @@ dependencies = [
"untrusted",
]
[[package]]
name = "rustversion"
version = "1.0.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "80af6f9131f277a45a3fba6ce8e2258037bb0477a67e610d3c1fe046ab31de47"
[[package]]
name = "ryu"
version = "1.0.17"
@ -1560,7 +1951,7 @@ checksum = "e86697c916019a8588c99b5fac3cead74ec0b4b819707a682fd4d23fa0ce1ba1"
[[package]]
name = "sasl"
version = "0.0.2-dev"
version = "0.0.3-dev"
dependencies = [
"anyhow",
"base64 0.22.0",
@ -1774,6 +2165,7 @@ dependencies = [
"atoi",
"byteorder",
"bytes",
"chrono",
"crc",
"crossbeam-queue",
"either",
@ -1785,7 +2177,7 @@ dependencies = [
"futures-util",
"hashlink",
"hex",
"indexmap",
"indexmap 2.2.6",
"log",
"memchr",
"once_cell",
@ -1832,6 +2224,7 @@ dependencies = [
"sha2",
"sqlx-core",
"sqlx-mysql",
"sqlx-postgres",
"sqlx-sqlite",
"syn 1.0.109",
"tempfile",
@ -1849,6 +2242,7 @@ dependencies = [
"bitflags 2.5.0",
"byteorder",
"bytes",
"chrono",
"crc",
"digest",
"dotenvy",
@ -1890,6 +2284,7 @@ dependencies = [
"base64 0.21.7",
"bitflags 2.5.0",
"byteorder",
"chrono",
"crc",
"dotenvy",
"etcetera",
@ -1925,6 +2320,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b244ef0a8414da0bed4bb1910426e890b19e5e9bccc27ada6b797d05c55ae0aa"
dependencies = [
"atoi",
"chrono",
"flume",
"futures-channel",
"futures-core",
@ -2068,6 +2464,16 @@ dependencies = [
"windows-sys 0.48.0",
]
[[package]]
name = "tokio-io-timeout"
version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "30b74022ada614a1b4834de765f9bb43877f910cc8ce4be40e89042c9223a8bf"
dependencies = [
"pin-project-lite",
"tokio",
]
[[package]]
name = "tokio-macros"
version = "2.2.0"
@ -2089,6 +2495,31 @@ dependencies = [
"tokio",
]
[[package]]
name = "tokio-stream"
version = "0.1.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "267ac89e0bec6e691e5813911606935d77c476ff49024f98abcea3e7b15e37af"
dependencies = [
"futures-core",
"pin-project-lite",
"tokio",
]
[[package]]
name = "tokio-util"
version = "0.7.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5419f34732d9eb6ee4c3578b7989078579b7f039cbbb9ca2c4da015749371e15"
dependencies = [
"bytes",
"futures-core",
"futures-sink",
"pin-project-lite",
"tokio",
"tracing",
]
[[package]]
name = "toml"
version = "0.8.12"
@ -2116,13 +2547,40 @@ version = "0.22.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fb686a972ccef8537b39eead3968b0e8616cb5040dbb9bba93007c8e07c9215f"
dependencies = [
"indexmap",
"indexmap 2.2.6",
"serde",
"serde_spanned",
"toml_datetime",
"winnow",
]
[[package]]
name = "tonic"
version = "0.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "76c4eb7a4e9ef9d4763600161f12f5070b92a578e1b634db88a6887844c91a13"
dependencies = [
"async-stream",
"async-trait",
"axum",
"base64 0.21.7",
"bytes",
"h2",
"http 0.2.12",
"http-body 0.4.6",
"hyper 0.14.28",
"hyper-timeout",
"percent-encoding",
"pin-project",
"prost",
"tokio",
"tokio-stream",
"tower",
"tower-layer",
"tower-service",
"tracing",
]
[[package]]
name = "tower"
version = "0.4.13"
@ -2131,9 +2589,13 @@ checksum = "b8fa9be0de6cf49e536ce1851f987bd21a43b771b09473c3549a6c853db37c1c"
dependencies = [
"futures-core",
"futures-util",
"indexmap 1.9.3",
"pin-project",
"pin-project-lite",
"rand",
"slab",
"tokio",
"tokio-util",
"tower-layer",
"tower-service",
"tracing",
@ -2195,6 +2657,24 @@ dependencies = [
"tracing-core",
]
[[package]]
name = "tracing-opentelemetry"
version = "0.23.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a9be14ba1bbe4ab79e9229f7f89fab8d120b865859f10527f31c033e599d2284"
dependencies = [
"js-sys",
"once_cell",
"opentelemetry",
"opentelemetry_sdk",
"smallvec",
"tracing",
"tracing-core",
"tracing-log",
"tracing-subscriber",
"web-time",
]
[[package]]
name = "tracing-subscriber"
version = "0.3.18"
@ -2416,6 +2896,16 @@ dependencies = [
"wasm-bindgen",
]
[[package]]
name = "web-time"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5a6580f308b1fad9207618087a65c04e7a10bc77e02c8e84e9b00dd4b12fa0bb"
dependencies = [
"js-sys",
"wasm-bindgen",
]
[[package]]
name = "whoami"
version = "1.5.1"

View File

@ -10,7 +10,7 @@ members = [
]
[workspace.package]
version = "0.0.2-dev"
version = "0.0.3-dev"
[workspace.dependencies]
nom = "7.1.3"
@ -31,6 +31,8 @@ base64 = "0.22.0"
lavina-core = { path = "crates/lavina-core" }
tracing-subscriber = "0.3.16"
sasl = { path = "crates/sasl" }
chrono = "0.4.37"
reqwest = { version = "0.12.0", default-features = false, features = ["json"] }
[package]
name = "lavina"
@ -58,8 +60,14 @@ projection-irc = { path = "crates/projection-irc" }
projection-xmpp = { path = "crates/projection-xmpp" }
mgmt-api = { path = "crates/mgmt-api" }
clap.workspace = true
opentelemetry = "0.22.0"
opentelemetry-semantic-conventions = "0.14.0"
opentelemetry_sdk = { version = "0.22.1", features = ["rt-tokio"] }
opentelemetry-otlp = "0.15.0"
tracing-opentelemetry = "0.23.0"
chrono.workspace = true
[dev-dependencies]
assert_matches.workspace = true
regex = "1.7.1"
reqwest = { version = "0.12.0", default-features = false }
reqwest.workspace = true

30
config.0.toml Normal file
View File

@ -0,0 +1,30 @@
[telemetry]
listen_on = "127.0.0.1:8080"
[irc]
listen_on = "127.0.0.1:6667"
server_name = "irc.localhost"
[xmpp]
listen_on = "127.0.0.1:5222"
cert = "./certs/xmpp.pem"
key = "./certs/xmpp.key"
hostname = "localhost"
[storage]
db_path = "db.0.sqlite"
[cluster]
addresses = [
"127.0.0.1:8080",
"127.0.0.1:8081",
]
[cluster.metadata]
node_id = 0
main_owner = 0
rooms = { aaaaa = 1, test = 0 }
[tracing]
endpoint = "http://localhost:4317"
service_name = "lavina-0"

30
config.1.toml Normal file
View File

@ -0,0 +1,30 @@
[telemetry]
listen_on = "127.0.0.1:8081"
[irc]
listen_on = "127.0.0.1:6668"
server_name = "irc.localhost"
[xmpp]
listen_on = "127.0.0.1:5223"
cert = "./certs/xmpp.pem"
key = "./certs/xmpp.key"
hostname = "localhost"
[storage]
db_path = "db.1.sqlite"
[cluster]
addresses = [
"127.0.0.1:8080",
"127.0.0.1:8081",
]
[cluster.metadata]
node_id = 1
main_owner = 0
rooms = { aaaaa = 1, test = 0 }
[tracing]
endpoint = "http://localhost:4317"
service_name = "lavina-1"

View File

@ -13,3 +13,11 @@ hostname = "localhost"
[storage]
db_path = "db.sqlite"
[cluster]
addresses = []
[cluster.metadata]
node_id = 0
main_owner = 0
rooms = {}

View File

@ -5,9 +5,16 @@ version.workspace = true
[dependencies]
anyhow.workspace = true
sqlx = { version = "0.7.4", features = ["sqlite", "migrate"] }
sqlx = { version = "0.7.4", features = ["sqlite", "migrate", "chrono"] }
serde.workspace = true
tokio.workspace = true
tracing.workspace = true
prometheus.workspace = true
chrono = "0.4.37"
chrono.workspace = true
argon2 = { version = "0.5.3" }
rand_core = { version = "0.6.4", features = ["getrandom"] }
reqwest.workspace = true
reqwest-middleware = { version = "0.3", features = ["json"] }
opentelemetry = "0.22.0"
mgmt-api = { path = "../mgmt-api" }
reqwest-tracing = { version = "0.5", features = ["opentelemetry_0_22"] }

View File

@ -0,0 +1,17 @@
create table dialogs(
id integer primary key autoincrement not null,
participant_1 integer not null,
participant_2 integer not null,
created_at timestamp not null,
message_count integer not null default 0,
unique (participant_1, participant_2)
);
create table dialog_messages(
dialog_id integer not null,
id integer not null, -- unique per dialog, sequential in one dialog
author_id integer not null,
content string not null,
created_at timestamp not null,
primary key (dialog_id, id)
);

View File

@ -0,0 +1,4 @@
create table challenges_argon2_password(
user_id integer primary key not null,
hash string not null
);

View File

@ -0,0 +1,58 @@
use anyhow::{anyhow, Result};
use argon2::password_hash::{PasswordHash, PasswordHasher, PasswordVerifier, SaltString};
use argon2::Argon2;
use rand_core::OsRng;
use crate::LavinaCore;
pub enum Verdict {
Authenticated,
UserNotFound,
InvalidPassword,
}
pub enum UpdatePasswordResult {
PasswordUpdated,
UserNotFound,
}
impl LavinaCore {
#[tracing::instrument(skip(self, provided_password), name = "Services::authenticate")]
pub async fn authenticate(&self, login: &str, provided_password: &str) -> Result<Verdict> {
let Some(stored_user) = self.services.storage.retrieve_user_by_name(login).await? else {
return Ok(Verdict::UserNotFound);
};
if let Some(argon2_hash) = stored_user.argon2_hash {
let argon2 = Argon2::default();
let password_hash =
PasswordHash::new(&argon2_hash).map_err(|e| anyhow!("Failed to parse password hash: {e:?}"))?;
let password_verifier = argon2.verify_password(provided_password.as_bytes(), &password_hash);
if password_verifier.is_ok() {
return Ok(Verdict::Authenticated);
}
}
if let Some(expected_password) = stored_user.password {
if expected_password == provided_password {
return Ok(Verdict::Authenticated);
}
}
Ok(Verdict::InvalidPassword)
}
#[tracing::instrument(skip(self, provided_password), name = "Services::set_password")]
pub async fn set_password(&self, login: &str, provided_password: &str) -> Result<UpdatePasswordResult> {
let Some(u) = self.services.storage.retrieve_user_by_name(login).await? else {
return Ok(UpdatePasswordResult::UserNotFound);
};
let salt = SaltString::generate(&mut OsRng);
let argon2 = Argon2::default();
let password_hash = argon2
.hash_password(provided_password.as_bytes(), &salt)
.map_err(|e| anyhow!("Failed to hash password: {e:?}"))?;
self.services.storage.set_argon2_challenge(u.id, password_hash.to_string().as_str()).await?;
tracing::info!("Password changed for player {login}");
Ok(UpdatePasswordResult::PasswordUpdated)
}
}

View File

@ -0,0 +1,56 @@
use std::collections::HashMap;
use std::net::SocketAddr;
use anyhow::{anyhow, Result};
use reqwest::Client;
use reqwest_middleware::{ClientBuilder, ClientWithMiddleware};
use reqwest_tracing::{DefaultSpanBackend, TracingMiddleware};
use serde::{Deserialize, Serialize};
pub mod broadcast;
pub mod room;
type Addresses = Vec<SocketAddr>;
#[derive(Deserialize, Debug, Clone)]
pub struct ClusterConfig {
pub metadata: ClusterMetadata,
pub addresses: Addresses,
}
#[derive(Deserialize, Debug, Clone)]
pub struct ClusterMetadata {
/// The node id of the current node.
pub node_id: u32,
/// Owns all rooms in the cluster except the ones specified in `rooms`.
pub main_owner: u32,
pub rooms: HashMap<String, u32>,
}
pub struct LavinaClient {
addresses: Addresses,
client: ClientWithMiddleware,
}
impl LavinaClient {
pub fn new(addresses: Addresses) -> Self {
let client = ClientBuilder::new(Client::new()).with(TracingMiddleware::<DefaultSpanBackend>::new()).build();
Self { addresses, client }
}
async fn send_request(&self, node_id: u32, path: &str, req: impl Serialize) -> Result<()> {
let Some(address) = self.addresses.get(node_id as usize) else {
return Err(anyhow!("Unknown node"));
};
match self.client.post(format!("http://{}{}", address, path)).json(&req).send().await {
Ok(res) => {
if res.status().is_server_error() || res.status().is_client_error() {
tracing::error!("Cluster request failed: {:?}", res);
return Err(anyhow!("Server error"));
}
Ok(())
}
Err(e) => Err(e.into()),
}
}
}

View File

@ -0,0 +1,53 @@
use std::collections::{HashMap, HashSet};
use chrono::{DateTime, Utc};
use tokio::sync::Mutex;
use crate::player::{PlayerId, Updates};
use crate::prelude::Str;
use crate::room::RoomId;
use crate::Services;
/// Receives updates from other nodes and broadcasts them to local player actors.
struct BroadcastingInner {
subscriptions: HashMap<RoomId, HashSet<PlayerId>>,
}
pub struct Broadcasting(Mutex<BroadcastingInner>);
impl Broadcasting {
pub fn new() -> Self {
let inner = BroadcastingInner {
subscriptions: HashMap::new(),
};
Self(Mutex::new(inner))
}
}
impl Services {
/// Broadcasts the given update to subscribed player actors on local node.
#[tracing::instrument(skip(self, message, created_at))]
pub async fn broadcast(&self, room_id: RoomId, author_id: PlayerId, message: Str, created_at: DateTime<Utc>) {
let inner = self.broadcasting.0.lock().await;
let Some(subscribers) = inner.subscriptions.get(&room_id) else {
return;
};
let update = Updates::NewMessage {
room_id: room_id.clone(),
author_id: author_id.clone(),
body: message.clone(),
created_at: created_at.clone(),
};
for i in subscribers {
if i == &author_id {
continue;
}
let Some(player) = self.players.get_player(i).await else {
continue;
};
player.update(update.clone()).await;
}
}
pub async fn subscribe(&self, subscriber: PlayerId, room_id: RoomId) {
self.broadcasting.0.lock().await.subscriptions.entry(room_id).or_insert_with(HashSet::new).insert(subscriber);
}
}

View File

@ -0,0 +1,88 @@
use anyhow::Result;
use serde::{Deserialize, Serialize};
use crate::clustering::LavinaClient;
use crate::player::PlayerId;
use crate::prelude::Str;
use crate::room::RoomId;
use crate::LavinaCore;
pub mod paths {
pub const JOIN: &'static str = "/cluster/rooms/join";
pub const LEAVE: &'static str = "/cluster/rooms/leave";
pub const ADD_MESSAGE: &'static str = "/cluster/rooms/add_message";
pub const SET_TOPIC: &'static str = "/cluster/rooms/set_topic";
}
#[derive(Serialize, Deserialize, Debug)]
pub struct JoinRoomReq<'a> {
pub room_id: &'a str,
pub player_id: &'a str,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct LeaveRoomReq<'a> {
pub room_id: &'a str,
pub player_id: &'a str,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct SendMessageReq<'a> {
pub room_id: &'a str,
pub player_id: &'a str,
pub message: &'a str,
pub created_at: &'a str,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct SetRoomTopicReq<'a> {
pub room_id: &'a str,
pub player_id: &'a str,
pub topic: &'a str,
}
impl LavinaClient {
#[tracing::instrument(skip(self, req), name = "LavinaClient::join_room")]
pub async fn join_room(&self, node_id: u32, req: JoinRoomReq<'_>) -> Result<()> {
self.send_request(node_id, paths::JOIN, req).await
}
#[tracing::instrument(skip(self, req), name = "LavinaClient::leave_room")]
pub async fn leave_room(&self, node_id: u32, req: LeaveRoomReq<'_>) -> Result<()> {
self.send_request(node_id, paths::LEAVE, req).await
}
#[tracing::instrument(skip(self, req), name = "LavinaClient::send_room_message")]
pub async fn send_room_message(&self, node_id: u32, req: SendMessageReq<'_>) -> Result<()> {
self.send_request(node_id, paths::ADD_MESSAGE, req).await
}
#[tracing::instrument(skip(self, req), name = "LavinaClient::set_room_topic")]
pub async fn set_room_topic(&self, node_id: u32, req: SetRoomTopicReq<'_>) -> Result<()> {
self.send_request(node_id, paths::SET_TOPIC, req).await
}
}
impl LavinaCore {
pub async fn cluster_join_room(&self, room_id: RoomId, player_id: &PlayerId) -> Result<()> {
let room_handle = self.services.rooms.get_or_create_room(&self.services, room_id).await?;
let storage_id =
self.services.storage.create_or_retrieve_user_id_by_name(player_id.as_inner().as_ref()).await?;
room_handle.add_member(&self.services, &player_id, storage_id).await;
Ok(())
}
pub async fn cluster_send_room_message(
&self,
room_id: RoomId,
player_id: &PlayerId,
message: Str,
created_at: chrono::DateTime<chrono::Utc>,
) -> Result<Option<()>> {
let Some(room_handle) = self.services.rooms.get_room(&self.services, &room_id).await else {
return Ok(None);
};
room_handle.send_message(&self.services, &player_id, message, created_at).await;
Ok(Some(()))
}
}

View File

@ -0,0 +1,150 @@
//! Domain of dialogs conversations between two participants.
//!
//! Dialogs are different from rooms in that they are always between two participants.
//! There are no admins or other roles in dialogs, both participants have equal rights.
use std::collections::HashMap;
use chrono::{DateTime, Utc};
use tokio::sync::RwLock as AsyncRwLock;
use crate::player::{PlayerId, Updates};
use crate::prelude::*;
use crate::Services;
/// Id of a conversation between two players.
///
/// Dialogs are identified by the pair of participants' ids. The order of ids does not matter.
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct DialogId(PlayerId, PlayerId);
impl DialogId {
pub fn new(a: PlayerId, b: PlayerId) -> DialogId {
if a.as_inner() < b.as_inner() {
DialogId(a, b)
} else {
DialogId(b, a)
}
}
pub fn as_inner(&self) -> (&PlayerId, &PlayerId) {
(&self.0, &self.1)
}
pub fn into_inner(self) -> (PlayerId, PlayerId) {
(self.0, self.1)
}
}
struct Dialog {
storage_id: u32,
player_storage_id_1: u32,
player_storage_id_2: u32,
message_count: u32,
}
struct DialogRegistryInner {
dialogs: HashMap<DialogId, AsyncRwLock<Dialog>>,
}
pub(crate) struct DialogRegistry(AsyncRwLock<DialogRegistryInner>);
impl Services {
#[tracing::instrument(skip(self, body, created_at))]
pub async fn send_dialog_message(
&self,
from: PlayerId,
to: PlayerId,
body: Str,
created_at: &DateTime<Utc>,
) -> Result<()> {
let guard = self.dialogs.0.read().await;
let id = DialogId::new(from.clone(), to.clone());
let dialog = guard.dialogs.get(&id);
if let Some(d) = dialog {
let mut d = d.write().await;
self.storage
.insert_dialog_message(d.storage_id, d.message_count, from.as_inner(), &body, created_at)
.await?;
d.message_count += 1;
} else {
drop(guard);
let mut guard2 = self.dialogs.0.write().await;
// double check in case concurrent access has loaded this dialog
if let Some(d) = guard2.dialogs.get(&id) {
let mut d = d.write().await;
self.storage
.insert_dialog_message(d.storage_id, d.message_count, from.as_inner(), &body, created_at)
.await?;
d.message_count += 1;
} else {
let (p1, p2) = id.as_inner();
tracing::info!("Dialog {id:?} not found locally, trying to load from storage");
let stored_dialog = match self.storage.retrieve_dialog(p1.as_inner(), p2.as_inner()).await? {
Some(t) => t,
None => {
tracing::info!("Dialog {id:?} does not exist, creating a new one in storage");
self.storage.initialize_dialog(p1.as_inner(), p2.as_inner(), created_at).await?
}
};
tracing::info!("Dialog {id:?} loaded");
self.storage
.insert_dialog_message(
stored_dialog.id,
stored_dialog.message_count,
from.as_inner(),
&body,
created_at,
)
.await?;
let dialog = Dialog {
storage_id: stored_dialog.id,
player_storage_id_1: stored_dialog.participant_1,
player_storage_id_2: stored_dialog.participant_2,
message_count: stored_dialog.message_count + 1,
};
guard2.dialogs.insert(id.clone(), AsyncRwLock::new(dialog));
}
drop(guard2);
}
// TODO send message to the other player and persist it
let Some(player) = self.players.get_player(&to).await else {
tracing::debug!("Player {to:?} not active, not sending message");
return Ok(());
};
let update = Updates::NewDialogMessage {
sender: from.clone(),
receiver: to.clone(),
body: body.clone(),
created_at: created_at.clone(),
};
player.update(update).await;
return Ok(());
}
}
impl DialogRegistry {
pub fn new() -> DialogRegistry {
DialogRegistry(AsyncRwLock::new(DialogRegistryInner {
dialogs: HashMap::new(),
}))
}
pub fn shutdown(self) {}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_dialog_id_new() {
let a = PlayerId::from("a").unwrap();
let b = PlayerId::from("b").unwrap();
let id1 = DialogId::new(a.clone(), b.clone());
let id2 = DialogId::new(a.clone(), b.clone());
// Dialog ids are invariant with respect to the order of participants
assert_eq!(id1, id2);
assert_eq!(id1.as_inner(), (&a, &b));
assert_eq!(id2.as_inner(), (&a, &b));
}
}

View File

@ -1,4 +1,20 @@
//! Domain definitions and implementation of common chat logic.
use std::ops::Deref;
use std::sync::Arc;
use anyhow::Result;
use prometheus::Registry as MetricsRegistry;
use crate::clustering::broadcast::Broadcasting;
use crate::clustering::{ClusterConfig, ClusterMetadata, LavinaClient};
use crate::dialog::DialogRegistry;
use crate::player::{PlayerConnection, PlayerId, PlayerRegistry};
use crate::repo::Storage;
use crate::room::{RoomHandle, RoomId, RoomInfo, RoomRegistry};
pub mod auth;
pub mod clustering;
pub mod dialog;
pub mod player;
pub mod prelude;
pub mod repo;
@ -6,3 +22,91 @@ pub mod room;
pub mod terminator;
mod table;
#[derive(Clone)]
pub struct LavinaCore {
services: Arc<Services>,
}
impl Deref for LavinaCore {
type Target = Services;
fn deref(&self) -> &Self::Target {
&self.services
}
}
impl LavinaCore {
pub async fn connect_to_player(&self, player_id: &PlayerId) -> PlayerConnection {
self.services.players.connect_to_player(&self, player_id).await
}
pub async fn get_room(&self, room_id: &RoomId) -> Option<RoomHandle> {
self.services.rooms.get_room(&self.services, room_id).await
}
pub async fn create_player(&self, player_id: &PlayerId) -> Result<()> {
self.services.storage.create_user(player_id.as_inner()).await
}
pub async fn get_all_rooms(&self) -> Vec<RoomInfo> {
self.services.rooms.get_all_rooms().await
}
pub async fn stop_player(&self, player_id: &PlayerId) -> Result<Option<()>> {
self.services.players.stop_player(player_id).await
}
}
pub struct Services {
pub(crate) players: PlayerRegistry,
pub(crate) rooms: RoomRegistry,
pub(crate) dialogs: DialogRegistry,
pub(crate) broadcasting: Broadcasting,
pub(crate) client: LavinaClient,
pub(crate) storage: Storage,
pub(crate) cluster_metadata: ClusterMetadata,
}
impl LavinaCore {
pub async fn new(
metrics: &mut MetricsRegistry,
cluster_config: ClusterConfig,
storage: Storage,
) -> Result<LavinaCore> {
// TODO shutdown all services in reverse order on error
let broadcasting = Broadcasting::new();
let client = LavinaClient::new(cluster_config.addresses.clone());
let rooms = RoomRegistry::new(metrics)?;
let dialogs = DialogRegistry::new();
let players = PlayerRegistry::empty(metrics)?;
let services = Services {
players,
rooms,
dialogs,
broadcasting,
client,
storage,
cluster_metadata: cluster_config.metadata,
};
Ok(LavinaCore {
services: Arc::new(services),
})
}
pub async fn shutdown(self) -> Storage {
let _ = self.players.shutdown_all().await;
let services = match Arc::try_unwrap(self.services) {
Ok(e) => e,
Err(_) => {
panic!("failed to acquire services ownership on shutdown");
}
};
let _ = services.players.shutdown();
let _ = services.dialogs.shutdown();
let _ = services.rooms.shutdown();
services.storage
}
}

View File

@ -8,17 +8,19 @@
//! A player actor is a serial handler of commands from a single player. It is preferable to run all per-player validations in the player actor,
//! so that they don't overload the room actor.
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use chrono::{DateTime, Utc};
use prometheus::{IntGauge, Registry as MetricsRegistry};
use serde::Serialize;
use tokio::sync::mpsc::{channel, Receiver, Sender};
use tokio::sync::RwLock;
use tracing::{Instrument, Span};
use crate::clustering::room::*;
use crate::prelude::*;
use crate::repo::Storage;
use crate::room::{RoomHandle, RoomId, RoomInfo, RoomRegistry};
use crate::room::{RoomHandle, RoomId, RoomInfo};
use crate::table::{AnonTable, Key as AnonKey};
use crate::LavinaCore;
/// Opaque player identifier. Cannot contain spaces, must be shorter than 32.
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize)]
@ -52,12 +54,13 @@ pub struct ConnectionId(pub AnonKey);
/// The connection is used to send commands to the player actor and to receive updates that might be sent to the client.
pub struct PlayerConnection {
pub connection_id: ConnectionId,
pub receiver: Receiver<Updates>,
pub receiver: Receiver<ConnectionMessage>,
player_handle: PlayerHandle,
}
impl PlayerConnection {
/// Handled in [Player::send_message].
pub async fn send_message(&mut self, room_id: RoomId, body: Str) -> Result<()> {
/// Handled in [Player::send_room_message].
#[tracing::instrument(skip(self, body), name = "PlayerConnection::send_message")]
pub async fn send_message(&mut self, room_id: RoomId, body: Str) -> Result<SendMessageResult> {
let (promise, deferred) = oneshot();
let cmd = ClientCommand::SendMessage { room_id, body, promise };
self.player_handle.send(ActorCommand::ClientCommand(cmd, self.connection_id.clone())).await;
@ -65,6 +68,7 @@ impl PlayerConnection {
}
/// Handled in [Player::join_room].
#[tracing::instrument(skip(self), name = "PlayerConnection::join_room")]
pub async fn join_room(&mut self, room_id: RoomId) -> Result<JoinResult> {
let (promise, deferred) = oneshot();
let cmd = ClientCommand::JoinRoom { room_id, promise };
@ -72,7 +76,8 @@ impl PlayerConnection {
Ok(deferred.await?)
}
/// Handled in [Player::change_topic].
/// Handled in [Player::change_room_topic].
#[tracing::instrument(skip(self, new_topic), name = "PlayerConnection::change_topic")]
pub async fn change_topic(&mut self, room_id: RoomId, new_topic: Str) -> Result<()> {
let (promise, deferred) = oneshot();
let cmd = ClientCommand::ChangeTopic {
@ -85,6 +90,7 @@ impl PlayerConnection {
}
/// Handled in [Player::leave_room].
#[tracing::instrument(skip(self), name = "PlayerConnection::leave_room")]
pub async fn leave_room(&mut self, room_id: RoomId) -> Result<()> {
let (promise, deferred) = oneshot();
let cmd = ClientCommand::LeaveRoom { room_id, promise };
@ -97,25 +103,48 @@ impl PlayerConnection {
}
/// Handled in [Player::get_rooms].
#[tracing::instrument(skip(self), name = "PlayerConnection::get_rooms")]
pub async fn get_rooms(&self) -> Result<Vec<RoomInfo>> {
let (promise, deferred) = oneshot();
let cmd = ClientCommand::GetRooms { promise };
self.player_handle.send(ActorCommand::ClientCommand(cmd, self.connection_id.clone())).await;
Ok(deferred.await?)
}
/// Handler in [Player::send_dialog_message].
#[tracing::instrument(skip(self, body), name = "PlayerConnection::send_dialog_message")]
pub async fn send_dialog_message(&self, recipient: PlayerId, body: Str) -> Result<()> {
let (promise, deferred) = oneshot();
let cmd = ClientCommand::SendDialogMessage {
recipient,
body,
promise,
};
self.player_handle.send(ActorCommand::ClientCommand(cmd, self.connection_id.clone())).await;
Ok(deferred.await?)
}
/// Handler in [Player::check_user_existence].
#[tracing::instrument(skip(self), name = "PlayerConnection::check_user_existence")]
pub async fn check_user_existence(&self, recipient: PlayerId) -> Result<GetInfoResult> {
let (promise, deferred) = oneshot();
let cmd = ClientCommand::GetInfo { recipient, promise };
self.player_handle.send(ActorCommand::ClientCommand(cmd, self.connection_id.clone())).await;
Ok(deferred.await?)
}
}
/// Handle to a player actor.
#[derive(Clone)]
pub struct PlayerHandle {
tx: Sender<ActorCommand>,
tx: Sender<(ActorCommand, Span)>,
}
impl PlayerHandle {
pub async fn subscribe(&self) -> PlayerConnection {
let (sender, receiver) = channel(32);
let (promise, deferred) = oneshot();
let cmd = ActorCommand::AddConnection { sender, promise };
let _ = self.tx.send(cmd).await;
self.send(cmd).await;
let connection_id = deferred.await.unwrap();
PlayerConnection {
connection_id,
@ -125,8 +154,9 @@ impl PlayerHandle {
}
async fn send(&self, command: ActorCommand) {
let span = tracing::span!(tracing::Level::INFO, "PlayerHandle::send");
// TODO either handle the error or doc why it is safe to ignore
let _ = self.tx.send(command).await;
let _ = self.tx.send((command, span)).await;
}
pub async fn update(&self, update: Updates) {
@ -138,7 +168,7 @@ impl PlayerHandle {
enum ActorCommand {
/// Establish a new connection.
AddConnection {
sender: Sender<Updates>,
sender: Sender<ConnectionMessage>,
promise: Promise<ConnectionId>,
},
/// Terminate an existing connection.
@ -163,7 +193,7 @@ pub enum ClientCommand {
SendMessage {
room_id: RoomId,
body: Str,
promise: Promise<()>,
promise: Promise<SendMessageResult>,
},
ChangeTopic {
room_id: RoomId,
@ -173,13 +203,33 @@ pub enum ClientCommand {
GetRooms {
promise: Promise<Vec<RoomInfo>>,
},
SendDialogMessage {
recipient: PlayerId,
body: Str,
promise: Promise<()>,
},
GetInfo {
recipient: PlayerId,
promise: Promise<GetInfoResult>,
},
}
pub enum GetInfoResult {
UserExists,
UserDoesntExist,
}
pub enum JoinResult {
Success(RoomInfo),
AlreadyJoined,
Banned,
}
pub enum SendMessageResult {
Success(DateTime<Utc>),
NoSuchRoom,
}
/// Player update event type which is sent to a player actor and from there to a connection handler.
#[derive(Clone, Debug)]
pub enum Updates {
@ -191,6 +241,7 @@ pub enum Updates {
room_id: RoomId,
author_id: PlayerId,
body: Str,
created_at: DateTime<Utc>,
},
RoomJoined {
room_id: RoomId,
@ -202,46 +253,78 @@ pub enum Updates {
},
/// The player was banned from the room and left it immediately.
BannedFrom(RoomId),
NewDialogMessage {
sender: PlayerId,
receiver: PlayerId,
body: Str,
created_at: DateTime<Utc>,
},
}
/// Handle to a player registry — a shared data structure containing information about players.
#[derive(Clone)]
pub struct PlayerRegistry(Arc<RwLock<PlayerRegistryInner>>);
pub(crate) struct PlayerRegistry(RwLock<PlayerRegistryInner>);
impl PlayerRegistry {
pub fn empty(
room_registry: RoomRegistry,
storage: Storage,
metrics: &mut MetricsRegistry,
) -> Result<PlayerRegistry> {
pub fn empty(metrics: &mut MetricsRegistry) -> Result<PlayerRegistry> {
let metric_active_players = IntGauge::new("chat_players_active", "Number of alive player actors")?;
metrics.register(Box::new(metric_active_players.clone()))?;
let inner = PlayerRegistryInner {
room_registry,
storage,
players: HashMap::new(),
metric_active_players,
};
Ok(PlayerRegistry(Arc::new(RwLock::new(inner))))
Ok(PlayerRegistry(RwLock::new(inner)))
}
pub async fn get_or_launch_player(&mut self, id: &PlayerId) -> PlayerHandle {
pub fn shutdown(self) {
let res = self.0.into_inner();
drop(res);
}
#[tracing::instrument(skip(self), name = "PlayerRegistry::get_player")]
pub async fn get_player(&self, id: &PlayerId) -> Option<PlayerHandle> {
let inner = self.0.read().await;
inner.players.get(id).map(|(handle, _)| handle.clone())
}
#[tracing::instrument(skip(self), name = "PlayerRegistry::stop_player")]
pub async fn stop_player(&self, id: &PlayerId) -> Result<Option<()>> {
let mut inner = self.0.write().await;
if let Some((handle, fiber)) = inner.players.remove(id) {
handle.send(ActorCommand::Stop).await;
drop(handle);
fiber.await?;
inner.metric_active_players.dec();
Ok(Some(()))
} else {
Ok(None)
}
}
#[tracing::instrument(skip(self, core), name = "PlayerRegistry::get_or_launch_player")]
pub async fn get_or_launch_player(&self, core: &LavinaCore, id: &PlayerId) -> PlayerHandle {
let inner = self.0.read().await;
if let Some((handle, _)) = inner.players.get(id) {
handle.clone()
} else {
drop(inner);
let mut inner = self.0.write().await;
if let Some((handle, _)) = inner.players.get(id) {
handle.clone()
} else {
let (handle, fiber) = Player::launch(id.clone(), inner.room_registry.clone(), inner.storage.clone()).await;
let (handle, fiber) = Player::launch(id.clone(), core.clone()).await;
inner.players.insert(id.clone(), (handle.clone(), fiber));
inner.metric_active_players.inc();
handle
}
}
}
pub async fn connect_to_player(&mut self, id: &PlayerId) -> PlayerConnection {
let player_handle = self.get_or_launch_player(id).await;
#[tracing::instrument(skip(self, core), name = "PlayerRegistry::connect_to_player")]
pub async fn connect_to_player(&self, core: &LavinaCore, id: &PlayerId) -> PlayerConnection {
let player_handle = self.get_or_launch_player(core, id).await;
player_handle.subscribe().await
}
pub async fn shutdown_all(&mut self) -> Result<()> {
pub async fn shutdown_all(&self) -> Result<()> {
let mut inner = self.0.write().await;
for (i, (k, j)) in inner.players.drain() {
k.send(ActorCommand::Stop).await;
@ -256,31 +339,33 @@ impl PlayerRegistry {
/// The player registry state representation.
struct PlayerRegistryInner {
room_registry: RoomRegistry,
storage: Storage,
/// Active player actors.
players: HashMap<PlayerId, (PlayerHandle, JoinHandle<Player>)>,
metric_active_players: IntGauge,
}
enum RoomRef {
Local(RoomHandle),
Remote { node_id: u32 },
}
/// Player actor inner state representation.
struct Player {
player_id: PlayerId,
storage_id: u32,
connections: AnonTable<Sender<Updates>>,
my_rooms: HashMap<RoomId, RoomHandle>,
connections: AnonTable<Sender<ConnectionMessage>>,
my_rooms: HashMap<RoomId, RoomRef>,
banned_from: HashSet<RoomId>,
rx: Receiver<ActorCommand>,
rx: Receiver<(ActorCommand, Span)>,
handle: PlayerHandle,
rooms: RoomRegistry,
storage: Storage,
services: LavinaCore,
}
impl Player {
async fn launch(player_id: PlayerId, rooms: RoomRegistry, storage: Storage) -> (PlayerHandle, JoinHandle<Player>) {
async fn launch(player_id: PlayerId, core: LavinaCore) -> (PlayerHandle, JoinHandle<Player>) {
let (tx, rx) = channel(32);
let handle = PlayerHandle { tx };
let handle_clone = handle.clone();
let storage_id = storage.retrieve_user_id_by_name(player_id.as_inner()).await.unwrap().unwrap();
let storage_id = core.services.storage.retrieve_user_id_by_name(player_id.as_inner()).await.unwrap().unwrap();
let player = Player {
player_id,
storage_id,
@ -292,24 +377,40 @@ impl Player {
banned_from: HashSet::new(),
rx,
handle,
rooms,
storage,
services: core,
};
let fiber = tokio::task::spawn(player.main_loop());
(handle_clone, fiber)
}
fn room_location(&self, room_id: &RoomId) -> Option<u32> {
let res = self.services.cluster_metadata.rooms.get(room_id.as_inner().as_ref()).copied();
let node = res.unwrap_or(self.services.cluster_metadata.main_owner);
if node == self.services.cluster_metadata.node_id {
None
} else {
Some(node)
}
}
async fn main_loop(mut self) -> Self {
let rooms = self.storage.get_rooms_of_a_user(self.storage_id).await.unwrap();
let rooms = self.services.storage.get_rooms_of_a_user(self.storage_id).await.unwrap();
for room_id in rooms {
let room = self.rooms.get_room(&room_id).await;
if let Some(remote_node) = self.room_location(&room_id) {
self.my_rooms.insert(room_id.clone(), RoomRef::Remote { node_id: remote_node });
self.services.subscribe(self.player_id.clone(), room_id).await;
} else {
let room = self.services.rooms.get_room(&self.services, &room_id).await;
if let Some(room) = room {
self.my_rooms.insert(room_id, room);
self.my_rooms.insert(room_id, RoomRef::Local(room));
} else {
tracing::error!("Room #{room_id:?} not found");
}
}
}
while let Some(cmd) = self.rx.recv().await {
let (cmd, span) = cmd;
let should_stop = async {
match cmd {
ActorCommand::AddConnection { sender, promise } => {
let connection_id = self.connections.insert(sender);
@ -317,13 +418,27 @@ impl Player {
log::warn!("Connection {connection_id:?} terminated before finalization");
self.terminate_connection(connection_id);
}
false
}
ActorCommand::TerminateConnection(connection_id) => {
self.terminate_connection(connection_id);
false
}
ActorCommand::Update(update) => self.handle_update(update).await,
ActorCommand::ClientCommand(cmd, connection_id) => self.handle_cmd(cmd, connection_id).await,
ActorCommand::Stop => break,
ActorCommand::Update(update) => {
self.handle_update(update).await;
false
}
ActorCommand::ClientCommand(cmd, connection_id) => {
self.handle_cmd(cmd, connection_id).await;
false
}
ActorCommand::Stop => true,
}
}
.instrument(span)
.await;
if should_stop {
break;
}
}
log::debug!("Shutting down player actor #{:?}", self.player_id);
@ -331,8 +446,9 @@ impl Player {
}
/// Handle an incoming update by changing the internal state and broadcasting it to all connections if necessary.
#[tracing::instrument(skip(self, update), name = "Player::handle_update")]
async fn handle_update(&mut self, update: Updates) {
log::info!(
log::debug!(
"Player received an update, broadcasting to {} connections",
self.connections.len()
);
@ -344,7 +460,7 @@ impl Player {
_ => {}
}
for (_, connection) in &self.connections {
let _ = connection.send(update.clone()).await;
let _ = connection.send(ConnectionMessage::Update(update.clone())).await;
}
}
@ -366,39 +482,71 @@ impl Player {
let _ = promise.send(());
}
ClientCommand::SendMessage { room_id, body, promise } => {
self.send_message(connection_id, room_id, body).await;
let _ = promise.send(());
let result = self.send_room_message(connection_id, room_id, body).await;
let _ = promise.send(result);
}
ClientCommand::ChangeTopic {
room_id,
new_topic,
promise,
} => {
self.change_topic(connection_id, room_id, new_topic).await;
self.change_room_topic(connection_id, room_id, new_topic).await;
let _ = promise.send(());
}
ClientCommand::GetRooms { promise } => {
let result = self.get_rooms().await;
let _ = promise.send(result);
}
ClientCommand::SendDialogMessage {
recipient,
body,
promise,
} => {
self.send_dialog_message(connection_id, recipient, body).await;
let _ = promise.send(());
}
ClientCommand::GetInfo { recipient, promise } => {
let result = self.check_user_existence(recipient).await;
let _ = promise.send(result);
}
}
}
#[tracing::instrument(skip(self), name = "Player::join_room")]
async fn join_room(&mut self, connection_id: ConnectionId, room_id: RoomId) -> JoinResult {
if self.banned_from.contains(&room_id) {
return JoinResult::Banned;
}
if self.my_rooms.contains_key(&room_id) {
return JoinResult::AlreadyJoined;
}
let room = match self.rooms.get_or_create_room(room_id.clone()).await {
if let Some(remote_node) = self.room_location(&room_id) {
let req = JoinRoomReq {
room_id: room_id.as_inner(),
player_id: self.player_id.as_inner(),
};
self.services.client.join_room(remote_node, req).await.unwrap();
let room_storage_id =
self.services.storage.create_or_retrieve_room_id_by_name(room_id.as_inner()).await.unwrap();
self.services.storage.add_room_member(room_storage_id, self.storage_id).await.unwrap();
self.my_rooms.insert(room_id.clone(), RoomRef::Remote { node_id: remote_node });
JoinResult::Success(RoomInfo {
id: room_id,
topic: "unknown".into(),
members: vec![],
})
} else {
let room = match self.services.rooms.get_or_create_room(&self.services, room_id.clone()).await {
Ok(room) => room,
Err(e) => {
log::error!("Failed to get or create room: {e}");
todo!();
}
};
room.add_member(&self.player_id, self.storage_id).await;
room.add_member(&self.services, &self.player_id, self.storage_id).await;
room.subscribe(&self.player_id, self.handle.clone()).await;
self.my_rooms.insert(room_id.clone(), room.clone());
self.my_rooms.insert(room_id.clone(), RoomRef::Local(room.clone()));
let room_info = room.get_room_info().await;
let update = Updates::RoomJoined {
room_id,
@ -407,12 +555,28 @@ impl Player {
self.broadcast_update(update, connection_id).await;
JoinResult::Success(room_info)
}
}
#[tracing::instrument(skip(self), name = "Player::leave_room")]
async fn leave_room(&mut self, connection_id: ConnectionId, room_id: RoomId) {
let room = self.my_rooms.remove(&room_id);
if let Some(room) = room {
match room {
RoomRef::Local(room) => {
room.unsubscribe(&self.player_id).await;
room.remove_member(&self.player_id, self.storage_id).await;
room.remove_member(&self.services, &self.player_id, self.storage_id).await;
}
RoomRef::Remote { node_id } => {
let req = LeaveRoomReq {
room_id: room_id.as_inner(),
player_id: self.player_id.as_inner(),
};
self.services.client.leave_room(node_id, req).await.unwrap();
let room_storage_id =
self.services.storage.create_or_retrieve_room_id_by_name(room_id.as_inner()).await.unwrap();
self.services.storage.remove_room_member(room_storage_id, self.storage_id).await.unwrap();
}
}
}
let update = Updates::RoomLeft {
room_id,
@ -421,48 +585,141 @@ impl Player {
self.broadcast_update(update, connection_id).await;
}
async fn send_message(&mut self, connection_id: ConnectionId, room_id: RoomId, body: Str) {
#[tracing::instrument(skip(self, body), name = "Player::send_room_message")]
async fn send_room_message(
&mut self,
connection_id: ConnectionId,
room_id: RoomId,
body: Str,
) -> SendMessageResult {
let Some(room) = self.my_rooms.get(&room_id) else {
tracing::info!("no room found");
return;
return SendMessageResult::NoSuchRoom;
};
room.send_message(&self.player_id, body.clone()).await;
let created_at = Utc::now();
match room {
RoomRef::Local(room) => {
room.send_message(&self.services, &self.player_id, body.clone(), created_at.clone()).await;
}
RoomRef::Remote { node_id } => {
let req = SendMessageReq {
room_id: room_id.as_inner(),
player_id: self.player_id.as_inner(),
message: &*body,
created_at: &*created_at.to_rfc3339(),
};
self.services.client.send_room_message(*node_id, req).await.unwrap();
self.services
.broadcast(
room_id.clone(),
self.player_id.clone(),
body.clone(),
created_at.clone(),
)
.await;
}
}
let update = Updates::NewMessage {
room_id,
author_id: self.player_id.clone(),
body,
created_at,
};
self.broadcast_update(update, connection_id).await;
SendMessageResult::Success(created_at)
}
async fn change_topic(&mut self, connection_id: ConnectionId, room_id: RoomId, new_topic: Str) {
#[tracing::instrument(skip(self, new_topic), name = "Player::change_room_topic")]
async fn change_room_topic(&mut self, connection_id: ConnectionId, room_id: RoomId, new_topic: Str) {
let Some(room) = self.my_rooms.get(&room_id) else {
tracing::info!("no room found");
return;
};
room.set_topic(&self.player_id, new_topic.clone()).await;
match room {
RoomRef::Local(room) => {
room.set_topic(&self.services, &self.player_id, new_topic.clone()).await;
}
RoomRef::Remote { node_id } => {
let req = SetRoomTopicReq {
room_id: room_id.as_inner(),
player_id: self.player_id.as_inner(),
topic: &*new_topic,
};
self.services.client.set_room_topic(*node_id, req).await.unwrap();
}
}
let update = Updates::RoomTopicChanged { room_id, new_topic };
self.broadcast_update(update, connection_id).await;
}
#[tracing::instrument(skip(self), name = "Player::get_rooms")]
async fn get_rooms(&self) -> Vec<RoomInfo> {
let mut response = vec![];
for (_, handle) in &self.my_rooms {
for (room_id, handle) in &self.my_rooms {
match handle {
RoomRef::Local(handle) => {
response.push(handle.get_room_info().await);
}
RoomRef::Remote { .. } => {
let room_info = RoomInfo {
id: room_id.clone(),
topic: "unknown".into(),
members: vec![],
};
response.push(room_info);
}
}
}
response
}
#[tracing::instrument(skip(self, body), name = "Player::send_dialog_message")]
async fn send_dialog_message(&self, connection_id: ConnectionId, recipient: PlayerId, body: Str) {
let created_at = Utc::now();
self.services
.send_dialog_message(self.player_id.clone(), recipient.clone(), body.clone(), &created_at)
.await
.unwrap();
let update = Updates::NewDialogMessage {
sender: self.player_id.clone(),
receiver: recipient.clone(),
body,
created_at,
};
self.broadcast_update(update, connection_id).await;
}
#[tracing::instrument(skip(self), name = "Player::check_user_existence")]
async fn check_user_existence(&self, recipient: PlayerId) -> GetInfoResult {
if self.services.storage.check_user_existence(recipient.as_inner().as_ref()).await.unwrap() {
GetInfoResult::UserExists
} else {
GetInfoResult::UserDoesntExist
}
}
/// Broadcasts an update to all connections except the one with the given id.
///
/// This is called after handling a client command.
/// Sending the update to the connection which sent the command is handled by the connection itself.
#[tracing::instrument(skip(self, update), name = "Player::broadcast_update")]
async fn broadcast_update(&self, update: Updates, except: ConnectionId) {
for (a, b) in &self.connections {
if ConnectionId(a) == except {
continue;
}
let _ = b.send(update.clone()).await;
let _ = b.send(ConnectionMessage::Update(update.clone())).await;
}
}
}
pub enum ConnectionMessage {
Update(Updates),
Stop(StopReason),
}
#[derive(Debug)]
pub enum StopReason {
ServerShutdown,
InternalError,
}

View File

@ -0,0 +1,20 @@
use anyhow::Result;
use crate::repo::Storage;
impl Storage {
#[tracing::instrument(skip(self), name = "Storage::set_argon2_challenge")]
pub async fn set_argon2_challenge(&self, user_id: u32, hash: &str) -> Result<()> {
let mut executor = self.conn.lock().await;
sqlx::query(
"insert into challenges_argon2_password(user_id, hash)
values (?, ?)
on conflict(user_id) do update set hash = excluded.hash;",
)
.bind(user_id)
.bind(hash)
.execute(&mut *executor)
.await?;
Ok(())
}
}

View File

@ -0,0 +1,91 @@
use anyhow::{anyhow, Result};
use chrono::{DateTime, Utc};
use sqlx::FromRow;
use crate::repo::Storage;
impl Storage {
#[tracing::instrument(skip(self), name = "Storage::retrieve_dialog")]
pub async fn retrieve_dialog(&self, participant_1: &str, participant_2: &str) -> Result<Option<StoredDialog>> {
let mut executor = self.conn.lock().await;
let res = sqlx::query_as(
"select r.id, r.participant_1, r.participant_2, r.message_count
from dialogs r join users u1 on r.participant_1 = u1.id join users u2 on r.participant_2 = u2.id
where u1.name = ? and u2.name = ?;",
)
.bind(participant_1)
.bind(participant_2)
.fetch_optional(&mut *executor)
.await?;
Ok(res)
}
#[tracing::instrument(skip(self, content, created_at), name = "Storage::insert_dialog_message")]
pub async fn insert_dialog_message(
&self,
dialog_id: u32,
id: u32,
author_id: &str,
content: &str,
created_at: &DateTime<Utc>,
) -> Result<()> {
let mut executor = self.conn.lock().await;
let res: Option<(u32,)> = sqlx::query_as("select id from users where name = ?;")
.bind(author_id)
.fetch_optional(&mut *executor)
.await?;
let Some((author_id,)) = res else {
return Err(anyhow!("No such user"));
};
sqlx::query(
"insert into dialog_messages(dialog_id, id, author_id, content, created_at)
values (?, ?, ?, ?, ?);
update dialogs set message_count = message_count + 1 where id = ?;",
)
.bind(dialog_id)
.bind(id)
.bind(author_id)
.bind(content)
.bind(created_at)
.bind(dialog_id)
.execute(&mut *executor)
.await?;
Ok(())
}
#[tracing::instrument(skip(self, created_at), name = "Storage::initialize_dialog")]
pub async fn initialize_dialog(
&self,
participant_1: &str,
participant_2: &str,
created_at: &DateTime<Utc>,
) -> Result<StoredDialog> {
let mut executor = self.conn.lock().await;
let res: StoredDialog = sqlx::query_as(
"insert into dialogs(participant_1, participant_2, created_at)
values (
(select id from users where name = ?),
(select id from users where name = ?),
?
)
returning id, participant_1, participant_2, message_count;",
)
.bind(participant_1)
.bind(participant_2)
.bind(&created_at)
.fetch_one(&mut *executor)
.await?;
Ok(res)
}
}
#[derive(FromRow)]
pub struct StoredDialog {
pub id: u32,
pub participant_1: u32,
pub participant_2: u32,
pub message_count: u32,
}

View File

@ -1,16 +1,16 @@
//! Storage and persistence logic.
use std::str::FromStr;
use std::sync::Arc;
use anyhow::anyhow;
use serde::Deserialize;
use sqlx::sqlite::SqliteConnectOptions;
use sqlx::{ConnectOptions, Connection, FromRow, Sqlite, SqliteConnection, Transaction};
use sqlx::{ConnectOptions, Connection, SqliteConnection};
use tokio::sync::Mutex;
use crate::prelude::*;
mod auth;
mod dialog;
mod room;
mod user;
@ -19,9 +19,8 @@ pub struct StorageConfig {
pub db_path: String,
}
#[derive(Clone)]
pub struct Storage {
conn: Arc<Mutex<SqliteConnection>>,
conn: Mutex<SqliteConnection>,
}
impl Storage {
pub async fn open(config: StorageConfig) -> Result<Storage> {
@ -33,145 +32,17 @@ impl Storage {
migrator.run(&mut conn).await?;
log::info!("Migrations passed");
let conn = Arc::new(Mutex::new(conn));
let conn = Mutex::new(conn);
Ok(Storage { conn })
}
pub async fn retrieve_user_by_name(&mut self, name: &str) -> Result<Option<StoredUser>> {
let mut executor = self.conn.lock().await;
let res = sqlx::query_as(
"select u.id, u.name, c.password
from users u left join challenges_plain_password c on u.id = c.user_id
where u.name = ?;",
)
.bind(name)
.fetch_optional(&mut *executor)
.await?;
Ok(res)
}
pub async fn retrieve_room_by_name(&self, name: &str) -> Result<Option<StoredRoom>> {
let mut executor = self.conn.lock().await;
let res = sqlx::query_as(
"select id, name, topic, message_count
from rooms
where name = ?;",
)
.bind(name)
.fetch_optional(&mut *executor)
.await?;
Ok(res)
}
pub async fn create_new_room(&mut self, name: &str, topic: &str) -> Result<u32> {
let mut executor = self.conn.lock().await;
let (id,): (u32,) = sqlx::query_as(
"insert into rooms(name, topic)
values (?, ?)
returning id;",
)
.bind(name)
.bind(topic)
.fetch_one(&mut *executor)
.await?;
Ok(id)
}
pub async fn insert_message(&mut self, room_id: u32, id: u32, content: &str, author_id: &str) -> Result<()> {
let mut executor = self.conn.lock().await;
let res: Option<(u32,)> = sqlx::query_as("select id from users where name = ?;")
.bind(author_id)
.fetch_optional(&mut *executor)
.await?;
let Some((author_id,)) = res else {
return Err(anyhow!("No such user"));
};
sqlx::query(
"insert into messages(room_id, id, content, author_id, created_at)
values (?, ?, ?, ?, ?);
update rooms set message_count = message_count + 1 where id = ?;",
)
.bind(room_id)
.bind(id)
.bind(content)
.bind(author_id)
.bind(chrono::Utc::now().to_string())
.bind(room_id)
.execute(&mut *executor)
.await?;
Ok(())
}
pub async fn close(self) -> Result<()> {
let res = match Arc::try_unwrap(self.conn) {
Ok(e) => e,
Err(_) => return Err(fail("failed to acquire DB ownership on shutdown")),
};
let res = res.into_inner();
res.close().await?;
Ok(())
}
pub async fn create_user(&mut self, name: &str) -> Result<()> {
let query = sqlx::query(
"insert into users(name)
values (?);",
)
.bind(name);
let mut executor = self.conn.lock().await;
query.execute(&mut *executor).await?;
Ok(())
}
pub async fn set_password<'a>(&'a mut self, name: &'a str, pwd: &'a str) -> Result<Option<()>> {
async fn inner(txn: &mut Transaction<'_, Sqlite>, name: &str, pwd: &str) -> Result<Option<()>> {
let id: Option<(u32,)> = sqlx::query_as("select * from users where name = ? limit 1;")
.bind(name)
.fetch_optional(&mut **txn)
.await?;
let Some((id,)) = id else {
return Ok(None);
};
sqlx::query("insert or replace into challenges_plain_password(user_id, password) values (?, ?);")
.bind(id)
.bind(pwd)
.execute(&mut **txn)
.await?;
Ok(Some(()))
}
let mut executor = self.conn.lock().await;
let mut tx = executor.begin().await?;
let res = inner(&mut tx, name, pwd).await;
match res {
Ok(e) => {
tx.commit().await?;
Ok(e)
}
pub async fn close(self) {
let res = self.conn.into_inner();
match res.close().await {
Ok(_) => {}
Err(e) => {
tx.rollback().await?;
Err(e)
tracing::error!("Failed to close the DB connection: {e:?}");
}
}
}
}
#[derive(FromRow)]
pub struct StoredUser {
pub id: u32,
pub name: String,
pub password: Option<String>,
}
#[derive(FromRow)]
pub struct StoredRoom {
pub id: u32,
pub name: String,
pub topic: String,
pub message_count: u32,
}

View File

@ -1,8 +1,105 @@
use anyhow::Result;
use anyhow::{anyhow, Result};
use chrono::{DateTime, Utc};
use sqlx::FromRow;
use crate::repo::Storage;
use crate::room::RoomId;
#[derive(FromRow)]
pub struct StoredRoom {
pub id: u32,
pub name: String,
pub topic: String,
pub message_count: u32,
}
impl Storage {
#[tracing::instrument(skip(self), name = "Storage::retrieve_room_by_name")]
pub async fn retrieve_room_by_name(&self, name: &str) -> Result<Option<StoredRoom>> {
let mut executor = self.conn.lock().await;
let res = sqlx::query_as(
"select id, name, topic, message_count
from rooms
where name = ?;",
)
.bind(name)
.fetch_optional(&mut *executor)
.await?;
Ok(res)
}
#[tracing::instrument(skip(self, topic), name = "Storage::create_new_room")]
pub async fn create_new_room(&self, name: &str, topic: &str) -> Result<u32> {
let mut executor = self.conn.lock().await;
let (id,): (u32,) = sqlx::query_as(
"insert into rooms(name, topic)
values (?, ?)
returning id;",
)
.bind(name)
.bind(topic)
.fetch_one(&mut *executor)
.await?;
Ok(id)
}
#[tracing::instrument(skip(self, content, created_at), name = "Storage::insert_room_message")]
pub async fn insert_room_message(
&self,
room_id: u32,
id: u32,
content: &str,
author_id: &str,
created_at: &DateTime<Utc>,
) -> Result<()> {
let mut executor = self.conn.lock().await;
let res: Option<(u32,)> = sqlx::query_as("select id from users where name = ?;")
.bind(author_id)
.fetch_optional(&mut *executor)
.await?;
let Some((author_id,)) = res else {
return Err(anyhow!("No such user"));
};
sqlx::query(
"insert into messages(room_id, id, content, author_id, created_at)
values (?, ?, ?, ?, ?);
update rooms set message_count = message_count + 1 where id = ?;",
)
.bind(room_id)
.bind(id)
.bind(content)
.bind(author_id)
.bind(created_at.to_string())
.bind(room_id)
.execute(&mut *executor)
.await?;
Ok(())
}
#[tracing::instrument(skip(self), name = "Storage::is_room_member")]
pub async fn is_room_member(&self, room_id: u32, player_id: u32) -> Result<bool> {
let mut executor = self.conn.lock().await;
let res: (u32,) = sqlx::query_as(
"
select
count(*)
from
memberships
where
user_id = ? and room_id = ?;
",
)
.bind(player_id)
.bind(room_id)
.fetch_one(&mut *executor)
.await?;
Ok(res.0 > 0)
}
#[tracing::instrument(skip(self), name = "Storage::add_room_member")]
pub async fn add_room_member(&self, room_id: u32, player_id: u32) -> Result<()> {
let mut executor = self.conn.lock().await;
sqlx::query(
@ -17,6 +114,7 @@ impl Storage {
Ok(())
}
#[tracing::instrument(skip(self), name = "Storage::remove_room_member")]
pub async fn remove_room_member(&self, room_id: u32, player_id: u32) -> Result<()> {
let mut executor = self.conn.lock().await;
sqlx::query(
@ -31,7 +129,8 @@ impl Storage {
Ok(())
}
pub async fn set_room_topic(&mut self, id: u32, topic: &str) -> Result<()> {
#[tracing::instrument(skip(self, topic), name = "Storage::set_room_topic")]
pub async fn set_room_topic(&self, id: u32, topic: &str) -> Result<()> {
let mut executor = self.conn.lock().await;
sqlx::query(
"update rooms
@ -45,4 +144,36 @@ impl Storage {
Ok(())
}
#[tracing::instrument(skip(self), name = "Storage::create_or_retrieve_room_id_by_name")]
pub async fn create_or_retrieve_room_id_by_name(&self, name: &str) -> Result<u32> {
// TODO we don't need any info except the name on non-owning nodes, should remove stubs here
let mut executor = self.conn.lock().await;
let res: (u32,) = sqlx::query_as(
"insert into rooms(name, topic)
values (?, '')
on conflict(name) do update set name = excluded.name
returning id;",
)
.bind(name)
.fetch_one(&mut *executor)
.await?;
Ok(res.0)
}
#[tracing::instrument(skip(self), name = "Storage::get_rooms_of_a_user")]
pub async fn get_rooms_of_a_user(&self, user_id: u32) -> Result<Vec<RoomId>> {
let mut executor = self.conn.lock().await;
let res: Vec<(String,)> = sqlx::query_as(
"select r.name
from memberships m inner join rooms r on m.room_id = r.id
where m.user_id = ?;",
)
.bind(user_id)
.fetch_all(&mut *executor)
.await?;
res.into_iter().map(|(room_id,)| RoomId::from(room_id)).collect()
}
}

View File

@ -1,9 +1,91 @@
use anyhow::Result;
use sqlx::{Connection, FromRow, Sqlite, Transaction};
use crate::repo::Storage;
use crate::room::RoomId;
#[derive(FromRow)]
pub struct StoredUser {
pub id: u32,
pub name: String,
pub password: Option<String>,
pub argon2_hash: Option<Box<str>>,
}
impl Storage {
#[tracing::instrument(skip(self), name = "Storage::retrieve_user_by_name")]
pub async fn retrieve_user_by_name(&self, name: &str) -> Result<Option<StoredUser>> {
let mut executor = self.conn.lock().await;
let res = sqlx::query_as(
"select u.id, u.name, c.password, a.hash as argon2_hash
from users u left join challenges_plain_password c on u.id = c.user_id
left join challenges_argon2_password a on u.id = a.user_id
where u.name = ?;",
)
.bind(name)
.fetch_optional(&mut *executor)
.await?;
Ok(res)
}
#[tracing::instrument(skip(self), name = "Storage::check_user_existence")]
pub async fn check_user_existence(&self, username: &str) -> Result<bool> {
let mut executor = self.conn.lock().await;
let result: Option<(String,)> = sqlx::query_as("select name from users where name = ?;")
.bind(username)
.fetch_optional(&mut *executor)
.await?;
Ok(result.is_some())
}
#[tracing::instrument(skip(self), name = "Storage::create_user")]
pub async fn create_user(&self, name: &str) -> Result<()> {
let query = sqlx::query(
"insert into users(name)
values (?);",
)
.bind(name);
let mut executor = self.conn.lock().await;
query.execute(&mut *executor).await?;
Ok(())
}
#[tracing::instrument(skip(self, pwd), name = "Storage::set_password")]
pub async fn set_password(&self, name: &str, pwd: &str) -> Result<Option<()>> {
async fn inner(txn: &mut Transaction<'_, Sqlite>, name: &str, pwd: &str) -> Result<Option<()>> {
let id: Option<(u32,)> = sqlx::query_as("select * from users where name = ? limit 1;")
.bind(name)
.fetch_optional(&mut **txn)
.await?;
let Some((id,)) = id else {
return Ok(None);
};
sqlx::query("insert or replace into challenges_plain_password(user_id, password) values (?, ?);")
.bind(id)
.bind(pwd)
.execute(&mut **txn)
.await?;
Ok(Some(()))
}
let mut executor = self.conn.lock().await;
let mut tx = executor.begin().await?;
let res = inner(&mut tx, name, pwd).await;
match res {
Ok(e) => {
tx.commit().await?;
Ok(e)
}
Err(e) => {
tx.rollback().await?;
Err(e)
}
}
}
#[tracing::instrument(skip(self), name = "Storage::retrieve_user_id_by_name")]
pub async fn retrieve_user_id_by_name(&self, name: &str) -> Result<Option<u32>> {
let mut executor = self.conn.lock().await;
let res: Option<(u32,)> = sqlx::query_as("select u.id from users u where u.name = ?;")
@ -14,17 +96,19 @@ impl Storage {
Ok(res.map(|(id,)| id))
}
pub async fn get_rooms_of_a_user(&self, user_id: u32) -> Result<Vec<RoomId>> {
#[tracing::instrument(skip(self), name = "Storage::create_or_retrieve_user_id_by_name")]
pub async fn create_or_retrieve_user_id_by_name(&self, name: &str) -> Result<u32> {
let mut executor = self.conn.lock().await;
let res: Vec<(String,)> = sqlx::query_as(
"select r.name
from memberships m inner join rooms r on m.room_id = r.id
where m.user_id = ?;",
let res: (u32,) = sqlx::query_as(
"insert into users(name)
values (?)
on conflict(name) do update set name = excluded.name
returning id;",
)
.bind(user_id)
.fetch_all(&mut *executor)
.bind(name)
.fetch_one(&mut *executor)
.await?;
res.into_iter().map(|(room_id,)| RoomId::from(room_id)).collect()
Ok(res.0)
}
}

View File

@ -2,17 +2,19 @@
use std::collections::HashSet;
use std::{collections::HashMap, hash::Hash, sync::Arc};
use chrono::{DateTime, Utc};
use prometheus::{IntGauge, Registry as MetricRegistry};
use serde::Serialize;
use tokio::sync::RwLock as AsyncRwLock;
use crate::player::{PlayerHandle, PlayerId, Updates};
use crate::prelude::*;
use crate::repo::Storage;
use crate::Services;
/// Opaque room id
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize)]
pub struct RoomId(Str);
impl RoomId {
pub fn from(str: impl Into<Str>) -> Result<RoomId> {
let bytes = str.into();
@ -33,28 +35,32 @@ impl RoomId {
}
/// Shared data structure for storing metadata about rooms.
#[derive(Clone)]
pub struct RoomRegistry(Arc<AsyncRwLock<RoomRegistryInner>>);
pub(crate) struct RoomRegistry(AsyncRwLock<RoomRegistryInner>);
impl RoomRegistry {
pub fn new(metrics: &mut MetricRegistry, storage: Storage) -> Result<RoomRegistry> {
pub fn new(metrics: &mut MetricRegistry) -> Result<RoomRegistry> {
let metric_active_rooms = IntGauge::new("chat_rooms_active", "Number of alive room actors")?;
metrics.register(Box::new(metric_active_rooms.clone()))?;
let inner = RoomRegistryInner {
rooms: HashMap::new(),
metric_active_rooms,
storage,
};
Ok(RoomRegistry(Arc::new(AsyncRwLock::new(inner))))
Ok(RoomRegistry(AsyncRwLock::new(inner)))
}
pub async fn get_or_create_room(&mut self, room_id: RoomId) -> Result<RoomHandle> {
pub fn shutdown(self) {
// TODO iterate over rooms and stop them
}
#[tracing::instrument(skip(self, services), name = "RoomRegistry::get_or_create_room")]
pub async fn get_or_create_room(&self, services: &Services, room_id: RoomId) -> Result<RoomHandle> {
let mut inner = self.0.write().await;
if let Some(room_handle) = inner.get_or_load_room(&room_id).await? {
if let Some(room_handle) = inner.get_or_load_room(services, &room_id).await? {
Ok(room_handle.clone())
} else {
log::debug!("Creating room {}...", &room_id.0);
let topic = "New room";
let id = inner.storage.create_new_room(&*room_id.0, &*topic).await?;
let id = services.storage.create_new_room(&*room_id.0, &*topic).await?;
let room = Room {
storage_id: id,
room_id: room_id.clone(),
@ -62,7 +68,6 @@ impl RoomRegistry {
members: HashSet::new(),
topic: topic.into(),
message_count: 0,
storage: inner.storage.clone(),
};
let room_handle = RoomHandle(Arc::new(AsyncRwLock::new(room)));
inner.rooms.insert(room_id, room_handle.clone());
@ -71,11 +76,13 @@ impl RoomRegistry {
}
}
pub async fn get_room(&self, room_id: &RoomId) -> Option<RoomHandle> {
#[tracing::instrument(skip(self, services), name = "RoomRegistry::get_room")]
pub async fn get_room(&self, services: &Services, room_id: &RoomId) -> Option<RoomHandle> {
let mut inner = self.0.write().await;
inner.get_or_load_room(room_id).await.unwrap()
inner.get_or_load_room(services, room_id).await.unwrap()
}
#[tracing::instrument(skip(self), name = "RoomRegistry::get_all_rooms")]
pub async fn get_all_rooms(&self) -> Vec<RoomInfo> {
let handles = {
let inner = self.0.read().await;
@ -93,15 +100,15 @@ impl RoomRegistry {
struct RoomRegistryInner {
rooms: HashMap<RoomId, RoomHandle>,
metric_active_rooms: IntGauge,
storage: Storage,
}
impl RoomRegistryInner {
async fn get_or_load_room(&mut self, room_id: &RoomId) -> Result<Option<RoomHandle>> {
#[tracing::instrument(skip(self, services), name = "RoomRegistryInner::get_or_load_room")]
async fn get_or_load_room(&mut self, services: &Services, room_id: &RoomId) -> Result<Option<RoomHandle>> {
if let Some(room_handle) = self.rooms.get(room_id) {
log::debug!("Room {} was loaded already", &room_id.0);
Ok(Some(room_handle.clone()))
} else if let Some(stored_room) = self.storage.retrieve_room_by_name(&*room_id.0).await? {
} else if let Some(stored_room) = services.storage.retrieve_room_by_name(&*room_id.0).await? {
log::debug!("Loading room {}...", &room_id.0);
let room = Room {
storage_id: stored_room.id,
@ -110,7 +117,6 @@ impl RoomRegistryInner {
members: HashSet::new(), // TODO load members from storage
topic: stored_room.topic.into(),
message_count: stored_room.message_count,
storage: self.storage.clone(),
};
let room_handle = RoomHandle(Arc::new(AsyncRwLock::new(room)));
self.rooms.insert(room_id.clone(), room_handle.clone());
@ -125,18 +131,25 @@ impl RoomRegistryInner {
#[derive(Clone)]
pub struct RoomHandle(Arc<AsyncRwLock<Room>>);
impl RoomHandle {
#[tracing::instrument(skip(self, player_handle), name = "RoomHandle::subscribe")]
pub async fn subscribe(&self, player_id: &PlayerId, player_handle: PlayerHandle) {
let mut lock = self.0.write().await;
tracing::info!("Adding a subscriber to a room");
lock.subscriptions.insert(player_id.clone(), player_handle);
}
pub async fn add_member(&self, player_id: &PlayerId, player_storage_id: u32) {
#[tracing::instrument(skip(self, services), name = "RoomHandle::add_member")]
pub async fn add_member(&self, services: &Services, player_id: &PlayerId, player_storage_id: u32) {
let mut lock = self.0.write().await;
tracing::info!("Adding a new member to a room");
let room_storage_id = lock.storage_id;
lock.storage.add_room_member(room_storage_id, player_storage_id).await.unwrap();
if !services.storage.is_room_member(room_storage_id, player_storage_id).await.unwrap() {
services.storage.add_room_member(room_storage_id, player_storage_id).await.unwrap();
} else {
tracing::warn!("User {:#?} has already been added to the room.", player_id);
}
lock.members.insert(player_id.clone());
let update = Updates::RoomJoined {
room_id: lock.room_id.clone(),
@ -145,16 +158,18 @@ impl RoomHandle {
lock.broadcast_update(update, player_id).await;
}
#[tracing::instrument(skip(self), name = "RoomHandle::unsubscribe")]
pub async fn unsubscribe(&self, player_id: &PlayerId) {
let mut lock = self.0.write().await;
lock.subscriptions.remove(player_id);
}
pub async fn remove_member(&self, player_id: &PlayerId, player_storage_id: u32) {
#[tracing::instrument(skip(self, services), name = "RoomHandle::remove_member")]
pub async fn remove_member(&self, services: &Services, player_id: &PlayerId, player_storage_id: u32) {
let mut lock = self.0.write().await;
tracing::info!("Removing a member from a room");
let room_storage_id = lock.storage_id;
lock.storage.remove_room_member(room_storage_id, player_storage_id).await.unwrap();
services.storage.remove_room_member(room_storage_id, player_storage_id).await.unwrap();
lock.members.remove(player_id);
let update = Updates::RoomLeft {
room_id: lock.room_id.clone(),
@ -163,14 +178,16 @@ impl RoomHandle {
lock.broadcast_update(update, player_id).await;
}
pub async fn send_message(&self, player_id: &PlayerId, body: Str) {
#[tracing::instrument(skip(self, services, body, created_at), name = "RoomHandle::send_message")]
pub async fn send_message(&self, services: &Services, player_id: &PlayerId, body: Str, created_at: DateTime<Utc>) {
let mut lock = self.0.write().await;
let res = lock.send_message(player_id, body).await;
let res = lock.send_message(services, player_id, body, created_at).await;
if let Err(err) = res {
log::warn!("Failed to send message: {err:?}");
}
}
#[tracing::instrument(skip(self), name = "RoomHandle::get_room_info")]
pub async fn get_room_info(&self) -> RoomInfo {
let lock = self.0.read().await;
RoomInfo {
@ -180,11 +197,12 @@ impl RoomHandle {
}
}
pub async fn set_topic(&self, changer_id: &PlayerId, new_topic: Str) {
#[tracing::instrument(skip(self, services, new_topic), name = "RoomHandle::set_topic")]
pub async fn set_topic(&self, services: &Services, changer_id: &PlayerId, new_topic: Str) {
let mut lock = self.0.write().await;
let storage_id = lock.storage_id;
lock.topic = new_topic.clone();
lock.storage.set_room_topic(storage_id, &new_topic).await.unwrap();
services.storage.set_room_topic(storage_id, &new_topic).await.unwrap();
let update = Updates::RoomTopicChanged {
room_id: lock.room_id.clone(),
new_topic: new_topic.clone(),
@ -205,17 +223,34 @@ struct Room {
/// The total number of messages. Used to calculate the id of the new message.
message_count: u32,
topic: Str,
storage: Storage,
}
impl Room {
async fn send_message(&mut self, author_id: &PlayerId, body: Str) -> Result<()> {
#[tracing::instrument(skip(self, services, body, created_at), name = "Room::send_message")]
async fn send_message(
&mut self,
services: &Services,
author_id: &PlayerId,
body: Str,
created_at: DateTime<Utc>,
) -> Result<()> {
tracing::info!("Adding a message to room");
self.storage.insert_message(self.storage_id, self.message_count, &body, &*author_id.as_inner()).await?;
services
.storage
.insert_room_message(
self.storage_id,
self.message_count,
&body,
&*author_id.as_inner(),
&created_at,
)
.await?;
self.message_count += 1;
let update = Updates::NewMessage {
room_id: self.room_id.clone(),
author_id: author_id.clone(),
body,
created_at,
};
self.broadcast_update(update, author_id).await;
Ok(())
@ -225,6 +260,7 @@ impl Room {
///
/// This is called after handling a client command.
/// Sending the update to the player who sent the command is handled by the player actor.
#[tracing::instrument(skip(self, update), name = "Room::broadcast_update")]
async fn broadcast_update(&self, update: Updates, except: &PlayerId) {
tracing::debug!("Broadcasting an update to {} subs", self.subscriptions.len());
for (player_id, sub) in &self.subscriptions {

View File

@ -1,5 +1,7 @@
use serde::{Deserialize, Serialize};
pub mod rooms;
#[derive(Serialize, Deserialize)]
pub struct ErrorResponse<'a> {
pub code: &'a str,
@ -11,6 +13,11 @@ pub struct CreatePlayerRequest<'a> {
pub name: &'a str,
}
#[derive(Serialize, Deserialize)]
pub struct StopPlayerRequest<'a> {
pub name: &'a str,
}
#[derive(Serialize, Deserialize)]
pub struct ChangePasswordRequest<'a> {
pub player_name: &'a str,
@ -19,6 +26,7 @@ pub struct ChangePasswordRequest<'a> {
pub mod paths {
pub const CREATE_PLAYER: &'static str = "/mgmt/create_player";
pub const STOP_PLAYER: &'static str = "/mgmt/stop_player";
pub const SET_PASSWORD: &'static str = "/mgmt/set_password";
}

View File

@ -0,0 +1,24 @@
use serde::{Deserialize, Serialize};
#[derive(Serialize, Deserialize)]
pub struct SendMessageReq<'a> {
pub room_id: &'a str,
pub author_id: &'a str,
pub message: &'a str,
}
#[derive(Serialize, Deserialize)]
pub struct SetTopicReq<'a> {
pub room_id: &'a str,
pub author_id: &'a str,
pub topic: &'a str,
}
pub mod paths {
pub const SEND_MESSAGE: &'static str = "/mgmt/rooms/send_message";
pub const SET_TOPIC: &'static str = "/mgmt/rooms/set_topic";
}
pub mod errors {
pub const ROOM_NOT_FOUND: &'static str = "room_not_found";
}

View File

@ -12,6 +12,7 @@ tokio.workspace = true
prometheus.workspace = true
futures-util.workspace = true
nonempty.workspace = true
chrono.workspace = true
bitflags = "2.4.1"
proto-irc = { path = "../proto-irc" }
sasl = { path = "../sasl" }

View File

@ -1,9 +1,10 @@
use bitflags::bitflags;
bitflags! {
#[derive(Debug)]
#[derive(Debug, Clone, Copy)]
pub struct Capabilities: u32 {
const None = 0;
const Sasl = 1 << 0;
const ServerTime = 1 << 1;
}
}

View File

@ -0,0 +1,21 @@
use std::future::Future;
use tokio::io::AsyncWrite;
use lavina_core::player::PlayerConnection;
use lavina_core::prelude::Str;
pub struct IrcConnection<'a, T: AsyncWrite + Unpin> {
pub server_name: Str,
/// client is nick of requester
pub client: Str,
pub writer: &'a mut T,
pub player_connection: &'a mut PlayerConnection,
}
pub trait Handler<T>
where
T: AsyncWrite + Unpin,
{
fn handle(&self, arg: IrcConnection<T>) -> impl Future<Output = anyhow::Result<()>>;
}

View File

@ -2,6 +2,7 @@ use std::collections::HashMap;
use std::net::SocketAddr;
use anyhow::{anyhow, Result};
use chrono::SecondsFormat;
use futures_util::future::join_all;
use nonempty::nonempty;
use nonempty::NonEmpty;
@ -13,23 +14,27 @@ use tokio::net::tcp::{ReadHalf, WriteHalf};
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::mpsc::channel;
use lavina_core::auth::Verdict;
use lavina_core::player::*;
use lavina_core::prelude::*;
use lavina_core::repo::Storage;
use lavina_core::room::{RoomId, RoomInfo, RoomRegistry};
use lavina_core::room::{RoomId, RoomInfo};
use lavina_core::terminator::Terminator;
use lavina_core::LavinaCore;
use proto_irc::client::CapabilitySubcommand;
use proto_irc::client::{client_message, ClientMessage};
use proto_irc::server::CapSubBody;
use proto_irc::server::{AwayStatus, ServerMessage, ServerMessageBody};
use proto_irc::user::PrefixedNick;
use proto_irc::{Chan, Recipient};
use proto_irc::{Chan, Recipient, Tag};
use sasl::AuthBody;
mod cap;
use handler::Handler;
mod whois;
use crate::cap::Capabilities;
mod handler;
pub const APP_VERSION: &str = concat!("lavina", "_", env!("CARGO_PKG_VERSION"));
#[derive(Deserialize, Debug, Clone)]
@ -48,16 +53,15 @@ struct RegisteredUser {
*/
username: Str,
realname: Str,
enabled_capabilities: Capabilities,
}
async fn handle_socket(
config: ServerConfig,
mut stream: TcpStream,
socket_addr: &SocketAddr,
players: PlayerRegistry,
rooms: RoomRegistry,
core: LavinaCore,
termination: Deferred<()>, // TODO use it to stop the connection gracefully
mut storage: Storage,
) -> Result<()> {
log::info!("Received an IRC connection from {socket_addr}");
let (reader, writer) = stream.split();
@ -71,11 +75,11 @@ async fn handle_socket(
log::info!("Socket handling was terminated");
return Ok(())
},
registered_user = handle_registration(&mut reader, &mut writer, &mut storage, &config) =>
registered_user = handle_registration(&mut reader, &mut writer, &core, &config) =>
match registered_user {
Ok(user) => {
log::debug!("User registered");
handle_registered_socket(config, players, rooms, &mut reader, &mut writer, user).await?;
handle_registered_socket(config, &core, &mut reader, &mut writer, user).await?;
}
Err(err) => {
log::debug!("Registration failed: {err}");
@ -120,7 +124,7 @@ impl RegistrationState {
&mut self,
msg: ClientMessage,
writer: &mut BufWriter<WriteHalf<'_>>,
storage: &mut Storage,
core: &LavinaCore,
config: &ServerConfig,
) -> Result<Option<RegisteredUser>> {
match msg {
@ -136,7 +140,7 @@ impl RegistrationState {
sender: Some(config.server_name.clone().into()),
body: ServerMessageBody::Cap {
target: self.future_nickname.clone().unwrap_or_else(|| "*".into()),
subcmd: CapSubBody::Ls("sasl=PLAIN".into()),
subcmd: CapSubBody::Ls("sasl=PLAIN server-time".into()),
},
}
.write_async(writer)
@ -156,17 +160,31 @@ impl RegistrationState {
self.enabled_capabilities |= Capabilities::Sasl;
}
acked.push(cap);
} else if &*cap.name == "server-time" {
if cap.to_disable {
self.enabled_capabilities &= !Capabilities::ServerTime;
} else {
self.enabled_capabilities |= Capabilities::ServerTime;
}
acked.push(cap);
} else {
naked.push(cap);
}
}
let mut ack_body = String::new();
for cap in acked {
if let Some((first, tail)) = acked.split_first() {
if first.to_disable {
ack_body.push('-');
}
ack_body += &*first.name;
for cap in tail {
ack_body.push(' ');
if cap.to_disable {
ack_body.push('-');
}
ack_body += &*cap.name;
}
}
ServerMessage {
tags: vec![],
sender: Some(config.server_name.clone().into()),
@ -195,8 +213,9 @@ impl RegistrationState {
nickname: nickname.clone(),
username,
realname,
enabled_capabilities: self.enabled_capabilities,
};
self.finalize_auth(candidate_user, writer, storage, config).await
self.finalize_auth(candidate_user, writer, core, config).await
}
},
ClientMessage::Nick { nickname } => {
@ -208,8 +227,9 @@ impl RegistrationState {
nickname: nickname.clone(),
username: username.clone(),
realname: realname.clone(),
enabled_capabilities: self.enabled_capabilities,
};
self.finalize_auth(candidate_user, writer, storage, config).await
self.finalize_auth(candidate_user, writer, core, config).await
} else {
self.future_nickname = Some(nickname);
Ok(None)
@ -224,8 +244,9 @@ impl RegistrationState {
nickname: nickname.clone(),
username,
realname,
enabled_capabilities: self.enabled_capabilities,
};
self.finalize_auth(candidate_user, writer, storage, config).await
self.finalize_auth(candidate_user, writer, core, config).await
} else {
self.future_username = Some((username, realname));
Ok(None)
@ -256,7 +277,7 @@ impl RegistrationState {
}
} else {
let body = AuthBody::from_str(body.as_bytes())?;
if let Err(e) = auth_user(storage, &body.login, &body.password).await {
if let Err(e) = auth_user(core, &body.login, &body.password).await {
tracing::warn!("Authentication failed: {:?}", e);
let target = self.future_nickname.clone().unwrap_or_else(|| "*".into());
sasl_fail_message(config.server_name.clone(), target, "Bad credentials".into())
@ -304,7 +325,7 @@ impl RegistrationState {
&mut self,
candidate_user: RegisteredUser,
writer: &mut BufWriter<WriteHalf<'_>>,
storage: &mut Storage,
core: &LavinaCore,
config: &ServerConfig,
) -> Result<Option<RegisteredUser>> {
if self.enabled_capabilities.contains(Capabilities::Sasl)
@ -323,7 +344,7 @@ impl RegistrationState {
writer.flush().await?;
return Ok(None);
};
auth_user(storage, &*candidate_user.nickname, &*candidate_password).await?;
auth_user(core, &*candidate_user.nickname, &*candidate_password).await?;
Ok(Some(candidate_user))
}
}
@ -332,7 +353,7 @@ impl RegistrationState {
async fn handle_registration<'a>(
reader: &mut BufReader<ReadHalf<'a>>,
writer: &mut BufWriter<WriteHalf<'a>>,
storage: &mut Storage,
core: &LavinaCore,
config: &ServerConfig,
) -> Result<RegisteredUser> {
let mut buffer = vec![];
@ -368,7 +389,7 @@ async fn handle_registration<'a>(
}
};
tracing::debug!("Incoming IRC message: {msg:?}");
if let Some(user) = state.handle_msg(msg, writer, storage, config).await? {
if let Some(user) = state.handle_msg(msg, writer, core, config).await? {
break Ok(user);
}
buffer.clear();
@ -385,31 +406,19 @@ fn sasl_fail_message(sender: Str, nick: Str, text: Str) -> ServerMessage {
}
}
async fn auth_user(storage: &mut Storage, login: &str, plain_password: &str) -> Result<()> {
let stored_user = storage.retrieve_user_by_name(login).await?;
let stored_user = match stored_user {
Some(u) => u,
None => {
log::info!("User '{}' not found", login);
return Err(anyhow!("no user found"));
async fn auth_user(core: &LavinaCore, login: &str, plain_password: &str) -> Result<()> {
let verdict = core.authenticate(login, plain_password).await?;
// TODO properly map these onto protocol messages
match verdict {
Verdict::Authenticated => Ok(()),
Verdict::UserNotFound => Err(anyhow!("no user found")),
Verdict::InvalidPassword => Err(anyhow!("incorrect credentials")),
}
};
let Some(expected_password) = stored_user.password else {
log::info!("Password not defined for user '{}'", login);
return Err(anyhow!("password is not defined"));
};
if expected_password != plain_password {
log::info!("Incorrect password supplied for user '{}'", login);
return Err(anyhow!("passwords do not match"));
}
Ok(())
}
async fn handle_registered_socket<'a>(
config: ServerConfig,
mut players: PlayerRegistry,
rooms: RoomRegistry,
core: &LavinaCore,
reader: &mut BufReader<ReadHalf<'a>>,
writer: &mut BufWriter<WriteHalf<'a>>,
user: RegisteredUser,
@ -418,7 +427,7 @@ async fn handle_registered_socket<'a>(
log::info!("Handling registered user: {user:?}");
let player_id = PlayerId::from(user.nickname.clone())?;
let mut connection = players.connect_to_player(&player_id).await;
let mut connection = core.connect_to_player(&player_id).await;
let text: Str = format!("Welcome to {} Server", &config.server_name).into();
ServerMessage {
@ -492,21 +501,28 @@ async fn handle_registered_socket<'a>(
len
};
let incoming = std::str::from_utf8(&buffer[0..len-2])?;
if let HandleResult::Leave = handle_incoming_message(incoming, &config, &user, &rooms, &mut connection, writer).await? {
if let HandleResult::Leave = handle_incoming_message(incoming, &config, &user, core, &mut connection, writer).await? {
break;
}
buffer.clear();
},
update = connection.receiver.recv() => {
if let Some(update) = update {
handle_update(&config, &user, &player_id, writer, &rooms, update).await?;
} else {
match update {
Some(ConnectionMessage::Update(update)) => {
handle_update(&config, &user, &player_id, writer, core, update).await?;
}
Some(ConnectionMessage::Stop(_)) => {
tracing::debug!("Connection is being terminated");
break;
}
None => {
log::warn!("Player is terminated, must terminate the connection");
break;
}
}
}
}
}
ServerMessage {
tags: vec![],
sender: Some(config.server_name.clone()),
@ -544,14 +560,14 @@ async fn handle_update(
user: &RegisteredUser,
player_id: &PlayerId,
writer: &mut (impl AsyncWrite + Unpin),
rooms: &RoomRegistry,
core: &LavinaCore,
update: Updates,
) -> Result<()> {
log::debug!("Sending irc message to player {player_id:?} on update {update:?}");
match update {
Updates::RoomJoined { new_member_id, room_id } => {
if player_id == &new_member_id {
if let Some(room) = rooms.get_room(&room_id).await {
if let Some(room) = core.get_room(&room_id).await {
let room_info = room.get_room_info().await;
let chan = Chan::Global(room_id.as_inner().clone());
produce_on_join_cmd_messages(&config, &user, &chan, &room_info, writer).await?;
@ -587,9 +603,18 @@ async fn handle_update(
author_id,
room_id,
body,
created_at,
} => {
let mut tags = vec![];
if user.enabled_capabilities.contains(Capabilities::ServerTime) {
let tag = Tag {
key: "time".into(),
value: Some(created_at.to_rfc3339_opts(SecondsFormat::Millis, true).into()),
};
tags.push(tag);
}
ServerMessage {
tags: vec![],
tags,
sender: Some(author_id.as_inner().clone()),
body: ServerMessageBody::PrivateMessage {
target: Recipient::Chan(Chan::Global(room_id.as_inner().clone())),
@ -625,6 +650,32 @@ async fn handle_update(
.await?;
writer.flush().await?
}
Updates::NewDialogMessage {
sender,
receiver,
body,
created_at,
} => {
let mut tags = vec![];
if user.enabled_capabilities.contains(Capabilities::ServerTime) {
let tag = Tag {
key: "time".into(),
value: Some(created_at.to_rfc3339_opts(SecondsFormat::Millis, true).into()),
};
tags.push(tag);
}
ServerMessage {
tags,
sender: Some(sender.as_inner().clone()),
body: ServerMessageBody::PrivateMessage {
target: Recipient::Nick(receiver.as_inner().clone()),
body: body.clone(),
},
}
.write_async(writer)
.await?;
writer.flush().await?
}
}
Ok(())
}
@ -634,11 +685,12 @@ enum HandleResult {
Leave,
}
#[tracing::instrument(skip_all, name = "handle_incoming_message")]
async fn handle_incoming_message(
buffer: &str,
config: &ServerConfig,
user: &RegisteredUser,
rooms: &RoomRegistry,
core: &LavinaCore,
user_handle: &mut PlayerConnection,
writer: &mut (impl AsyncWrite + Unpin),
) -> Result<HandleResult> {
@ -671,6 +723,10 @@ async fn handle_incoming_message(
let room_id = RoomId::from(chan)?;
user_handle.send_message(room_id, body).await?;
}
Recipient::Nick(nick) => {
let receiver = PlayerId::from(nick)?;
user_handle.send_dialog_message(receiver, body).await?;
}
_ => log::warn!("Unsupported target type"),
},
ClientMessage::Topic { chan, topic } => {
@ -697,8 +753,6 @@ async fn handle_incoming_message(
ClientMessage::Who { target } => match &target {
Recipient::Nick(nick) => {
// TODO handle non-existing user
let mut username = format!("~{nick}");
let mut host = format!("user/{nick}");
ServerMessage {
tags: vec![],
sender: Some(config.server_name.clone()),
@ -720,7 +774,7 @@ async fn handle_incoming_message(
writer.flush().await?;
}
Recipient::Chan(Chan::Global(chan)) => {
let room = rooms.get_room(&RoomId::from(chan.clone())?).await;
let room = core.get_room(&RoomId::from(chan.clone())?).await;
if let Some(room) = room {
let room_info = room.get_room_info().await;
for member in room_info.members {
@ -750,6 +804,17 @@ async fn handle_incoming_message(
log::warn!("Local chans not supported");
}
},
ClientMessage::Whois { arg } => {
arg.handle(handler::IrcConnection {
server_name: config.server_name.clone(),
client: user.nickname.clone(),
writer,
player_connection: user_handle,
})
.await?;
writer.flush().await?;
}
ClientMessage::Mode { target } => {
match target {
Recipient::Nick(nickname) => {
@ -804,7 +869,7 @@ fn user_to_who_msg(config: &ServerConfig, requestor: &RegisteredUser, target_use
let username = format!("~{target_user_nickname}").into();
// User's host is not public, replace it with `user/<nickname>` pattern
let mut host = format!("user/{target_user_nickname}").into();
let host = format!("user/{target_user_nickname}").into();
ServerMessageBody::N352WhoReply {
client: requestor.nickname.clone(),
@ -940,13 +1005,7 @@ impl RunningServer {
}
}
pub async fn launch(
config: ServerConfig,
players: PlayerRegistry,
rooms: RoomRegistry,
metrics: MetricsRegistry,
storage: Storage,
) -> Result<RunningServer> {
pub async fn launch(config: ServerConfig, core: LavinaCore, metrics: MetricsRegistry) -> Result<RunningServer> {
log::info!("Starting IRC projection");
let (stopped_tx, mut stopped_rx) = channel(32);
let current_connections = IntGauge::new("irc_current_connections", "Open and alive TCP connections")?;
@ -984,13 +1043,11 @@ pub async fn launch(
}
let terminator = Terminator::spawn(|termination| {
let players = players.clone();
let rooms = rooms.clone();
let core = core.clone();
let current_connections_clone = current_connections.clone();
let stopped_tx = stopped_tx.clone();
let storage = storage.clone();
async move {
match handle_socket(config, stream, &socket_addr, players, rooms, termination, storage).await {
match handle_socket(config, stream, &socket_addr, core, termination).await {
Ok(_) => log::info!("Connection terminated"),
Err(err) => log::warn!("Connection failed: {err}"),
}

View File

@ -0,0 +1,67 @@
use lavina_core::{
player::{GetInfoResult, PlayerId},
prelude::Str,
};
use proto_irc::{
client::command_args::Whois,
commands::whois::{
error::{ErrNoNicknameGiven431, ErrNoSuchNick401},
response::RplEndOfWhois318,
},
response::{IrcResponseMessage, WriteResponse},
};
use tokio::io::AsyncWrite;
use crate::handler::{Handler, IrcConnection};
impl<T: AsyncWrite + Unpin> Handler<T> for Whois {
async fn handle(&self, body: IrcConnection<'_, T>) -> anyhow::Result<()> {
match self {
Whois::Nick(nick) => handle_nick_target(nick.clone(), body).await?,
Whois::TargetNick(_, nick) => handle_nick_target(nick.clone(), body).await?,
Whois::EmptyArgs => {
let IrcConnection {
server_name,
mut writer,
..
} = body;
IrcResponseMessage::empty_tags(
Some(server_name.clone()),
ErrNoNicknameGiven431::new(server_name.clone()),
)
.write_response(&mut writer)
.await?
}
}
Ok(())
}
}
async fn handle_nick_target(nick: Str, body: IrcConnection<'_, impl AsyncWrite + Unpin>) -> anyhow::Result<()> {
let IrcConnection {
server_name,
mut writer,
client,
player_connection,
} = body;
if let GetInfoResult::UserDoesntExist =
player_connection.check_user_existence(PlayerId::from(nick.clone())?).await?
{
IrcResponseMessage::empty_tags(
Some(server_name.clone()),
ErrNoSuchNick401::new(client.clone(), nick.clone()),
)
.write_response(&mut writer)
.await?
}
IrcResponseMessage::empty_tags(
Some(server_name.clone()),
RplEndOfWhois318::new(client.clone(), nick.clone()),
)
.write_response(&mut writer)
.await?;
Ok(())
}

View File

@ -1,17 +1,21 @@
use std::io::ErrorKind;
use std::net::SocketAddr;
use std::time::Duration;
use anyhow::{anyhow, Result};
use chrono::{DateTime, SecondsFormat};
use prometheus::Registry as MetricsRegistry;
use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader};
use tokio::net::tcp::{ReadHalf, WriteHalf};
use tokio::net::TcpStream;
use lavina_core::clustering::{ClusterConfig, ClusterMetadata};
use lavina_core::player::{JoinResult, PlayerId, SendMessageResult};
use lavina_core::repo::{Storage, StorageConfig};
use lavina_core::{player::PlayerRegistry, room::RoomRegistry};
use lavina_core::room::RoomId;
use lavina_core::LavinaCore;
use projection_irc::APP_VERSION;
use projection_irc::{launch, read_irc_message, RunningServer, ServerConfig};
struct TestScope<'a> {
reader: BufReader<ReadHalf<'a>>,
writer: WriteHalf<'a>,
@ -24,7 +28,7 @@ impl<'a> TestScope<'a> {
let (reader, writer) = stream.split();
let reader = BufReader::new(reader);
let buffer = vec![];
let timeout = Duration::from_millis(100);
let timeout = Duration::from_millis(1000);
TestScope {
reader,
writer,
@ -89,13 +93,15 @@ impl<'a> TestScope<'a> {
Err(_) => Ok(()),
}
}
async fn expect_cap_ls(&mut self) -> Result<()> {
self.expect(":testserver CAP * LS :sasl=PLAIN server-time").await?;
Ok(())
}
}
struct TestServer {
metrics: MetricsRegistry,
storage: Storage,
rooms: RoomRegistry,
players: PlayerRegistry,
core: LavinaCore,
server: RunningServer,
}
impl TestServer {
@ -106,60 +112,60 @@ impl TestServer {
server_name: "testserver".into(),
};
let mut metrics = MetricsRegistry::new();
let mut storage = Storage::open(StorageConfig {
let storage = Storage::open(StorageConfig {
db_path: ":memory:".into(),
})
.await?;
let rooms = RoomRegistry::new(&mut metrics, storage.clone()).unwrap();
let players = PlayerRegistry::empty(rooms.clone(), storage.clone(), &mut metrics).unwrap();
let server = launch(config, players.clone(), rooms.clone(), metrics.clone(), storage.clone()).await.unwrap();
Ok(TestServer {
metrics,
storage,
rooms,
players,
server,
})
let cluster_config = ClusterConfig {
addresses: vec![],
metadata: ClusterMetadata {
node_id: 0,
main_owner: 0,
rooms: Default::default(),
},
};
let core = LavinaCore::new(&mut metrics, cluster_config, storage).await?;
let server = launch(config, core.clone(), metrics.clone()).await.unwrap();
Ok(TestServer { core, server })
}
async fn reboot(mut self) -> Result<TestServer> {
async fn reboot(self) -> Result<TestServer> {
let config = ServerConfig {
listen_on: "127.0.0.1:0".parse().unwrap(),
server_name: "testserver".into(),
};
let TestServer {
mut metrics,
mut storage,
rooms,
mut players,
server,
} = self;
let cluster_config = ClusterConfig {
addresses: vec![],
metadata: ClusterMetadata {
node_id: 0,
main_owner: 0,
rooms: Default::default(),
},
};
let TestServer { core, server } = self;
server.terminate().await?;
players.shutdown_all().await.unwrap();
drop(players);
drop(rooms);
let storage = core.shutdown().await;
let mut metrics = MetricsRegistry::new();
let rooms = RoomRegistry::new(&mut metrics, storage.clone()).unwrap();
let players = PlayerRegistry::empty(rooms.clone(), storage.clone(), &mut metrics).unwrap();
let server = launch(config, players.clone(), rooms.clone(), metrics.clone(), storage.clone()).await.unwrap();
Ok(TestServer {
metrics,
storage,
rooms,
players,
server,
})
let core = LavinaCore::new(&mut metrics, cluster_config, storage).await?;
let server = launch(config, core.clone(), metrics.clone()).await.unwrap();
Ok(TestServer { core, server })
}
async fn shutdown(self) {
let _ = self.server.terminate().await;
let storage = self.core.shutdown().await;
let _ = storage.close().await;
}
}
#[tokio::test]
async fn scenario_basic() -> Result<()> {
let mut server = TestServer::start().await?;
let server = TestServer::start().await?;
// test scenario
server.storage.create_user("tester").await?;
server.storage.set_password("tester", "password").await?;
server.core.create_player(&PlayerId::from("tester")?).await?;
server.core.set_password("tester", "password").await?;
let mut stream = TcpStream::connect(server.server.addr).await?;
let mut s = TestScope::new(&mut stream);
@ -177,18 +183,18 @@ async fn scenario_basic() -> Result<()> {
// wrap up
server.server.terminate().await?;
server.shutdown().await;
Ok(())
}
#[tokio::test]
async fn scenario_join_and_reboot() -> Result<()> {
let mut server = TestServer::start().await?;
let server = TestServer::start().await?;
// test scenario
server.storage.create_user("tester").await?;
server.storage.set_password("tester", "password").await?;
server.core.create_player(&PlayerId::from("tester")?).await?;
server.core.set_password("tester", "password").await?;
let mut stream = TcpStream::connect(server.server.addr).await?;
let mut s = TestScope::new(&mut stream);
@ -247,18 +253,18 @@ async fn scenario_join_and_reboot() -> Result<()> {
// wrap up
server.server.terminate().await?;
server.shutdown().await;
Ok(())
}
#[tokio::test]
async fn scenario_force_join_msg() -> Result<()> {
let mut server = TestServer::start().await?;
let server = TestServer::start().await?;
// test scenario
server.storage.create_user("tester").await?;
server.storage.set_password("tester", "password").await?;
server.core.create_player(&PlayerId::from("tester")?).await?;
server.core.set_password("tester", "password").await?;
let mut stream1 = TcpStream::connect(server.server.addr).await?;
let mut s1 = TestScope::new(&mut stream1);
@ -313,20 +319,20 @@ async fn scenario_force_join_msg() -> Result<()> {
// wrap up
server.server.terminate().await?;
server.shutdown().await;
Ok(())
}
#[tokio::test]
async fn scenario_two_users() -> Result<()> {
let mut server = TestServer::start().await?;
let server = TestServer::start().await?;
// test scenario
server.storage.create_user("tester1").await?;
server.storage.set_password("tester1", "password").await?;
server.storage.create_user("tester2").await?;
server.storage.set_password("tester2", "password").await?;
server.core.create_player(&PlayerId::from("tester1")?).await?;
server.core.set_password("tester1", "password").await?;
server.core.create_player(&PlayerId::from("tester2")?).await?;
server.core.set_password("tester2", "password").await?;
let mut stream1 = TcpStream::connect(server.server.addr).await?;
let mut s1 = TestScope::new(&mut stream1);
@ -374,6 +380,18 @@ async fn scenario_two_users() -> Result<()> {
s1.expect(":tester1 PART #test").await?;
// The second user should receive the PART message
s2.expect(":tester1 PART #test").await?;
s1.send("WHOIS tester2").await?;
s1.expect(":testserver 318 tester1 tester2 :End of /WHOIS list").await?;
stream1.shutdown().await?;
s2.send("WHOIS tester3").await?;
s2.expect(":testserver 401 tester2 tester3 :No such nick/channel").await?;
s2.expect(":testserver 318 tester2 tester3 :End of /WHOIS list").await?;
stream2.shutdown().await?;
server.shutdown().await;
Ok(())
}
@ -383,12 +401,12 @@ AUTHENTICATE doc: https://modern.ircdocs.horse/#authenticate-message
*/
#[tokio::test]
async fn scenario_cap_full_negotiation() -> Result<()> {
let mut server = TestServer::start().await?;
let server = TestServer::start().await?;
// test scenario
server.storage.create_user("tester").await?;
server.storage.set_password("tester", "password").await?;
server.core.create_player(&PlayerId::from("tester")?).await?;
server.core.set_password("tester", "password").await?;
let mut stream = TcpStream::connect(server.server.addr).await?;
let mut s = TestScope::new(&mut stream);
@ -396,7 +414,7 @@ async fn scenario_cap_full_negotiation() -> Result<()> {
s.send("CAP LS 302").await?;
s.send("NICK tester").await?;
s.send("USER UserName 0 * :Real Name").await?;
s.expect(":testserver CAP * LS :sasl=PLAIN").await?;
s.expect_cap_ls().await?;
s.send("CAP REQ :sasl").await?;
s.expect(":testserver CAP tester ACK :sasl").await?;
s.send("AUTHENTICATE PLAIN").await?;
@ -417,24 +435,24 @@ async fn scenario_cap_full_negotiation() -> Result<()> {
// wrap up
server.server.terminate().await?;
server.shutdown().await;
Ok(())
}
#[tokio::test]
async fn scenario_cap_full_negotiation_nick_last() -> Result<()> {
let mut server = TestServer::start().await?;
let server = TestServer::start().await?;
// test scenario
server.storage.create_user("tester").await?;
server.storage.set_password("tester", "password").await?;
server.core.create_player(&PlayerId::from("tester")?).await?;
server.core.set_password("tester", "password").await?;
let mut stream = TcpStream::connect(server.server.addr).await?;
let mut s = TestScope::new(&mut stream);
s.send("CAP LS 302").await?;
s.expect(":testserver CAP * LS :sasl=PLAIN").await?;
s.expect_cap_ls().await?;
s.send("CAP REQ :sasl").await?;
s.expect(":testserver CAP * ACK :sasl").await?;
s.send("AUTHENTICATE PLAIN").await?;
@ -456,18 +474,18 @@ async fn scenario_cap_full_negotiation_nick_last() -> Result<()> {
// wrap up
server.server.terminate().await?;
server.shutdown().await;
Ok(())
}
#[tokio::test]
async fn scenario_cap_short_negotiation() -> Result<()> {
let mut server = TestServer::start().await?;
let server = TestServer::start().await?;
// test scenario
server.storage.create_user("tester").await?;
server.storage.set_password("tester", "password").await?;
server.core.create_player(&PlayerId::from("tester")?).await?;
server.core.set_password("tester", "password").await?;
let mut stream = TcpStream::connect(server.server.addr).await?;
let mut s = TestScope::new(&mut stream);
@ -494,18 +512,18 @@ async fn scenario_cap_short_negotiation() -> Result<()> {
// wrap up
server.server.terminate().await?;
server.shutdown().await;
Ok(())
}
#[tokio::test]
async fn scenario_cap_sasl_fail() -> Result<()> {
let mut server = TestServer::start().await?;
let server = TestServer::start().await?;
// test scenario
server.storage.create_user("tester").await?;
server.storage.set_password("tester", "password").await?;
server.core.create_player(&PlayerId::from("tester")?).await?;
server.core.set_password("tester", "password").await?;
let mut stream = TcpStream::connect(server.server.addr).await?;
let mut s = TestScope::new(&mut stream);
@ -513,7 +531,7 @@ async fn scenario_cap_sasl_fail() -> Result<()> {
s.send("CAP LS 302").await?;
s.send("NICK tester").await?;
s.send("USER UserName 0 * :Real Name").await?;
s.expect(":testserver CAP * LS :sasl=PLAIN").await?;
s.expect_cap_ls().await?;
s.send("CAP REQ :sasl").await?;
s.expect(":testserver CAP tester ACK :sasl").await?;
s.send("AUTHENTICATE SHA256").await?;
@ -538,18 +556,18 @@ async fn scenario_cap_sasl_fail() -> Result<()> {
// wrap up
server.server.terminate().await?;
server.shutdown().await;
Ok(())
}
#[tokio::test]
async fn terminate_socket_scenario() -> Result<()> {
let mut server = TestServer::start().await?;
let server = TestServer::start().await?;
// test scenario
server.storage.create_user("tester").await?;
server.storage.set_password("tester", "password").await?;
server.core.create_player(&PlayerId::from("tester")?).await?;
server.core.set_password("tester", "password").await?;
let mut stream = TcpStream::connect(server.server.addr).await?;
let mut s = TestScope::new(&mut stream);
@ -561,8 +579,142 @@ async fn terminate_socket_scenario() -> Result<()> {
s.send("AUTHENTICATE PLAIN").await?;
s.expect(":testserver AUTHENTICATE +").await?;
server.server.terminate().await?;
server.shutdown().await;
assert_eq!(stream.read_u8().await.unwrap_err().kind(), ErrorKind::UnexpectedEof);
Ok(())
}
#[tokio::test]
async fn server_time_capability() -> Result<()> {
let server = TestServer::start().await?;
// test scenario
server.core.create_player(&PlayerId::from("tester")?).await?;
server.core.set_password("tester", "password").await?;
let mut stream = TcpStream::connect(server.server.addr).await?;
let mut s = TestScope::new(&mut stream);
s.send("CAP LS 302").await?;
s.send("NICK tester").await?;
s.send("USER UserName 0 * :Real Name").await?;
s.expect_cap_ls().await?;
s.send("CAP REQ :sasl server-time").await?;
s.expect(":testserver CAP tester ACK :sasl server-time").await?;
s.send("AUTHENTICATE PLAIN").await?;
s.expect(":testserver AUTHENTICATE +").await?;
s.send("AUTHENTICATE dGVzdGVyAHRlc3RlcgBwYXNzd29yZA==").await?; // base64-encoded 'tester\x00tester\x00password'
s.expect(":testserver 900 tester tester tester :You are now logged in as tester").await?;
s.expect(":testserver 903 tester :SASL authentication successful").await?;
s.send("CAP END").await?;
s.expect_server_introduction("tester").await?;
s.expect_nothing().await?;
s.send("JOIN #test").await?;
s.expect(":tester JOIN #test").await?;
s.expect(":testserver 332 tester #test :New room").await?;
s.expect(":testserver 353 tester = #test :tester").await?;
s.expect(":testserver 366 tester #test :End of /NAMES list").await?;
server.core.create_player(&PlayerId::from("some_guy")?).await?;
let mut conn = server.core.connect_to_player(&PlayerId::from("some_guy").unwrap()).await;
let res = conn.join_room(RoomId::from("test").unwrap()).await?;
let JoinResult::Success(_) = res else {
panic!("Failed to join room");
};
s.expect(":some_guy JOIN #test").await?;
let SendMessageResult::Success(res) = conn.send_message(RoomId::from("test").unwrap(), "Hello".into()).await?
else {
panic!("Failed to send message");
};
s.expect(&format!(
"@time={} :some_guy PRIVMSG #test :Hello",
res.to_rfc3339_opts(SecondsFormat::Millis, true)
))
.await?;
// formatting check
assert_eq!(
DateTime::parse_from_rfc3339(&"2024-01-01T10:00:32.123Z").unwrap().to_rfc3339_opts(SecondsFormat::Millis, true),
"2024-01-01T10:00:32.123Z"
);
s.send("QUIT :Leaving").await?;
s.expect(":testserver ERROR :Leaving the server").await?;
s.expect_eof().await?;
stream.shutdown().await?;
// wrap up
server.shutdown().await;
Ok(())
}
#[tokio::test]
async fn scenario_two_players_dialog() -> Result<()> {
let server = TestServer::start().await?;
// test scenario
server.core.create_player(&PlayerId::from("tester1")?).await?;
server.core.set_password("tester1", "password").await?;
server.core.create_player(&PlayerId::from("tester2")?).await?;
server.core.set_password("tester2", "password").await?;
let mut stream1 = TcpStream::connect(server.server.addr).await?;
let mut s1 = TestScope::new(&mut stream1);
let mut stream2 = TcpStream::connect(server.server.addr).await?;
let mut s2 = TestScope::new(&mut stream2);
s1.send("CAP LS 302").await?;
s1.send("NICK tester1").await?;
s1.send("USER UserName 0 * :Real Name").await?;
s1.expect_cap_ls().await?;
s1.send("CAP REQ :sasl").await?;
s1.expect(":testserver CAP tester1 ACK :sasl").await?;
s1.send("AUTHENTICATE PLAIN").await?;
s1.expect(":testserver AUTHENTICATE +").await?;
s1.send("AUTHENTICATE dGVzdGVyMQB0ZXN0ZXIxAHBhc3N3b3Jk").await?; // base64-encoded 'tester1\x00tester1\x00password'
s1.expect(":testserver 900 tester1 tester1 tester1 :You are now logged in as tester1").await?;
s1.expect(":testserver 903 tester1 :SASL authentication successful").await?;
s1.send("CAP END").await?;
s1.expect_server_introduction("tester1").await?;
s1.expect_nothing().await?;
s2.send("CAP LS 302").await?;
s2.send("NICK tester2").await?;
s2.send("USER UserName 0 * :Real Name").await?;
s2.expect_cap_ls().await?;
s2.send("CAP REQ :sasl").await?;
s2.expect(":testserver CAP tester2 ACK :sasl").await?;
s2.send("AUTHENTICATE PLAIN").await?;
s2.expect(":testserver AUTHENTICATE +").await?;
s2.send("AUTHENTICATE dGVzdGVyMgB0ZXN0ZXIyAHBhc3N3b3Jk").await?; // base64-encoded 'tester2\x00tester2\x00password'
s2.expect(":testserver 900 tester2 tester2 tester2 :You are now logged in as tester2").await?;
s2.expect(":testserver 903 tester2 :SASL authentication successful").await?;
s2.send("CAP END").await?;
s2.expect_server_introduction("tester2").await?;
s2.expect_nothing().await?;
s1.send("PRIVMSG tester2 :Henlo! How are you?").await?;
s1.expect_nothing().await?;
s2.expect(":tester1 PRIVMSG tester2 :Henlo! How are you?").await?;
s2.expect_nothing().await?;
s2.send("PRIVMSG tester1 good").await?;
s2.expect_nothing().await?;
s1.expect(":tester2 PRIVMSG tester1 :good").await?;
s1.expect_nothing().await?;
stream1.shutdown().await?;
stream2.shutdown().await?;
server.shutdown().await;
Ok(())
}

View File

@ -2,32 +2,29 @@
use quick_xml::events::Event;
use lavina_core::room::{RoomId, RoomRegistry};
use proto_xmpp::bind::{BindResponse, Jid, Name, Server};
use proto_xmpp::client::{Iq, IqError, IqErrorType, IqType};
use lavina_core::room::RoomId;
use lavina_core::LavinaCore;
use proto_xmpp::bind::{BindRequest, BindResponse, Jid, Name, Server};
use proto_xmpp::client::{Iq, IqError, IqErrorCondition, IqErrorType, IqType};
use proto_xmpp::disco::{Feature, Identity, InfoQuery, Item, ItemQuery};
use proto_xmpp::mam::{Fin, Set};
use proto_xmpp::roster::RosterQuery;
use proto_xmpp::session::Session;
use proto_xmpp::xml::ToXml;
use crate::proto::IqClientBody;
use crate::XmppConnection;
use proto_xmpp::xml::ToXml;
impl<'a> XmppConnection<'a> {
pub async fn handle_iq(&self, output: &mut Vec<Event<'static>>, iq: Iq<IqClientBody>) {
match iq.body {
IqClientBody::Bind(_) => {
IqClientBody::Bind(req) => {
let req = Iq {
from: None,
id: iq.id,
to: None,
r#type: IqType::Result,
body: BindResponse(Jid {
name: Some(self.user.xmpp_name.clone()),
server: Server(self.hostname.clone()),
resource: Some(self.user.xmpp_resource.clone()),
}),
body: self.bind(&req).await,
};
req.serialize(output);
}
@ -77,7 +74,7 @@ impl<'a> XmppConnection<'a> {
}
}
IqClientBody::DiscoItem(item) => {
let response = self.disco_items(iq.to.as_ref(), &item, self.rooms).await;
let response = self.disco_items(iq.to.as_ref(), &item, self.core).await;
let req = Iq {
from: iq.to,
id: iq.id,
@ -87,6 +84,18 @@ impl<'a> XmppConnection<'a> {
};
req.serialize(output);
}
IqClientBody::MessageArchiveRequest(_) => {
let response = Iq {
from: iq.to,
id: iq.id,
to: None,
r#type: IqType::Result,
body: Fin {
set: Set { count: Some(0) },
},
};
response.serialize(output);
}
_ => {
let req = Iq {
from: None,
@ -95,6 +104,7 @@ impl<'a> XmppConnection<'a> {
r#type: IqType::Error,
body: IqError {
r#type: IqErrorType::Cancel,
condition: None,
},
};
req.serialize(output);
@ -102,6 +112,14 @@ impl<'a> XmppConnection<'a> {
}
}
pub(crate) async fn bind(&self, req: &BindRequest) -> BindResponse {
BindResponse(Jid {
name: Some(self.user.xmpp_name.clone()),
server: Server(self.hostname.clone()),
resource: Some(self.user.xmpp_resource.clone()),
})
}
async fn disco_info(&self, to: Option<&Jid>, req: &InfoQuery) -> Result<InfoQuery, IqError> {
let identity;
let feature;
@ -146,7 +164,7 @@ impl<'a> XmppConnection<'a> {
resource: None,
}) if server.0 == self.hostname_rooms => {
let room_id = RoomId::from(room_name.0.clone()).unwrap();
let Some(_) = self.rooms.get_room(&room_id).await else {
let Some(_) = self.core.get_room(&room_id).await else {
// TODO should return item-not-found
// example:
// <error type="cancel">
@ -155,6 +173,7 @@ impl<'a> XmppConnection<'a> {
// </error>
return Err(IqError {
r#type: IqErrorType::Cancel,
condition: Some(IqErrorCondition::ItemNotFound),
});
};
identity = vec![Identity {
@ -180,7 +199,7 @@ impl<'a> XmppConnection<'a> {
})
}
async fn disco_items(&self, to: Option<&Jid>, req: &ItemQuery, rooms: &RoomRegistry) -> ItemQuery {
async fn disco_items(&self, to: Option<&Jid>, req: &ItemQuery, core: &LavinaCore) -> ItemQuery {
let item = match to {
Some(Jid {
name: None,
@ -202,7 +221,7 @@ impl<'a> XmppConnection<'a> {
server,
resource: None,
}) if server.0 == self.hostname_rooms => {
let room_list = rooms.get_all_rooms().await;
let room_list = core.get_all_rooms().await;
room_list
.into_iter()
.map(|room_info| Item {

View File

@ -9,6 +9,7 @@ use std::net::SocketAddr;
use std::path::PathBuf;
use std::sync::Arc;
use anyhow::anyhow;
use futures_util::future::join_all;
use prometheus::Registry as MetricsRegistry;
use quick_xml::events::{BytesDecl, Event};
@ -21,13 +22,14 @@ use tokio::sync::mpsc::channel;
use tokio_rustls::rustls::{Certificate, PrivateKey};
use tokio_rustls::TlsAcceptor;
use lavina_core::player::{PlayerConnection, PlayerId, PlayerRegistry};
use lavina_core::auth::Verdict;
use lavina_core::player::{ConnectionMessage, PlayerConnection, PlayerId, StopReason};
use lavina_core::prelude::*;
use lavina_core::repo::Storage;
use lavina_core::room::RoomRegistry;
use lavina_core::terminator::Terminator;
use lavina_core::LavinaCore;
use proto_xmpp::bind::{Name, Resource};
use proto_xmpp::stream::*;
use proto_xmpp::streamerror::{StreamError, StreamErrorKind};
use proto_xmpp::xml::{Continuation, FromXml, Parser, ToXml};
use sasl::AuthBody;
@ -38,6 +40,9 @@ mod message;
mod presence;
mod updates;
#[cfg(test)]
mod testkit;
#[derive(Deserialize, Debug, Clone)]
pub struct ServerConfig {
pub listen_on: SocketAddr,
@ -77,13 +82,7 @@ impl RunningServer {
}
}
pub async fn launch(
config: ServerConfig,
players: PlayerRegistry,
rooms: RoomRegistry,
metrics: MetricsRegistry,
storage: Storage,
) -> Result<RunningServer> {
pub async fn launch(config: ServerConfig, core: LavinaCore, metrics: MetricsRegistry) -> Result<RunningServer> {
log::info!("Starting XMPP projection");
let certs = certs(&mut SyncBufReader::new(File::open(config.cert)?))?;
@ -122,15 +121,13 @@ pub async fn launch(
// TODO kill the older connection and restart it
continue;
}
let players = players.clone();
let rooms = rooms.clone();
let storage = storage.clone();
let core = core.clone();
let hostname = config.hostname.clone();
let terminator = Terminator::spawn(|termination| {
let stopped_tx = stopped_tx.clone();
let loaded_config = loaded_config.clone();
async move {
match handle_socket(loaded_config, stream, &socket_addr, players, rooms, storage, hostname, termination).await {
match handle_socket(loaded_config, stream, &socket_addr, core, hostname, termination).await {
Ok(_) => log::info!("Connection terminated"),
Err(err) => log::warn!("Connection failed: {err}"),
}
@ -168,9 +165,7 @@ async fn handle_socket(
cert_config: Arc<LoadedConfig>,
mut stream: TcpStream,
socket_addr: &SocketAddr,
mut players: PlayerRegistry,
rooms: RoomRegistry,
mut storage: Storage,
core: LavinaCore,
hostname: Str,
termination: Deferred<()>, // TODO use it to stop the connection gracefully
) -> Result<()> {
@ -200,21 +195,21 @@ async fn handle_socket(
pin!(termination);
select! {
biased;
_ = &mut termination =>{
_ = &mut termination => {
log::info!("Socket handling was terminated");
return Ok(())
},
authenticated = socket_auth(&mut xml_reader, &mut xml_writer, &mut reader_buf, &mut storage, &hostname) => {
authenticated = socket_auth(&mut xml_reader, &mut xml_writer, &mut reader_buf, &core, &hostname) => {
match authenticated {
Ok(authenticated) => {
let mut connection = players.connect_to_player(&authenticated.player_id).await;
let mut connection = core.connect_to_player(&authenticated.player_id).await;
socket_final(
&mut xml_reader,
&mut xml_writer,
&mut reader_buf,
&authenticated,
&mut connection,
&rooms,
&core,
&hostname,
)
.await?;
@ -271,7 +266,7 @@ async fn socket_auth(
xml_reader: &mut NsReader<(impl AsyncBufRead + Unpin)>,
xml_writer: &mut Writer<(impl AsyncWrite + Unpin)>,
reader_buf: &mut Vec<u8>,
storage: &mut Storage,
core: &LavinaCore,
hostname: &Str,
) -> Result<Authenticated> {
// TODO validate the server hostname received in the stream start
@ -296,34 +291,23 @@ async fn socket_auth(
xml_writer.get_mut().flush().await?;
let auth: proto_xmpp::sasl::Auth = proto_xmpp::sasl::Auth::parse(xml_reader, reader_buf).await?;
proto_xmpp::sasl::Success.write_xml(xml_writer).await?;
xml_writer.get_mut().flush().await?;
match AuthBody::from_str(&auth.body) {
Ok(logopass) => {
let name = &logopass.login;
let stored_user = storage.retrieve_user_by_name(name).await?;
let stored_user = match stored_user {
Some(u) => u,
None => {
log::info!("User '{}' not found", name);
return Err(fail("no user found"));
let verdict = core.authenticate(name, &logopass.password).await?;
match verdict {
Verdict::Authenticated => {
proto_xmpp::sasl::Success.write_xml(xml_writer).await?;
xml_writer.get_mut().flush().await?;
}
};
// TODO return proper XML errors to the client
if stored_user.password.is_none() {
log::info!("Password not defined for user '{}'", name);
return Err(fail("password is not defined"));
Verdict::UserNotFound | Verdict::InvalidPassword => {
proto_xmpp::sasl::Failure.write_xml(xml_writer).await?;
xml_writer.get_mut().flush().await?;
return Err(anyhow!("incorrect credentials"));
}
if stored_user.password.as_deref() != Some(&logopass.password) {
log::info!("Incorrect password supplied for user '{}'", name);
return Err(fail("passwords do not match"));
}
let name: Str = name.as_str().into();
Ok(Authenticated {
player_id: PlayerId::from(name.clone())?,
xmpp_name: Name(name.clone()),
@ -341,7 +325,7 @@ async fn socket_final(
reader_buf: &mut Vec<u8>,
authenticated: &Authenticated,
user_handle: &mut PlayerConnection,
rooms: &RoomRegistry,
core: &LavinaCore,
hostname: &Str,
) -> Result<()> {
// TODO validate the server hostname received in the stream start
@ -374,7 +358,7 @@ async fn socket_final(
let mut conn = XmppConnection {
user: authenticated,
user_handle,
rooms,
core,
hostname: hostname.clone(),
hostname_rooms: format!("rooms.{}", hostname).into(),
};
@ -406,17 +390,42 @@ async fn socket_final(
true
},
update = conn.user_handle.receiver.recv() => {
if let Some(update) = update {
match update {
Some(ConnectionMessage::Update(update)) => {
conn.handle_update(&mut events, update).await?;
for i in &events {
xml_writer.write_event_async(i).await?;
}
events.clear();
xml_writer.get_mut().flush().await?;
} else {
log::warn!("Player is terminated, must terminate the connection");
}
Some(ConnectionMessage::Stop(reason)) => {
tracing::debug!("Connection is being terminated: {reason:?}");
let kind = match reason {
StopReason::ServerShutdown => StreamErrorKind::SystemShutdown,
StopReason::InternalError => StreamErrorKind::InternalServerError,
};
StreamError { kind }.serialize(&mut events);
ServerStreamEnd.serialize(&mut events);
for i in &events {
xml_writer.write_event_async(i).await?;
}
events.clear();
xml_writer.get_mut().flush().await?;
break;
}
None => {
log::error!("Player is terminated, must terminate the connection");
StreamError { kind: StreamErrorKind::SystemShutdown }.serialize(&mut events);
ServerStreamEnd.serialize(&mut events);
for i in &events {
xml_writer.write_event_async(i).await?;
}
events.clear();
xml_writer.get_mut().flush().await?;
break;
}
}
false
}
@ -432,12 +441,13 @@ async fn socket_final(
struct XmppConnection<'a> {
user: &'a Authenticated,
user_handle: &'a mut PlayerConnection,
rooms: &'a RoomRegistry,
core: &'a LavinaCore,
hostname: Str,
hostname_rooms: Str,
}
impl<'a> XmppConnection<'a> {
#[tracing::instrument(skip(self, output, packet), name = "XmppConnection::handle_packet")]
async fn handle_packet(&mut self, output: &mut Vec<Event<'static>>, packet: ClientPacket) -> Result<bool> {
let res = match packet {
ClientPacket::Iq(iq) => {
@ -456,6 +466,7 @@ impl<'a> XmppConnection<'a> {
ServerStreamEnd.serialize(output);
true
}
ClientPacket::Eos => true,
};
Ok(res)
}

View File

@ -1,5 +1,6 @@
//! Handling of all client2server message stanzas
use lavina_core::player::PlayerId;
use quick_xml::events::Event;
use lavina_core::prelude::*;
@ -40,6 +41,9 @@ impl<'a> XmppConnection<'a> {
}
.serialize(output);
Ok(())
} else if server.0.as_ref() == &*self.hostname && m.r#type == MessageType::Chat {
self.user_handle.send_dialog_message(PlayerId::from(name.0.clone())?, m.body.clone()).await?;
Ok(())
} else {
todo!()
}

View File

@ -1,11 +1,12 @@
//! Handling of all client2server presence stanzas
use anyhow::Result;
use quick_xml::events::Event;
use lavina_core::prelude::*;
use lavina_core::room::RoomId;
use proto_xmpp::bind::{Jid, Name, Server};
use proto_xmpp::client::Presence;
use proto_xmpp::muc::XUser;
use proto_xmpp::xml::{Ignore, ToXml};
use crate::XmppConnection;
@ -14,7 +15,7 @@ impl<'a> XmppConnection<'a> {
pub async fn handle_presence(&mut self, output: &mut Vec<Event<'static>>, p: Presence<Ignore>) -> Result<()> {
match p.to {
None => {
self.self_presence(output).await;
self.self_presence(output, p.r#type.as_deref()).await;
}
Some(Jid {
name: Some(name),
@ -22,7 +23,8 @@ impl<'a> XmppConnection<'a> {
// resources in MUCs are members' personas not implemented (yet?)
resource: Some(_),
}) if server.0 == self.hostname_rooms => {
self.muc_presence(name, output).await?;
let response = self.muc_presence(&name).await?;
response.serialize(output);
}
_ => {
// TODO other presence cases
@ -33,7 +35,12 @@ impl<'a> XmppConnection<'a> {
Ok(())
}
async fn self_presence(&mut self, output: &mut Vec<Event<'static>>) {
async fn self_presence(&mut self, output: &mut Vec<Event<'static>>, r#type: Option<&str>) {
match r#type {
Some("unavailable") => {
// do not print anything
}
None => {
let response = Presence::<()> {
to: Some(Jid {
name: Some(self.user.xmpp_name.clone()),
@ -49,11 +56,15 @@ impl<'a> XmppConnection<'a> {
};
response.serialize(output);
}
_ => todo!(),
}
}
async fn muc_presence(&mut self, name: Name, output: &mut Vec<Event<'static>>) -> Result<()> {
// todo: return Presence and serialize on the outside.
async fn muc_presence(&mut self, name: &Name) -> Result<(Presence<XUser>)> {
let a = self.user_handle.join_room(RoomId::from(name.0.clone())?).await?;
// TODO handle bans
let response = Presence::<()> {
let response = Presence {
to: Some(Jid {
name: Some(self.user.xmpp_name.clone()),
server: Server(self.hostname.clone()),
@ -64,9 +75,109 @@ impl<'a> XmppConnection<'a> {
server: Server(self.hostname_rooms.clone()),
resource: Some(self.user.xmpp_muc_name.clone()),
}),
custom: vec![XUser],
..Default::default()
};
response.serialize(output);
Ok(response)
}
}
// todo: set up so that the user has been previously joined.
// todo: first call to muc_presence is OK, next one is OK too.
#[cfg(test)]
mod tests {
use crate::testkit::{expect_user_authenticated, TestServer};
use crate::Authenticated;
use anyhow::Result;
use lavina_core::player::PlayerId;
use proto_xmpp::bind::{Jid, Name, Resource, Server};
use proto_xmpp::client::Presence;
use proto_xmpp::muc::XUser;
#[tokio::test]
async fn test_muc_joining() -> Result<()> {
let server = TestServer::start().await.unwrap();
server.core.create_player(&PlayerId::from("tester")?).await?;
let player_id = PlayerId::from("tester").unwrap();
let user = Authenticated {
player_id,
xmpp_name: Name("tester".into()),
xmpp_resource: Resource("tester".into()),
xmpp_muc_name: Resource("tester".into()),
};
let mut player_conn = server.core.connect_to_player(&user.player_id).await;
let mut conn = expect_user_authenticated(&server, &user, &mut player_conn).await.unwrap();
let response = conn.muc_presence(&user.xmpp_name).await.unwrap();
let expected = Presence {
to: Some(Jid {
name: Some(conn.user.xmpp_name.clone()),
server: Server(conn.hostname.clone()),
resource: Some(conn.user.xmpp_resource.clone()),
}),
from: Some(Jid {
name: Some(user.xmpp_name.clone()),
server: Server(conn.hostname_rooms.clone()),
resource: Some(conn.user.xmpp_muc_name.clone()),
}),
custom: vec![XUser],
..Default::default()
};
assert_eq!(expected, response);
server.shutdown().await.unwrap();
Ok(())
}
// Test that joining a room second time after a server restart,
// i.e. in-memory cache of memberships is cleaned, does not cause any issues.
#[tokio::test]
async fn test_muc_joining_twice() -> Result<()> {
let server = TestServer::start().await.unwrap();
server.core.create_player(&PlayerId::from("tester")?).await?;
let player_id = PlayerId::from("tester").unwrap();
let user = Authenticated {
player_id,
xmpp_name: Name("tester".into()),
xmpp_resource: Resource("tester".into()),
xmpp_muc_name: Resource("tester".into()),
};
let mut player_conn = server.core.connect_to_player(&user.player_id).await;
let mut conn = expect_user_authenticated(&server, &user, &mut player_conn).await.unwrap();
let response = conn.muc_presence(&user.xmpp_name).await.unwrap();
let expected = Presence::<()> {
to: Some(Jid {
name: Some(conn.user.xmpp_name.clone()),
server: Server(conn.hostname.clone()),
resource: Some(conn.user.xmpp_resource.clone()),
}),
from: Some(Jid {
name: Some(user.xmpp_name.clone()),
server: Server(conn.hostname_rooms.clone()),
resource: Some(conn.user.xmpp_muc_name.clone()),
}),
..Default::default()
};
assert_eq!(expected, response);
drop(conn);
let server = server.reboot().await.unwrap();
let mut player_conn = server.core.connect_to_player(&user.player_id).await;
let mut conn = expect_user_authenticated(&server, &user, &mut player_conn).await.unwrap();
let response = conn.muc_presence(&user.xmpp_name).await.unwrap();
assert_eq!(expected, response);
server.shutdown().await.unwrap();
Ok(())
}
}

View File

@ -7,6 +7,7 @@ use lavina_core::prelude::*;
use proto_xmpp::bind::BindRequest;
use proto_xmpp::client::{Iq, Message, Presence};
use proto_xmpp::disco::{InfoQuery, ItemQuery};
use proto_xmpp::mam::MessageArchiveRequest;
use proto_xmpp::roster::RosterQuery;
use proto_xmpp::session::Session;
use proto_xmpp::xml::*;
@ -18,6 +19,7 @@ pub enum IqClientBody {
Roster(RosterQuery),
DiscoInfo(InfoQuery),
DiscoItem(ItemQuery),
MessageArchiveRequest(MessageArchiveRequest),
Unknown(Ignore),
}
@ -25,7 +27,7 @@ impl FromXml for IqClientBody {
type P = impl Parser<Output = Result<Self>>;
fn parse() -> Self::P {
|(namespace, event): (ResolveResult<'static>, &'static Event<'static>)| -> Result<Self> {
|(mut namespace, mut event): (ResolveResult<'static>, &'static Event<'static>)| -> Result<Self> {
let bytes = match event {
Event::Start(bytes) => bytes,
Event::Empty(bytes) => bytes,
@ -38,6 +40,7 @@ impl FromXml for IqClientBody {
RosterQuery,
InfoQuery,
ItemQuery,
MessageArchiveRequest,
{
delegate_parsing!(Ignore, namespace, event).into()
}
@ -52,13 +55,14 @@ pub enum ClientPacket {
Message(Message<Ignore>),
Presence(Presence<Ignore>),
StreamEnd,
Eos,
}
impl FromXml for ClientPacket {
type P = impl Parser<Output = Result<Self>>;
fn parse() -> Self::P {
|(namespace, event): (ResolveResult<'static>, &'static Event<'static>)| -> Result<Self> {
|(mut namespace, mut event): (ResolveResult<'static>, &'static Event<'static>)| -> Result<Self> {
match event {
Event::Start(bytes) | Event::Empty(bytes) => {
let name = bytes.name();
@ -83,6 +87,7 @@ impl FromXml for ClientPacket {
return Err(anyhow!("Unexpected XML event: {event:?}"));
}
}
Event::Eof => Ok(ClientPacket::Eos),
_ => {
return Err(anyhow!("Unexpected XML event: {event:?}"));
}

View File

@ -0,0 +1,78 @@
use prometheus::Registry as MetricsRegistry;
use crate::{Authenticated, XmppConnection};
use lavina_core::clustering::{ClusterConfig, ClusterMetadata};
use lavina_core::player::PlayerConnection;
use lavina_core::repo::{Storage, StorageConfig};
use lavina_core::LavinaCore;
use proto_xmpp::bind::{BindRequest, BindResponse, Jid, Name, Resource, Server};
pub(crate) struct TestServer {
pub core: LavinaCore,
}
impl TestServer {
pub async fn start() -> anyhow::Result<TestServer> {
let _ = tracing_subscriber::fmt::try_init();
let mut metrics = MetricsRegistry::new();
let storage = Storage::open(StorageConfig {
db_path: ":memory:".into(),
})
.await?;
let cluster_config = ClusterConfig {
metadata: ClusterMetadata {
node_id: 0,
main_owner: 0,
rooms: Default::default(),
},
addresses: vec![],
};
let core = LavinaCore::new(&mut metrics, cluster_config, storage).await?;
Ok(TestServer { core })
}
pub async fn reboot(self) -> anyhow::Result<TestServer> {
let storage = self.core.shutdown().await;
let mut metrics = MetricsRegistry::new();
let cluster_config = ClusterConfig {
metadata: ClusterMetadata {
node_id: 0,
main_owner: 0,
rooms: Default::default(),
},
addresses: vec![],
};
let core = LavinaCore::new(&mut metrics, cluster_config, storage).await?;
Ok(TestServer { core })
}
pub async fn shutdown(self) -> anyhow::Result<()> {
let storage = self.core.shutdown().await;
storage.close().await;
Ok(())
}
}
pub async fn expect_user_authenticated<'a>(
server: &'a TestServer,
user: &'a Authenticated,
conn: &'a mut PlayerConnection,
) -> anyhow::Result<XmppConnection<'a>> {
let conn = XmppConnection {
user: &user,
user_handle: conn,
core: &server.core,
hostname: "localhost".into(),
hostname_rooms: "rooms.localhost".into(),
};
let result = conn.bind(&BindRequest(Resource("whatever".into()))).await;
let expected = BindResponse(Jid {
name: Some(Name("tester".into())),
server: Server("localhost".into()),
resource: Some(Resource("tester".into())),
});
assert_eq!(expected, result);
Ok(conn)
}

View File

@ -17,6 +17,7 @@ impl<'a> XmppConnection<'a> {
room_id,
author_id,
body,
created_at: _,
} => {
Message::<()> {
to: Some(Jid {
@ -38,6 +39,34 @@ impl<'a> XmppConnection<'a> {
}
.serialize(output);
}
Updates::NewDialogMessage {
sender,
receiver,
body,
created_at: _,
} => {
if receiver == self.user.player_id {
Message::<()> {
to: Some(Jid {
name: Some(self.user.xmpp_name.clone()),
server: Server(self.hostname.clone()),
resource: Some(self.user.xmpp_resource.clone()),
}),
from: Some(Jid {
name: Some(Name(sender.as_inner().clone())),
server: Server(self.hostname.clone()),
resource: Some(Resource(sender.into_inner())),
}),
id: None,
r#type: MessageType::Chat,
lang: None,
subject: None,
body: body.into(),
custom: vec![],
}
.serialize(output);
}
}
_ => {}
}
Ok(())

View File

@ -1,4 +1,5 @@
use std::io::ErrorKind;
use std::str::from_utf8;
use std::sync::Arc;
use std::time::Duration;
@ -6,6 +7,7 @@ use anyhow::Result;
use assert_matches::*;
use prometheus::Registry as MetricsRegistry;
use quick_xml::events::Event;
use quick_xml::name::LocalName;
use quick_xml::NsReader;
use tokio::io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader};
use tokio::io::{ReadHalf as GenericReadHalf, WriteHalf as GenericWriteHalf};
@ -16,11 +18,15 @@ use tokio_rustls::rustls::client::ServerCertVerifier;
use tokio_rustls::rustls::{ClientConfig, ServerName};
use tokio_rustls::TlsConnector;
use lavina_core::player::PlayerRegistry;
use lavina_core::clustering::{ClusterConfig, ClusterMetadata};
use lavina_core::player::PlayerId;
use lavina_core::repo::{Storage, StorageConfig};
use lavina_core::room::RoomRegistry;
use lavina_core::LavinaCore;
use projection_xmpp::{launch, RunningServer, ServerConfig};
use proto_xmpp::xml::{Continuation, FromXml, Parser};
fn element_name<'a>(local_name: &LocalName<'a>) -> &'a str {
from_utf8(local_name.into_inner()).unwrap()
}
pub async fn read_irc_message(reader: &mut BufReader<ReadHalf<'_>>, buf: &mut Vec<u8>) -> Result<usize> {
let mut size = 0;
@ -55,19 +61,13 @@ impl<'a> TestScope<'a> {
Ok(event)
}
async fn read<T: FromXml>(&mut self) -> Result<T> {
self.buffer.clear();
let (ns, event) = self.reader.read_resolved_event_into_async(&mut self.buffer).await?;
let mut parser: Continuation<_, std::result::Result<T, anyhow::Error>> = T::parse().consume(ns, &event);
loop {
match parser {
Continuation::Final(res) => return Ok(res?),
Continuation::Continue(next) => {
let (ns, event) = self.reader.read_resolved_event_into_async(&mut self.buffer).await?;
parser = next.consume(ns, &event);
}
}
}
async fn expect_starttls_required(&mut self) -> Result<()> {
assert_matches!(self.next_xml_event().await?, Event::Start(b) => assert_eq!(element_name(&b.local_name()), "features"));
assert_matches!(self.next_xml_event().await?, Event::Start(b) => assert_eq!(element_name(&b.local_name()), "starttls"));
assert_matches!(self.next_xml_event().await?, Event::Empty(b) => assert_eq!(element_name(&b.local_name()), "required"));
assert_matches!(self.next_xml_event().await?, Event::End(b) => assert_eq!(element_name(&b.local_name()), "starttls"));
assert_matches!(self.next_xml_event().await?, Event::End(b) => assert_eq!(element_name(&b.local_name()), "features"));
Ok(())
}
}
@ -82,7 +82,7 @@ impl<'a> TestScopeTls<'a> {
fn new(stream: &'a mut TlsStream<TcpStream>, buffer: Vec<u8>) -> TestScopeTls<'a> {
let (reader, writer) = tokio::io::split(stream);
let reader = NsReader::from_reader(BufReader::new(reader));
let timeout = Duration::from_millis(100);
let timeout = Duration::from_millis(500);
TestScopeTls {
reader,
@ -98,6 +98,24 @@ impl<'a> TestScopeTls<'a> {
Ok(())
}
async fn expect_auth_mechanisms(&mut self) -> Result<()> {
assert_matches!(self.next_xml_event().await?, Event::Start(b) => assert_eq!(element_name(&b.local_name()), "features"));
assert_matches!(self.next_xml_event().await?, Event::Start(b) => assert_eq!(element_name(&b.local_name()), "mechanisms"));
assert_matches!(self.next_xml_event().await?, Event::Start(b) => assert_eq!(element_name(&b.local_name()), "mechanism"));
assert_matches!(self.next_xml_event().await?, Event::Text(b) => assert_eq!(&*b, b"PLAIN"));
assert_matches!(self.next_xml_event().await?, Event::End(b) => assert_eq!(element_name(&b.local_name()), "mechanism"));
assert_matches!(self.next_xml_event().await?, Event::End(b) => assert_eq!(element_name(&b.local_name()), "mechanisms"));
assert_matches!(self.next_xml_event().await?, Event::End(b) => assert_eq!(element_name(&b.local_name()), "features"));
Ok(())
}
async fn expect_bind_feature(&mut self) -> Result<()> {
assert_matches!(self.next_xml_event().await?, Event::Start(b) => assert_eq!(element_name(&b.local_name()), "features"));
assert_matches!(self.next_xml_event().await?, Event::Empty(b) => assert_eq!(element_name(&b.local_name()), "bind"));
assert_matches!(self.next_xml_event().await?, Event::End(b) => assert_eq!(element_name(&b.local_name()), "features"));
Ok(())
}
async fn next_xml_event(&mut self) -> Result<Event<'_>> {
self.buffer.clear();
let event = self.reader.read_event_into_async(&mut self.buffer);
@ -107,6 +125,7 @@ impl<'a> TestScopeTls<'a> {
}
struct IgnoreCertVerification;
impl ServerCertVerifier for IgnoreCertVerification {
fn verify_server_cert(
&self,
@ -122,12 +141,10 @@ impl ServerCertVerifier for IgnoreCertVerification {
}
struct TestServer {
metrics: MetricsRegistry,
storage: Storage,
rooms: RoomRegistry,
players: PlayerRegistry,
core: LavinaCore,
server: RunningServer,
}
impl TestServer {
async fn start() -> Result<TestServer> {
let _ = tracing_subscriber::fmt::try_init();
@ -138,31 +155,39 @@ impl TestServer {
hostname: "localhost".into(),
};
let mut metrics = MetricsRegistry::new();
let mut storage = Storage::open(StorageConfig {
let storage = Storage::open(StorageConfig {
db_path: ":memory:".into(),
})
.await?;
let rooms = RoomRegistry::new(&mut metrics, storage.clone()).unwrap();
let players = PlayerRegistry::empty(rooms.clone(), storage.clone(), &mut metrics).unwrap();
let server = launch(config, players.clone(), rooms.clone(), metrics.clone(), storage.clone()).await.unwrap();
Ok(TestServer {
metrics,
storage,
rooms,
players,
server,
})
let cluster_config = ClusterConfig {
addresses: vec![],
metadata: ClusterMetadata {
node_id: 0,
main_owner: 0,
rooms: Default::default(),
},
};
let core = LavinaCore::new(&mut metrics, cluster_config, storage).await?;
let server = launch(config, core.clone(), metrics.clone()).await.unwrap();
Ok(TestServer { core, server })
}
async fn shutdown(self) -> Result<()> {
self.server.terminate().await?;
let storage = self.core.shutdown().await;
storage.close().await;
Ok(())
}
}
#[tokio::test]
async fn scenario_basic() -> Result<()> {
let mut server = TestServer::start().await?;
let server = TestServer::start().await?;
// test scenario
server.storage.create_user("tester").await?;
server.storage.set_password("tester", "password").await?;
server.core.create_player(&PlayerId::from("tester")?).await?;
server.core.set_password("tester", "password").await?;
let mut stream = TcpStream::connect(server.server.addr).await?;
let mut s = TestScope::new(&mut stream);
@ -171,14 +196,10 @@ async fn scenario_basic() -> Result<()> {
s.send(r#"<?xml version="1.0"?>"#).await?;
s.send(r#"<stream:stream xmlns:stream="http://etherx.jabber.org/streams" to="127.0.0.1" xml:lang="en" xmlns:xml="http://www.w3.org/XML/1998/namespace" xmlns="jabber:client" version="1.0">"#).await?;
assert_matches!(s.next_xml_event().await?, Event::Decl(_) => {});
assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"stream"));
assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"features"));
assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"starttls"));
assert_matches!(s.next_xml_event().await?, Event::Empty(b) => assert_eq!(b.local_name().into_inner(), b"required"));
assert_matches!(s.next_xml_event().await?, Event::End(b) => assert_eq!(b.local_name().into_inner(), b"starttls"));
assert_matches!(s.next_xml_event().await?, Event::End(b) => assert_eq!(b.local_name().into_inner(), b"features"));
assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(element_name(&b.local_name()), "stream"));
s.expect_starttls_required().await?;
s.send(r#"<starttls/>"#).await?;
assert_matches!(s.next_xml_event().await?, Event::Empty(b) => assert_eq!(b.local_name().into_inner(), b"proceed"));
assert_matches!(s.next_xml_event().await?, Event::Empty(b) => assert_eq!(element_name(&b.local_name()), "proceed"));
let buffer = s.buffer;
tracing::info!("TLS feature negotiation complete");
@ -197,24 +218,99 @@ async fn scenario_basic() -> Result<()> {
s.send(r#"<?xml version="1.0"?>"#).await?;
s.send(r#"<stream:stream xmlns:stream="http://etherx.jabber.org/streams" to="127.0.0.1" xml:lang="en" xmlns:xml="http://www.w3.org/XML/1998/namespace" xmlns="jabber:client" version="1.0">"#).await?;
assert_matches!(s.next_xml_event().await?, Event::Decl(_) => {});
assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"stream"));
assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(element_name(&b.local_name()), "stream"));
s.expect_auth_mechanisms().await?;
// base64-encoded "\x00tester\x00password"
s.send(r#"<auth xmlns="urn:ietf:params:xml:ns:xmpp-sasl" mechanism="PLAIN">AHRlc3RlcgBwYXNzd29yZA==</auth>"#)
.await?;
assert_matches!(s.next_xml_event().await?, Event::Empty(b) => assert_eq!(element_name(&b.local_name()), "success"));
s.send(r#"<?xml version="1.0"?>"#).await?;
s.send(r#"<stream:stream xmlns:stream="http://etherx.jabber.org/streams" to="127.0.0.1" xml:lang="en" xmlns:xml="http://www.w3.org/XML/1998/namespace" xmlns="jabber:client" version="1.0">"#).await?;
assert_matches!(s.next_xml_event().await?, Event::Decl(_) => {});
assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(element_name(&b.local_name()), "stream"));
s.expect_bind_feature().await?;
s.send(r#"<iq id="bind_1" type="set"><bind xmlns="urn:ietf:params:xml:ns:xmpp-bind"><resource>kek</resource></bind></iq>"#).await?;
assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(element_name(&b.local_name()), "iq"));
assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(element_name(&b.local_name()), "bind"));
assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(element_name(&b.local_name()), "jid"));
assert_matches!(s.next_xml_event().await?, Event::Text(b) => assert_eq!(&*b, b"tester@localhost/tester"));
assert_matches!(s.next_xml_event().await?, Event::End(b) => assert_eq!(element_name(&b.local_name()), "jid"));
assert_matches!(s.next_xml_event().await?, Event::End(b) => assert_eq!(element_name(&b.local_name()), "bind"));
assert_matches!(s.next_xml_event().await?, Event::End(b) => assert_eq!(element_name(&b.local_name()), "iq"));
s.send(r#"<presence xmlns="jabber:client" type="unavailable"><status>Logged out</status></presence>"#).await?;
stream.shutdown().await?;
// wrap up
server.server.terminate().await?;
server.shutdown().await?;
Ok(())
}
#[tokio::test]
async fn scenario_wrong_password() -> Result<()> {
let server = TestServer::start().await?;
// test scenario
server.core.create_player(&PlayerId::from("tester")?).await?;
server.core.set_password("tester", "password").await?;
let mut stream = TcpStream::connect(server.server.addr).await?;
let mut s = TestScope::new(&mut stream);
tracing::info!("TCP connection established");
s.send(r#"<?xml version="1.0"?>"#).await?;
s.send(r#"<stream:stream xmlns:stream="http://etherx.jabber.org/streams" to="127.0.0.1" xml:lang="en" xmlns:xml="http://www.w3.org/XML/1998/namespace" xmlns="jabber:client" version="1.0">"#).await?;
assert_matches!(s.next_xml_event().await?, Event::Decl(_) => {});
assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(element_name(&b.local_name()), "stream"));
s.expect_starttls_required().await?;
s.send(r#"<starttls/>"#).await?;
assert_matches!(s.next_xml_event().await?, Event::Empty(b) => assert_eq!(element_name(&b.local_name()), "proceed"));
let buffer = s.buffer;
tracing::info!("TLS feature negotiation complete");
let connector = TlsConnector::from(Arc::new(
ClientConfig::builder()
.with_safe_defaults()
.with_custom_certificate_verifier(Arc::new(IgnoreCertVerification))
.with_no_client_auth(),
));
tracing::info!("Initiating TLS connection...");
let mut stream = connector.connect(ServerName::IpAddress(server.server.addr.ip()), stream).await?;
tracing::info!("TLS connection established");
let mut s = TestScopeTls::new(&mut stream, buffer);
s.send(r#"<?xml version="1.0"?>"#).await?;
s.send(r#"<stream:stream xmlns:stream="http://etherx.jabber.org/streams" to="127.0.0.1" xml:lang="en" xmlns:xml="http://www.w3.org/XML/1998/namespace" xmlns="jabber:client" version="1.0">"#).await?;
assert_matches!(s.next_xml_event().await?, Event::Decl(_) => {});
assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(element_name(&b.local_name()), "stream"));
s.expect_auth_mechanisms().await?;
// base64-encoded "\x00tester\x00password2"
s.send(r#"<auth xmlns="urn:ietf:params:xml:ns:xmpp-sasl" mechanism="PLAIN">AHRlc3RlcgBwYXNzd29yZDI=</auth>"#)
.await?;
assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(element_name(&b.local_name()), "failure"));
assert_matches!(s.next_xml_event().await?, Event::Empty(b) => assert_eq!(element_name(&b.local_name()), "not-authorized"));
assert_matches!(s.next_xml_event().await?, Event::End(b) => assert_eq!(element_name(&b.local_name()), "failure"));
let _ = stream.shutdown().await;
// wrap up
server.shutdown().await?;
Ok(())
}
#[tokio::test]
async fn scenario_basic_without_headers() -> Result<()> {
let mut server = TestServer::start().await?;
let server = TestServer::start().await?;
// test scenario
server.storage.create_user("tester").await?;
server.storage.set_password("tester", "password").await?;
server.core.create_player(&PlayerId::from("tester")?).await?;
server.core.set_password("tester", "password").await?;
let mut stream = TcpStream::connect(server.server.addr).await?;
let mut s = TestScope::new(&mut stream);
@ -222,14 +318,10 @@ async fn scenario_basic_without_headers() -> Result<()> {
s.send(r#"<stream:stream xmlns:stream="http://etherx.jabber.org/streams" to="127.0.0.1" xml:lang="en" xmlns:xml="http://www.w3.org/XML/1998/namespace" xmlns="jabber:client" version="1.0">"#).await?;
assert_matches!(s.next_xml_event().await?, Event::Decl(_) => {});
assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"stream"));
assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"features"));
assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"starttls"));
assert_matches!(s.next_xml_event().await?, Event::Empty(b) => assert_eq!(b.local_name().into_inner(), b"required"));
assert_matches!(s.next_xml_event().await?, Event::End(b) => assert_eq!(b.local_name().into_inner(), b"starttls"));
assert_matches!(s.next_xml_event().await?, Event::End(b) => assert_eq!(b.local_name().into_inner(), b"features"));
assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(element_name(&b.local_name()), "stream"));
s.expect_starttls_required().await?;
s.send(r#"<starttls/>"#).await?;
assert_matches!(s.next_xml_event().await?, Event::Empty(b) => assert_eq!(b.local_name().into_inner(), b"proceed"));
assert_matches!(s.next_xml_event().await?, Event::Empty(b) => assert_eq!(element_name(&b.local_name()), "proceed"));
let buffer = s.buffer;
tracing::info!("TLS feature negotiation complete");
@ -247,24 +339,24 @@ async fn scenario_basic_without_headers() -> Result<()> {
s.send(r#"<stream:stream xmlns:stream="http://etherx.jabber.org/streams" to="127.0.0.1" xml:lang="en" xmlns:xml="http://www.w3.org/XML/1998/namespace" xmlns="jabber:client" version="1.0">"#).await?;
assert_matches!(s.next_xml_event().await?, Event::Decl(_) => {});
assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"stream"));
assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(element_name(&b.local_name()), "stream"));
stream.shutdown().await?;
// wrap up
server.server.terminate().await?;
server.shutdown().await?;
Ok(())
}
#[tokio::test]
async fn terminate_socket() -> Result<()> {
let mut server = TestServer::start().await?;
let server = TestServer::start().await?;
// test scenario
server.storage.create_user("tester").await?;
server.storage.set_password("tester", "password").await?;
server.core.create_player(&PlayerId::from("tester")?).await?;
server.core.set_password("tester", "password").await?;
let mut stream = TcpStream::connect(server.server.addr).await?;
let mut s = TestScope::new(&mut stream);
@ -274,14 +366,10 @@ async fn terminate_socket() -> Result<()> {
s.send(r#"<stream:stream xmlns:stream="http://etherx.jabber.org/streams" to="127.0.0.1" xml:lang="en" xmlns:xml="http://www.w3.org/XML/1998/namespace" xmlns="jabber:client" version="1.0">"#).await?;
assert_matches!(s.next_xml_event().await?, Event::Decl(_) => {});
assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"stream"));
assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"features"));
assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"starttls"));
assert_matches!(s.next_xml_event().await?, Event::Empty(b) => assert_eq!(b.local_name().into_inner(), b"required"));
assert_matches!(s.next_xml_event().await?, Event::End(b) => assert_eq!(b.local_name().into_inner(), b"starttls"));
assert_matches!(s.next_xml_event().await?, Event::End(b) => assert_eq!(b.local_name().into_inner(), b"features"));
assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(element_name(&b.local_name()), "stream"));
s.expect_starttls_required().await?;
s.send(r#"<starttls/>"#).await?;
assert_matches!(s.next_xml_event().await?, Event::Empty(b) => assert_eq!(b.local_name().into_inner(), b"proceed"));
assert_matches!(s.next_xml_event().await?, Event::Empty(b) => assert_eq!(element_name(&b.local_name()), "proceed"));
let connector = TlsConnector::from(Arc::new(
ClientConfig::builder()
@ -294,9 +382,95 @@ async fn terminate_socket() -> Result<()> {
let mut stream = connector.connect(ServerName::IpAddress(server.server.addr.ip()), stream).await?;
tracing::info!("TLS connection established");
server.server.terminate().await?;
server.shutdown().await?;
assert_eq!(stream.read_u8().await.unwrap_err().kind(), ErrorKind::UnexpectedEof);
Ok(())
}
#[tokio::test]
async fn test_message_archive_request() -> Result<()> {
let server = TestServer::start().await?;
// test scenario
server.core.create_player(&PlayerId::from("tester")?).await?;
server.core.set_password("tester", "password").await?;
let mut stream = TcpStream::connect(server.server.addr).await?;
let mut s = TestScope::new(&mut stream);
tracing::info!("TCP connection established");
s.send(r#"<?xml version="1.0"?>"#).await?;
s.send(r#"<stream:stream xmlns:stream="http://etherx.jabber.org/streams" to="127.0.0.1" xml:lang="en" xmlns:xml="http://www.w3.org/XML/1998/namespace" xmlns="jabber:client" version="1.0">"#).await?;
assert_matches!(s.next_xml_event().await?, Event::Decl(_) => {});
assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(element_name(&b.local_name()), "stream"));
s.expect_starttls_required().await?;
s.send(r#"<starttls/>"#).await?;
assert_matches!(s.next_xml_event().await?, Event::Empty(b) => assert_eq!(element_name(&b.local_name()), "proceed"));
let buffer = s.buffer;
tracing::info!("TLS feature negotiation complete");
let connector = TlsConnector::from(Arc::new(
ClientConfig::builder()
.with_safe_defaults()
.with_custom_certificate_verifier(Arc::new(IgnoreCertVerification))
.with_no_client_auth(),
));
tracing::info!("Initiating TLS connection...");
let mut stream = connector.connect(ServerName::IpAddress(server.server.addr.ip()), stream).await?;
tracing::info!("TLS connection established");
let mut s = TestScopeTls::new(&mut stream, buffer);
s.send(r#"<?xml version="1.0"?>"#).await?;
s.send(r#"<stream:stream xmlns:stream="http://etherx.jabber.org/streams" to="127.0.0.1" xml:lang="en" xmlns:xml="http://www.w3.org/XML/1998/namespace" xmlns="jabber:client" version="1.0">"#).await?;
assert_matches!(s.next_xml_event().await?, Event::Decl(_) => {});
assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(element_name(&b.local_name()), "stream"));
s.expect_auth_mechanisms().await?;
// base64-encoded "\x00tester\x00password"
s.send(r#"<auth xmlns="urn:ietf:params:xml:ns:xmpp-sasl" mechanism="PLAIN">AHRlc3RlcgBwYXNzd29yZA==</auth>"#)
.await?;
assert_matches!(s.next_xml_event().await?, Event::Empty(b) => assert_eq!(element_name(&b.local_name()), "success"));
s.send(r#"<?xml version="1.0"?>"#).await?;
s.send(r#"<stream:stream xmlns:stream="http://etherx.jabber.org/streams" to="127.0.0.1" xml:lang="en" xmlns:xml="http://www.w3.org/XML/1998/namespace" xmlns="jabber:client" version="1.0">"#).await?;
assert_matches!(s.next_xml_event().await?, Event::Decl(_) => {});
assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(element_name(&b.local_name()), "stream"));
s.expect_bind_feature().await?;
s.send(r#"<iq id="bind_1" type="set"><bind xmlns="urn:ietf:params:xml:ns:xmpp-bind"><resource>kek</resource></bind></iq>"#).await?;
assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(element_name(&b.local_name()), "iq"));
assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(element_name(&b.local_name()), "bind"));
assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(element_name(&b.local_name()), "jid"));
assert_matches!(s.next_xml_event().await?, Event::Text(b) => assert_eq!(&*b, b"tester@localhost/tester"));
assert_matches!(s.next_xml_event().await?, Event::End(b) => assert_eq!(element_name(&b.local_name()), "jid"));
assert_matches!(s.next_xml_event().await?, Event::End(b) => assert_eq!(element_name(&b.local_name()), "bind"));
assert_matches!(s.next_xml_event().await?, Event::End(b) => assert_eq!(element_name(&b.local_name()), "iq"));
s.send(r#"<iq type='get' id='juliet1'><query xmlns='urn:xmpp:mam:2' queryid='f27'/></iq>"#).await?;
assert_matches!(s.next_xml_event().await?, Event::Start(b) => {
assert_eq!(element_name(&b.local_name()), "iq")
});
assert_matches!(s.next_xml_event().await?, Event::Start(b) => {
assert_eq!(element_name(&b.local_name()), "fin")
});
assert_matches!(s.next_xml_event().await?, Event::Start(b) => {
assert_eq!(element_name(&b.local_name()), "set")
});
assert_matches!(s.next_xml_event().await?, Event::Start(b) => {
assert_eq!(element_name(&b.local_name()), "count")
});
assert_matches!(s.next_xml_event().await?, Event::Text(b) => {
assert_eq!(&*b, b"0")
});
s.send(r#"<presence xmlns="jabber:client" type="unavailable"><status>Logged out</status></presence>"#).await?;
stream.shutdown().await?;
// wrap up
server.shutdown().await?;
Ok(())
}

View File

@ -1,9 +1,9 @@
use super::*;
use anyhow::{anyhow, Result};
use nom::combinator::{all_consuming, opt};
use nonempty::NonEmpty;
use super::*;
/// Client-to-server command.
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum ClientMessage {
@ -42,6 +42,10 @@ pub enum ClientMessage {
Who {
target: Recipient, // aka mask
},
/// WHOIS [<target>] <nick>
Whois {
arg: command_args::Whois,
},
/// `TOPIC <chan> :<topic>`
Topic {
chan: Chan,
@ -63,6 +67,17 @@ pub enum ClientMessage {
Authenticate(Str),
}
pub mod command_args {
use crate::prelude::Str;
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum Whois {
Nick(Str),
TargetNick(Str, Str),
EmptyArgs,
}
}
pub fn client_message(input: &str) -> Result<ClientMessage> {
let res = all_consuming(alt((
client_message_capability,
@ -74,6 +89,7 @@ pub fn client_message(input: &str) -> Result<ClientMessage> {
client_message_join,
client_message_mode,
client_message_who,
client_message_whois,
client_message_topic,
client_message_part,
client_message_privmsg,
@ -177,6 +193,31 @@ fn client_message_who(input: &str) -> IResult<&str, ClientMessage> {
Ok((input, ClientMessage::Who { target }))
}
fn client_message_whois(input: &str) -> IResult<&str, ClientMessage> {
let (input, _) = tag("WHOIS ")(input)?;
let args: Vec<_> = input.split_whitespace().collect();
match args.as_slice()[..] {
[nick] => Ok((
"",
ClientMessage::Whois {
arg: command_args::Whois::Nick(nick.into()),
},
)),
[target, nick, ..] => Ok((
"",
ClientMessage::Whois {
arg: command_args::Whois::TargetNick(target.into(), nick.into()),
},
)),
[] => Ok((
"",
ClientMessage::Whois {
arg: command_args::Whois::EmptyArgs,
},
)),
}
}
fn client_message_topic(input: &str) -> IResult<&str, ClientMessage> {
let (input, _) = tag("TOPIC ")(input)?;
let (input, chan) = chan(input)?;
@ -311,6 +352,7 @@ mod test {
use nonempty::nonempty;
use super::*;
#[test]
fn test_client_message_cap_ls() {
let input = "CAP LS 302";
@ -360,6 +402,66 @@ mod test {
assert_matches!(result, Ok(result) => assert_eq!(expected, result));
}
#[test]
fn test_client_message_whois() {
let test_user = "WHOIS val";
let test_user_user = "WHOIS val val";
let test_server_user = "WHOIS com.test.server user";
let test_user_server = "WHOIS user com.test.server";
let test_users_list = "WHOIS user_1,user_2,user_3";
let test_server_users_list = "WHOIS com.test.server user_1,user_2,user_3";
let test_more_than_two_params = "WHOIS test.server user_1,user_2,user_3 whatever spam";
let test_none_none_params = "WHOIS ";
let res_one_arg = client_message(test_user);
let res_user_user = client_message(test_user_user);
let res_server_user = client_message(test_server_user);
let res_user_server = client_message(test_user_server);
let res_users_list = client_message(test_users_list);
let res_server_users_list = client_message(test_server_users_list);
let res_more_than_two_params = client_message(test_more_than_two_params);
let res_none_none_params = client_message(test_none_none_params);
let expected_arg = ClientMessage::Whois {
arg: command_args::Whois::Nick("val".into()),
};
let expected_user_user = ClientMessage::Whois {
arg: command_args::Whois::TargetNick("val".into(), "val".into()),
};
let expected_server_user = ClientMessage::Whois {
arg: command_args::Whois::TargetNick("com.test.server".into(), "user".into()),
};
let expected_user_server = ClientMessage::Whois {
arg: command_args::Whois::TargetNick("user".into(), "com.test.server".into()),
};
let expected_user_list = ClientMessage::Whois {
arg: command_args::Whois::Nick("user_1,user_2,user_3".into()),
};
let expected_server_user_list = ClientMessage::Whois {
arg: command_args::Whois::TargetNick("com.test.server".into(), "user_1,user_2,user_3".into()),
};
let expected_more_than_two_params = ClientMessage::Whois {
arg: command_args::Whois::TargetNick("test.server".into(), "user_1,user_2,user_3".into()),
};
let expected_none_none_params = ClientMessage::Whois {
arg: command_args::Whois::EmptyArgs,
};
assert_matches!(res_one_arg, Ok(result) => assert_eq!(expected_arg, result));
assert_matches!(res_user_user, Ok(result) => assert_eq!(expected_user_user, result));
assert_matches!(res_server_user, Ok(result) => assert_eq!(expected_server_user, result));
assert_matches!(res_user_server, Ok(result) => assert_eq!(expected_user_server, result));
assert_matches!(res_users_list, Ok(result) => assert_eq!(expected_user_list, result));
assert_matches!(res_server_users_list, Ok(result) => assert_eq!(expected_server_user_list, result));
assert_matches!(res_more_than_two_params, Ok(result) => assert_eq!(expected_more_than_two_params, result));
assert_matches!(res_none_none_params, Ok(result) => assert_eq!(expected_none_none_params, result))
}
#[test]
fn test_client_message_user() {
let input = "USER SomeNick 8 * :Real Name";
let expected = ClientMessage::User {

View File

@ -0,0 +1 @@
pub mod whois;

View File

@ -0,0 +1,67 @@
use tokio::io::{AsyncWrite, AsyncWriteExt};
use crate::{prelude::Str, response::WriteResponse};
/// ErrNoSuchNick401
pub struct ErrNoSuchNick401 {
client: Str,
nick: Str,
}
impl ErrNoSuchNick401 {
pub fn new(client: Str, nick: Str) -> Self {
ErrNoSuchNick401 { client, nick }
}
}
/// ErrNoSuchServer402
struct ErrNoSuchServer402 {
client: Str,
/// target parameter in WHOIS
/// example: `/whois <target> <nick>`
server_name: Str,
}
/// ErrNoNicknameGiven431
pub struct ErrNoNicknameGiven431 {
client: Str,
}
impl ErrNoNicknameGiven431 {
pub fn new(client: Str) -> Self {
ErrNoNicknameGiven431 { client }
}
}
impl WriteResponse for ErrNoSuchNick401 {
async fn write_response(&self, writer: &mut (impl AsyncWrite + Unpin)) -> std::io::Result<()> {
writer.write_all(b"401 ").await?;
writer.write_all(self.client.as_bytes()).await?;
writer.write_all(b" ").await?;
writer.write_all(self.nick.as_bytes()).await?;
writer.write_all(b" :").await?;
writer.write_all("No such nick/channel".as_bytes()).await?;
Ok(())
}
}
impl WriteResponse for ErrNoNicknameGiven431 {
async fn write_response(&self, writer: &mut (impl AsyncWrite + Unpin)) -> std::io::Result<()> {
writer.write_all(b"431").await?;
writer.write_all(self.client.as_bytes()).await?;
writer.write_all(b" :").await?;
writer.write_all("No nickname given".as_bytes()).await?;
Ok(())
}
}
impl WriteResponse for ErrNoSuchServer402 {
async fn write_response(&self, writer: &mut (impl AsyncWrite + Unpin)) -> std::io::Result<()> {
writer.write_all(b"402 ").await?;
writer.write_all(self.client.as_bytes()).await?;
writer.write_all(b" ").await?;
writer.write_all(self.server_name.as_bytes()).await?;
writer.write_all(b" :").await?;
writer.write_all("No such server".as_bytes()).await?;
Ok(())
}
}

View File

@ -0,0 +1,2 @@
pub mod error;
pub mod response;

View File

@ -0,0 +1,24 @@
use tokio::io::{AsyncWrite, AsyncWriteExt};
use crate::{prelude::Str, response::WriteResponse};
pub struct RplEndOfWhois318 {
client: Str,
nick: Str,
}
impl RplEndOfWhois318 {
pub fn new(client: Str, nick: Str) -> Self {
RplEndOfWhois318 { client, nick }
}
}
impl WriteResponse for RplEndOfWhois318 {
async fn write_response(&self, writer: &mut (impl AsyncWrite + Unpin)) -> std::io::Result<()> {
writer.write_all(b"318 ").await?;
writer.write_all(self.client.as_bytes()).await?;
writer.write_all(b" ").await?;
writer.write_all(self.nick.as_bytes()).await?;
writer.write_all(b" :").await?;
writer.write_all("End of /WHOIS list".as_bytes()).await?;
Ok(())
}
}

View File

@ -1,6 +1,8 @@
//! Client-to-Server IRC protocol.
pub mod client;
pub mod commands;
mod prelude;
pub mod response;
pub mod server;
#[cfg(test)]
mod testkit;
@ -18,8 +20,19 @@ use tokio::io::{AsyncWrite, AsyncWriteExt};
/// Single message tag value.
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct Tag {
key: Str,
value: Option<u8>,
pub key: Str,
pub value: Option<Str>,
}
impl Tag {
pub async fn write_async(&self, writer: &mut (impl AsyncWrite + Unpin)) -> std::io::Result<()> {
writer.write_all(self.key.as_bytes()).await?;
if let Some(value) = &self.value {
writer.write_all(b"=").await?;
writer.write_all(value.as_bytes()).await?;
}
Ok(())
}
}
fn receiver(input: &str) -> IResult<&str, &str> {

View File

@ -0,0 +1,47 @@
use std::future::Future;
use tokio::io::{AsyncWrite, AsyncWriteExt};
use crate::prelude::Str;
use crate::Tag;
pub trait WriteResponse {
fn write_response(&self, writer: &mut (impl AsyncWrite + Unpin)) -> impl Future<Output = std::io::Result<()>>;
}
/// Server-to-client enum agnostic message
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct IrcResponseMessage<T> {
/// Optional tags section, prefixed with `@`
pub tags: Vec<Tag>,
/// Optional server name, prefixed with `:`.
pub sender: Option<Str>,
pub body: T,
}
impl<T> IrcResponseMessage<T> {
pub fn empty_tags(sender: Option<Str>, body: T) -> Self {
IrcResponseMessage {
tags: vec![],
sender,
body,
}
}
pub fn new(tags: Vec<Tag>, sender: Option<Str>, body: T) -> Self {
IrcResponseMessage { tags, sender, body }
}
}
impl<T: WriteResponse> WriteResponse for IrcResponseMessage<T> {
async fn write_response(&self, writer: &mut (impl AsyncWrite + Unpin)) -> std::io::Result<()> {
if let Some(sender) = &self.sender {
writer.write_all(b":").await?;
writer.write_all(sender.as_bytes()).await?;
writer.write_all(b" ").await?;
}
self.body.write_response(writer).await?;
writer.write_all(b"\r\n").await?;
Ok(())
}
}

View File

@ -1,12 +1,11 @@
use std::sync::Arc;
use nonempty::NonEmpty;
use tokio::io::AsyncWrite;
use tokio::io::AsyncWriteExt;
use super::*;
use crate::user::PrefixedNick;
use super::*;
/// Server-to-client message.
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct ServerMessage {
@ -19,6 +18,13 @@ pub struct ServerMessage {
impl ServerMessage {
pub async fn write_async(&self, writer: &mut (impl AsyncWrite + Unpin)) -> std::io::Result<()> {
if !self.tags.is_empty() {
for tag in &self.tags {
writer.write_all(b"@").await?;
tag.write_async(writer).await?;
writer.write_all(b" ").await?;
}
}
match &self.sender {
Some(ref sender) => {
writer.write_all(b":").await?;
@ -107,6 +113,12 @@ pub enum ServerMessageBody {
/// Usually `b"End of WHO list"`
msg: Str,
},
N318EndOfWhois {
client: Str,
nick: Str,
/// Usually `b"End of /WHOIS list"`
msg: Str,
},
N332Topic {
client: Str,
chat: Chan,
@ -136,6 +148,10 @@ pub enum ServerMessageBody {
client: Str,
chan: Chan,
},
N431ErrNoNicknameGiven {
client: Str,
message: Str,
},
N474BannedFromChan {
client: Str,
chan: Chan,
@ -273,6 +289,14 @@ impl ServerMessageBody {
writer.write_all(b" :").await?;
writer.write_all(msg.as_bytes()).await?;
}
ServerMessageBody::N318EndOfWhois { client, nick, msg } => {
writer.write_all(b"318 ").await?;
writer.write_all(client.as_bytes()).await?;
writer.write_all(b" ").await?;
writer.write_all(nick.as_bytes()).await?;
writer.write_all(b" :").await?;
writer.write_all(msg.as_bytes()).await?;
}
ServerMessageBody::N332Topic { client, chat, topic } => {
writer.write_all(b"332 ").await?;
writer.write_all(client.as_bytes()).await?;
@ -335,6 +359,12 @@ impl ServerMessageBody {
chan.write_async(writer).await?;
writer.write_all(b" :End of /NAMES list").await?;
}
ServerMessageBody::N431ErrNoNicknameGiven { client, message } => {
writer.write_all(b"431").await?;
writer.write_all(client.as_bytes()).await?;
writer.write_all(b" :").await?;
writer.write_all(message.as_bytes()).await?;
}
ServerMessageBody::N474BannedFromChan { client, chan, message } => {
writer.write_all(b"474 ").await?;
writer.write_all(client.as_bytes()).await?;
@ -463,9 +493,10 @@ fn server_message_body_cap(input: &str) -> IResult<&str, ServerMessageBody> {
mod test {
use assert_matches::*;
use super::*;
use crate::testkit::*;
use super::*;
#[test]
fn test_server_message_notice() {
let input = "NOTICE * :*** Looking up your hostname...\r\n";

View File

@ -74,15 +74,15 @@ impl Jid {
pub struct BindRequest(pub Resource);
impl FromXmlTag for BindRequest {
const NS: &'static str = XMLNS;
const NAME: &'static str = "bind";
const NS: &'static str = XMLNS;
}
impl FromXml for BindRequest {
type P = impl Parser<Output = Result<Self>>;
fn parse() -> Self::P {
|(namespace, event): (ResolveResult<'static>, &'static Event<'static>)| -> Result<Self> {
|(mut namespace, mut event): (ResolveResult<'static>, &'static Event<'static>)| -> Result<Self> {
let mut resource: Option<Str> = None;
let Event::Start(bytes) = event else {
return Err(anyhow!("Unexpected XML event: {event:?}"));
@ -97,15 +97,15 @@ impl FromXml for BindRequest {
return Err(anyhow!("Incorrect namespace"));
}
loop {
let (namespace, event) = yield;
(namespace, event) = yield;
match event {
Event::Start(bytes) if bytes.name().0 == b"resource" => {
let (namespace, event) = yield;
(namespace, event) = yield;
let Event::Text(text) = event else {
return Err(anyhow!("Unexpected XML event: {event:?}"));
};
resource = Some(std::str::from_utf8(&*text)?.into());
let (namespace, event) = yield;
(namespace, event) = yield;
let Event::End(bytes) = event else {
return Err(anyhow!("Unexpected XML event: {event:?}"));
};
@ -127,12 +127,16 @@ impl FromXml for BindRequest {
}
}
#[derive(PartialEq, Eq, Debug)]
pub struct BindResponse(pub Jid);
impl ToXml for BindResponse {
fn serialize(&self, events: &mut Vec<Event<'static>>) {
events.extend_from_slice(&[
Event::Start(BytesStart::new(r#"bind xmlns="urn:ietf:params:xml:ns:xmpp-bind""#)),
Event::Start(BytesStart::from_content(
r#"bind xmlns="urn:ietf:params:xml:ns:xmpp-bind""#,
4,
)),
Event::Start(BytesStart::new(r#"jid"#)),
Event::Text(BytesText::new(self.0.to_string().as_str()).into_owned()),
Event::End(BytesEnd::new("jid")),

View File

@ -260,6 +260,18 @@ impl MessageType {
/// https://xmpp.org/rfcs/rfc6120.html#stanzas-error
pub struct IqError {
pub r#type: IqErrorType,
pub condition: Option<IqErrorCondition>,
}
pub enum IqErrorCondition {
ItemNotFound,
}
impl IqErrorCondition {
pub fn as_str(&self) -> &'static str {
match self {
IqErrorCondition::ItemNotFound => "item-not-found",
}
}
}
pub enum IqErrorType {
@ -289,8 +301,19 @@ impl IqErrorType {
impl ToXml for IqError {
fn serialize(&self, events: &mut Vec<Event<'static>>) {
let bytes = BytesStart::new(format!(r#"error xmlns="{}" type="{}""#, XMLNS, self.r#type.as_str()));
match self.condition {
None => {
events.push(Event::Empty(bytes));
}
Some(IqErrorCondition::ItemNotFound) => {
events.push(Event::Start(bytes));
let bytes2 = BytesStart::new(r#"item-not-found xmlns="urn:ietf:params:xml:ns:xmpp-stanzas""#);
events.push(Event::Empty(bytes2));
let bytes = BytesEnd::new("error");
events.push(Event::End(bytes));
}
}
}
}
#[derive(PartialEq, Eq, Debug)]
@ -378,7 +401,7 @@ impl<T: FromXml> Parser for IqParser<T> {
}
},
IqParserInner::Final(state) => {
if let Event::End(ref bytes) = event {
if let Event::End(_) = event {
let id = fail_fast!(state.id.ok_or_else(|| ffail!("No id provided")));
let r#type = fail_fast!(state.r#type.ok_or_else(|| ffail!("No type provided")));
let body = fail_fast!(state.body.ok_or_else(|| ffail!("No body provided")));
@ -464,6 +487,7 @@ impl<T: ToXml> ToXml for Iq<T> {
#[derive(PartialEq, Eq, Debug)]
pub struct Presence<T> {
pub id: Option<String>,
pub to: Option<Jid>,
pub from: Option<Jid>,
pub priority: Option<PresencePriority>,
@ -476,6 +500,7 @@ pub struct Presence<T> {
impl<T> Default for Presence<T> {
fn default() -> Self {
Self {
id: Default::default(),
to: Default::default(),
from: Default::default(),
priority: Default::default(),
@ -528,7 +553,7 @@ impl<T: FromXml> FromXml for Presence<T> {
type P = impl Parser<Output = Result<Presence<T>>>;
fn parse() -> Self::P {
|(namespace, event): (ResolveResult<'static>, &'static Event<'static>)| -> Result<Self> {
|(mut namespace, mut event): (ResolveResult<'static>, &'static Event<'static>)| -> Result<Self> {
let (bytes, end) = match event {
Event::Start(bytes) => (bytes, false),
Event::Empty(bytes) => (bytes, true),
@ -550,6 +575,10 @@ impl<T: FromXml> FromXml for Presence<T> {
let s = std::str::from_utf8(&attr.value)?;
p.r#type = Some(s.into());
}
b"id" => {
let s = std::str::from_utf8(&attr.value)?;
p.r#type = Option::from(s.to_string());
}
_ => {}
}
}
@ -557,37 +586,37 @@ impl<T: FromXml> FromXml for Presence<T> {
return Ok(p);
}
loop {
let (namespace, event) = yield;
(namespace, event) = yield;
match event {
Event::Start(bytes) => match bytes.name().0 {
b"show" => {
let (_, event) = yield;
(namespace, event) = yield;
let Event::Text(bytes) = event else {
return Err(ffail!("Unexpected XML event: {event:?}"));
};
let i = PresenceShow::from_str(bytes)?;
p.show = Some(i);
let (_, event) = yield;
(namespace, event) = yield;
let Event::End(_) = event else {
return Err(ffail!("Unexpected XML event: {event:?}"));
};
}
b"status" => {
let (_, event) = yield;
(namespace, event) = yield;
let Event::Text(bytes) = event else {
return Err(ffail!("Unexpected XML event: {event:?}"));
};
let s = std::str::from_utf8(bytes)?;
p.status.push(s.to_string());
let (_, event) = yield;
(namespace, event) = yield;
let Event::End(_) = event else {
return Err(ffail!("Unexpected XML event: {event:?}"));
};
}
b"priority" => {
let (_, event) = yield;
(namespace, event) = yield;
let Event::Text(bytes) = event else {
return Err(ffail!("Unexpected XML event: {event:?}"));
};
@ -595,7 +624,7 @@ impl<T: FromXml> FromXml for Presence<T> {
let i = s.parse()?;
p.priority = Some(PresencePriority(i));
let (_, event) = yield;
(namespace, event) = yield;
let Event::End(_) = event else {
return Err(ffail!("Unexpected XML event: {event:?}"));
};
@ -637,6 +666,12 @@ impl<T: ToXml> ToXml for Presence<T> {
value: from.to_string().as_bytes().into(),
}]);
}
if let Some(ref id) = self.id {
start.extend_attributes([Attribute {
key: QName(b"id"),
value: id.to_string().as_bytes().into(),
}]);
}
events.push(Event::Start(start));
if let Some(ref priority) = self.priority {
let s = priority.0.to_string();
@ -646,6 +681,9 @@ impl<T: ToXml> ToXml for Presence<T> {
Event::End(BytesEnd::new("priority")),
]);
}
for c in &self.custom {
c.serialize(events);
}
events.push(Event::End(BytesEnd::new("presence")));
}
}
@ -658,7 +696,7 @@ mod tests {
#[tokio::test]
async fn parse_message() {
let input = r#"<message id="aacea" type="chat" to="nikita@vlnv.dev"><subject>daa</subject><body>bbb</body><unknown-stuff></unknown-stuff></message>"#;
let input = r#"<message id="aacea" type="chat" to="chelik@xmpp.ru"><subject>daa</subject><body>bbb</body><unknown-stuff></unknown-stuff></message>"#;
let result: Message<Ignore> = crate::xml::parse(input).unwrap();
assert_eq!(
result,
@ -666,8 +704,8 @@ mod tests {
from: None,
id: Some("aacea".to_string()),
to: Some(Jid {
name: Some(Name("nikita".into())),
server: Server("vlnv.dev".into()),
name: Some(Name("chelik".into())),
server: Server("xmpp.ru".into()),
resource: None
}),
r#type: MessageType::Chat,
@ -681,7 +719,7 @@ mod tests {
#[tokio::test]
async fn parse_message_empty_custom() {
let input = r#"<message id="aacea" type="chat" to="nikita@vlnv.dev"><subject>daa</subject><body>bbb</body><unknown-stuff/></message>"#;
let input = r#"<message id="aacea" type="chat" to="chelik@xmpp.ru"><subject>daa</subject><body>bbb</body><unknown-stuff/></message>"#;
let result: Message<Ignore> = crate::xml::parse(input).unwrap();
assert_eq!(
result,
@ -689,8 +727,8 @@ mod tests {
from: None,
id: Some("aacea".to_string()),
to: Some(Jid {
name: Some(Name("nikita".into())),
server: Server("vlnv.dev".into()),
name: Some(Name("chelik".into())),
server: Server("xmpp.ru".into()),
resource: None
}),
r#type: MessageType::Chat,

View File

@ -21,7 +21,7 @@ impl FromXml for InfoQuery {
type P = impl Parser<Output = Result<Self>>;
fn parse() -> Self::P {
|(namespace, event): (ResolveResult<'static>, &'static Event<'static>)| -> Result<Self> {
|(mut namespace, mut event): (ResolveResult<'static>, &'static Event<'static>)| -> Result<Self> {
let mut node = None;
let mut identity = vec![];
let mut feature = vec![];
@ -48,7 +48,7 @@ impl FromXml for InfoQuery {
});
}
loop {
let (namespace, event) = yield;
(namespace, event) = yield;
let bytes = match event {
Event::Start(bytes) => bytes,
Event::Empty(bytes) => bytes,
@ -141,7 +141,7 @@ impl FromXml for Identity {
type P = impl Parser<Output = Result<Self>>;
fn parse() -> Self::P {
|(namespace, event): (ResolveResult<'static>, &'static Event<'static>)| -> Result<Self> {
|(mut namespace, mut event): (ResolveResult<'static>, &'static Event<'static>)| -> Result<Self> {
let mut category = None;
let mut name = None;
let mut r#type = None;
@ -179,8 +179,8 @@ impl FromXml for Identity {
return Ok(item);
}
let (namespace, event) = yield;
let Event::End(bytes) = event else {
(namespace, event) = yield;
let Event::End(_) = event else {
return Err(ffail!("Unexpected XML event: {event:?}"));
};
Ok(item)
@ -209,7 +209,7 @@ impl FromXml for Feature {
type P = impl Parser<Output = Result<Self>>;
fn parse() -> Self::P {
|(namespace, event): (ResolveResult<'static>, &'static Event<'static>)| -> Result<Self> {
|(mut namespace, mut event): (ResolveResult<'static>, &'static Event<'static>)| -> Result<Self> {
let mut var = None;
let (bytes, end) = match event {
Event::Start(bytes) => (bytes, false),
@ -234,8 +234,8 @@ impl FromXml for Feature {
return Ok(item);
}
let (namespace, event) = yield;
let Event::End(bytes) = event else {
(namespace, event) = yield;
let Event::End(_) = event else {
return Err(ffail!("Unexpected XML event: {event:?}"));
};
Ok(item)
@ -258,9 +258,9 @@ impl FromXml for ItemQuery {
type P = impl Parser<Output = Result<Self>>;
fn parse() -> Self::P {
|(namespace, event): (ResolveResult<'static>, &'static Event<'static>)| -> Result<Self> {
|(mut namespace, mut event): (ResolveResult<'static>, &'static Event<'static>)| -> Result<Self> {
let mut item = vec![];
let (bytes, end) = match event {
let (_, end) = match event {
Event::Start(bytes) => (bytes, false),
Event::Empty(bytes) => (bytes, true),
_ => return Err(ffail!("Unexpected XML event: {event:?}")),
@ -269,7 +269,7 @@ impl FromXml for ItemQuery {
return Ok(ItemQuery { item });
}
loop {
let (namespace, event) = yield;
(namespace, event) = yield;
let bytes = match event {
Event::Start(bytes) => bytes,
Event::Empty(bytes) => bytes,
@ -296,7 +296,7 @@ impl FromXmlTag for ItemQuery {
impl ToXml for ItemQuery {
fn serialize(&self, events: &mut Vec<Event<'static>>) {
let mut bytes = BytesStart::new(format!(r#"query xmlns="{}""#, XMLNS_ITEM));
let bytes = BytesStart::new(format!(r#"query xmlns="{}""#, XMLNS_ITEM));
let empty = self.item.is_empty();
if empty {
events.push(Event::Empty(bytes));
@ -342,7 +342,7 @@ impl FromXml for Item {
type P = impl Parser<Output = Result<Self>>;
fn parse() -> Self::P {
|(namespace, event): (ResolveResult<'static>, &'static Event<'static>)| -> Result<Self> {
|(_, mut event): (ResolveResult<'static>, &'static Event<'static>)| -> Result<Self> {
let mut jid = None;
let mut name = None;
let mut node = None;
@ -378,8 +378,8 @@ impl FromXml for Item {
return Ok(item);
}
let (namespace, event) = yield;
let Event::End(bytes) = event else {
(_, event) = yield;
let Event::End(_) = event else {
return Err(ffail!("Unexpected XML event: {event:?}"));
};
Ok(item)

View File

@ -3,6 +3,7 @@
pub mod bind;
pub mod client;
pub mod disco;
pub mod mam;
pub mod muc;
mod prelude;
pub mod roster;
@ -10,6 +11,7 @@ pub mod sasl;
pub mod session;
pub mod stanzaerror;
pub mod stream;
pub mod streamerror;
pub mod tls;
pub mod xml;

View File

@ -0,0 +1,224 @@
use anyhow::{anyhow, Result};
use quick_xml::events::{BytesEnd, BytesStart, BytesText, Event};
use quick_xml::name::{Namespace, ResolveResult};
use crate::xml::*;
pub const MAM_XMLNS: &'static str = "urn:xmpp:mam:2";
pub const DATA_XMLNS: &'static str = "jabber:x:data";
pub const RESULT_SET_XMLNS: &'static str = "http://jabber.org/protocol/rsm";
#[derive(PartialEq, Eq, Debug, Clone)]
pub struct MessageArchiveRequest {
pub x: Option<X>,
}
#[derive(PartialEq, Eq, Debug, Clone)]
pub struct X {
pub fields: Vec<Field>,
}
#[derive(PartialEq, Eq, Debug, Clone)]
pub struct Field {
pub values: Vec<String>,
}
// Message archive response styled as a result set.
#[derive(PartialEq, Eq, Debug, Clone)]
pub struct Fin {
pub set: Set,
}
#[derive(PartialEq, Eq, Debug, Clone)]
pub struct Set {
pub count: Option<i32>,
}
impl ToXml for Fin {
fn serialize(&self, events: &mut Vec<Event<'static>>) {
let fin_bytes = BytesStart::new(format!(r#"fin xmlns="{}" complete=True"#, MAM_XMLNS));
let set_bytes = BytesStart::new(format!(r#"set xmlns="{}""#, RESULT_SET_XMLNS));
events.push(Event::Start(fin_bytes));
events.push(Event::Start(set_bytes));
if let &Some(count) = &self.set.count {
events.push(Event::Start(BytesStart::new("count")));
events.push(Event::Text(BytesText::new(count.to_string().as_str()).into_owned()));
events.push(Event::End(BytesEnd::new("count")));
}
events.push(Event::End(BytesEnd::new("set")));
events.push(Event::End(BytesEnd::new("fin")));
}
}
impl FromXmlTag for X {
const NAME: &'static str = "x";
const NS: &'static str = DATA_XMLNS;
}
impl FromXmlTag for MessageArchiveRequest {
const NAME: &'static str = "query";
const NS: &'static str = MAM_XMLNS;
}
impl FromXml for X {
type P = impl Parser<Output = Result<Self>>;
fn parse() -> Self::P {
|(mut namespace, mut event): (ResolveResult<'static>, &'static Event<'static>)| -> Result<Self> {
println!("X::parse {:?}", event);
let bytes = match event {
Event::Start(bytes) if bytes.name().0 == X::NAME.as_bytes() => bytes,
Event::Empty(bytes) if bytes.name().0 == X::NAME.as_bytes() => return Ok(X { fields: vec![] }),
_ => return Err(anyhow!("Unexpected XML event: {event:?}")),
};
let mut fields = vec![];
loop {
(namespace, event) = yield;
match event {
Event::Start(_) => {
// start of <field>
let mut values = vec![];
loop {
(namespace, event) = yield;
match event {
Event::Start(bytes) if bytes.name().0 == b"value" => {
// start of <value>
}
Event::End(bytes) if bytes.name().0 == b"field" => {
// end of </field>
break;
}
_ => return Err(anyhow!("Unexpected XML event: {event:?}")),
}
(namespace, event) = yield;
let text: String = match event {
Event::Text(bytes) => {
// text inside <value></value>
String::from_utf8(bytes.to_vec())?
}
_ => return Err(anyhow!("Unexpected XML event: {event:?}")),
};
(namespace, event) = yield;
match event {
Event::End(bytes) if bytes.name().0 == b"value" => {
// end of </value>
}
_ => return Err(anyhow!("Unexpected XML event: {event:?}")),
}
values.push(text);
}
fields.push(Field { values })
}
Event::End(bytes) if bytes.name().0 == X::NAME.as_bytes() => {
// end of <x/>
return Ok(X { fields });
}
_ => return Err(anyhow!("Unexpected XML event: {event:?}")),
}
}
}
}
}
impl FromXml for MessageArchiveRequest {
type P = impl Parser<Output = Result<Self>>;
fn parse() -> Self::P {
|(mut namespace, mut event): (ResolveResult<'static>, &'static Event<'static>)| -> Result<Self> {
println!("MessageArchiveRequest::parse {:?}", event);
let bytes = match event {
Event::Empty(_) => return Ok(MessageArchiveRequest { x: None }),
Event::Start(bytes) => bytes,
_ => return Err(anyhow!("Unexpected XML event: {event:?}")),
};
if bytes.name().0 != MessageArchiveRequest::NAME.as_bytes() {
return Err(anyhow!("Unexpected XML tag: {:?}", bytes.name()));
}
let ResolveResult::Bound(Namespace(ns)) = namespace else {
return Err(anyhow!("No namespace provided"));
};
if ns != MAM_XMLNS.as_bytes() {
return Err(anyhow!("Incorrect namespace"));
}
(namespace, event) = yield;
match event {
Event::End(bytes) if bytes.name().0 == MessageArchiveRequest::NAME.as_bytes() => {
Ok(MessageArchiveRequest { x: None })
}
Event::Start(bytes) | Event::Empty(bytes) if bytes.name().0 == X::NAME.as_bytes() => {
let x = delegate_parsing!(X, namespace, event)?;
Ok(MessageArchiveRequest { x: Some(x) })
}
_ => Err(anyhow!("Unexpected XML event: {event:?}")),
}
}
}
}
impl MessageArchiveRequest {}
#[cfg(test)]
mod tests {
use super::*;
use crate::bind::{Jid, Name, Server};
use crate::client::{Iq, IqType};
#[test]
fn test_parse_archive_query() {
let input = r#"<iq to='pubsub.shakespeare.lit' type='set' id='juliet1'><query xmlns='urn:xmpp:mam:2' queryid='f28'/></iq>"#;
let result: Iq<MessageArchiveRequest> = parse(input).unwrap();
assert_eq!(
result,
Iq {
from: None,
id: "juliet1".to_string(),
to: Option::from(Jid {
name: None,
server: Server("pubsub.shakespeare.lit".into()),
resource: None,
}),
r#type: IqType::Set,
body: MessageArchiveRequest { x: None },
}
);
}
#[test]
fn test_parse_query_messages_from_jid() {
let input = r#"<iq type='set' id='juliet1'><query xmlns='urn:xmpp:mam:2'><x xmlns='jabber:x:data' type='submit'><field var='FORM_TYPE' type='hidden'><value>value1</value></field><field var='with'><value>juliet@capulet.lit</value></field></x></query></iq>"#;
let result: Iq<MessageArchiveRequest> = parse(input).unwrap();
assert_eq!(
result,
Iq {
from: None,
id: "juliet1".to_string(),
to: None,
r#type: IqType::Set,
body: MessageArchiveRequest {
x: Some(X {
fields: vec![
Field {
values: vec!["value1".to_string()],
},
Field {
values: vec!["juliet@capulet.lit".to_string()],
},
]
})
},
}
);
}
#[test]
fn test_parse_query_messages_from_jid_with_unclosed_tag() {
let input = r#"<iq type='set' id='juliet1'><query xmlns='urn:xmpp:mam:2'><x xmlns='jabber:x:data' type='submit'><field var='FORM_TYPE' type='hidden'><value>value1</value></field><field var='with'><value>juliet@capulet.lit</value></field></query></iq>"#;
assert!(parse::<Iq<MessageArchiveRequest>>(input).is_err())
}
}

View File

@ -1,12 +1,13 @@
#![allow(unused_variables)]
use quick_xml::events::Event;
use quick_xml::events::{BytesEnd, BytesStart, Event};
use quick_xml::name::ResolveResult;
use crate::xml::*;
use anyhow::{anyhow, Result};
pub const XMLNS: &'static str = "http://jabber.org/protocol/muc";
pub const XMLNS_USER: &'static str = "http://jabber.org/protocol/muc#user";
#[derive(PartialEq, Eq, Debug, Default)]
pub struct History {
@ -19,7 +20,7 @@ impl FromXml for History {
type P = impl Parser<Output = Result<Self>>;
fn parse() -> Self::P {
|(namespace, event): (ResolveResult<'static>, &'static Event<'static>)| -> Result<Self> {
|(mut namespace, mut event): (ResolveResult<'static>, &'static Event<'static>)| -> Result<Self> {
let mut history = History::default();
let (bytes, end) = match event {
Event::Start(bytes) if bytes.name().0 == Self::NAME.as_bytes() => (bytes, false),
@ -51,7 +52,7 @@ impl FromXml for History {
return Ok(history);
}
let (namespace, event) = yield;
(namespace, event) = yield;
let Event::End(bytes) = event else {
return Err(anyhow!("Unexpected XML event: {event:?}"));
};
@ -73,17 +74,17 @@ impl FromXml for Password {
type P = impl Parser<Output = Result<Self>>;
fn parse() -> Self::P {
|(namespace, event): (ResolveResult<'static>, &'static Event<'static>)| -> Result<Self> {
|(mut namespace, mut event): (ResolveResult<'static>, &'static Event<'static>)| -> Result<Self> {
let bytes = match event {
Event::Start(bytes) if bytes.name().0 == Self::NAME.as_bytes() => bytes,
_ => return Err(anyhow!("Unexpected XML event: {event:?}")),
};
let (namespace, event) = yield;
(namespace, event) = yield;
let Event::Text(bytes) = event else {
return Err(anyhow!("Unexpected XML event: {event:?}"));
};
let s = std::str::from_utf8(bytes)?.to_string();
let (namespace, event) = yield;
(namespace, event) = yield;
let Event::End(bytes) = event else {
return Err(anyhow!("Unexpected XML event: {event:?}"));
};
@ -108,7 +109,7 @@ impl FromXml for X {
type P = impl Parser<Output = Result<Self>>;
fn parse() -> Self::P {
|(namespace, event): (ResolveResult<'static>, &'static Event<'static>)| -> Result<Self> {
|(mut namespace, mut event): (ResolveResult<'static>, &'static Event<'static>)| -> Result<Self> {
let mut res = X::default();
let (_, end) = match event {
Event::Start(bytes) => (bytes, false),
@ -120,7 +121,7 @@ impl FromXml for X {
}
loop {
let (namespace, event) = yield;
(namespace, event) = yield;
let bytes = match event {
Event::Start(bytes) => bytes,
Event::Empty(bytes) => bytes,
@ -143,6 +144,32 @@ impl FromXml for X {
}
}
#[derive(Debug, PartialEq, Eq)]
pub struct XUser;
impl ToXml for XUser {
fn serialize(&self, output: &mut Vec<Event<'static>>) {
let mut tag = BytesStart::new("x");
tag.push_attribute(("xmlns", XMLNS_USER));
output.push(Event::Start(tag));
let mut meg = BytesStart::new("item");
meg.push_attribute(("affiliation", "owner"));
meg.push_attribute(("role", "moderator"));
meg.push_attribute(("jid", "sauer@localhost"));
output.push(Event::Empty(meg));
let mut veg = BytesStart::new("status");
veg.push_attribute(("code", "100"));
output.push(Event::Empty(veg));
let mut veg = BytesStart::new("status");
veg.push_attribute(("code", "110"));
output.push(Event::Empty(veg));
output.push(Event::End(BytesEnd::new("x")));
}
}
#[cfg(test)]
mod test {
use super::*;

View File

@ -2,7 +2,7 @@ use quick_xml::events::{BytesStart, Event};
use crate::xml::*;
use anyhow::{anyhow, Result};
use quick_xml::name::ResolveResult;
use quick_xml::name::{Namespace, ResolveResult};
pub const XMLNS: &'static str = "jabber:iq:roster";
@ -14,6 +14,9 @@ impl FromXml for RosterQuery {
fn parse() -> Self::P {
|(mut namespace, mut event): (ResolveResult<'static>, &'static Event<'static>)| -> Result<Self> {
let ResolveResult::Bound(Namespace(ns)) = namespace else {
return Err(anyhow!("No namespace provided"));
};
match event {
Event::Start(_) => (),
Event::Empty(_) => return Ok(RosterQuery),
@ -38,3 +41,39 @@ impl ToXml for RosterQuery {
events.push(Event::Empty(BytesStart::new(format!(r#"query xmlns="{}""#, XMLNS))));
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::bind::{Jid, Name, Resource, Server};
use crate::client::{Iq, IqType};
#[test]
fn test_parse() {
let input =
r#"<iq from='juliet@example.com/balcony' id='bv1bs71f' type='get'><query xmlns='jabber:iq:roster'/></iq>"#;
let result: Iq<RosterQuery> = parse(input).unwrap();
assert_eq!(
result,
Iq {
from: Option::from(Jid {
name: Option::from(Name("juliet".into())),
server: Server("example.com".into()),
resource: Option::from(Resource("balcony".into())),
}),
id: "bv1bs71f".to_string(),
to: None,
r#type: IqType::Get,
body: RosterQuery,
}
)
}
#[test]
fn test_missing_namespace() {
let input = r#"<iq from='juliet@example.com/balcony' id='bv1bs71f' type='get'><query/></iq>"#;
assert!(parse::<Iq<RosterQuery>>(input).is_err());
}
}

View File

@ -1,7 +1,7 @@
use std::borrow::Borrow;
use anyhow::{anyhow, Result};
use quick_xml::events::{BytesStart, Event};
use quick_xml::events::{BytesEnd, BytesStart, Event};
use quick_xml::{NsReader, Writer};
use tokio::io::{AsyncBufRead, AsyncWrite};
@ -74,3 +74,16 @@ impl Success {
Ok(())
}
}
pub struct Failure;
impl Failure {
pub async fn write_xml(&self, writer: &mut Writer<impl AsyncWrite + Unpin>) -> Result<()> {
let event = BytesStart::new(r#"failure xmlns="urn:ietf:params:xml:ns:xmpp-sasl""#);
writer.write_event_async(Event::Start(event)).await?;
let event = BytesStart::new(r#"not-authorized"#);
writer.write_event_async(Event::Empty(event)).await?;
let event = BytesEnd::new(r#"failure"#);
writer.write_event_async(Event::End(event)).await?;
Ok(())
}
}

View File

@ -170,14 +170,14 @@ mod test {
#[tokio::test]
async fn client_stream_start_correct_parse() {
let input = r###"<stream:stream xmlns:stream="http://etherx.jabber.org/streams" to="vlnv.dev" version="1.0" xmlns="jabber:client" xml:lang="en" xmlns:xml="http://www.w3.org/XML/1998/namespace">"###;
let input = r###"<stream:stream xmlns:stream="http://etherx.jabber.org/streams" to="xmpp.ru" version="1.0" xmlns="jabber:client" xml:lang="en" xmlns:xml="http://www.w3.org/XML/1998/namespace">"###;
let mut reader = NsReader::from_reader(input.as_bytes());
let mut buf = vec![];
let res = ClientStreamStart::parse(&mut reader, &mut buf).await.unwrap();
assert_eq!(
res,
ClientStreamStart {
to: "vlnv.dev".to_owned(),
to: "xmpp.ru".to_owned(),
lang: Some("en".to_owned()),
version: "1.0".to_owned()
}
@ -187,12 +187,12 @@ mod test {
#[tokio::test]
async fn server_stream_start_write() {
let input = ServerStreamStart {
from: "vlnv.dev".to_owned(),
from: "xmpp.ru".to_owned(),
lang: "en".to_owned(),
id: "stream_id".to_owned(),
version: "1.0".to_owned(),
};
let expected = r###"<stream:stream from="vlnv.dev" version="1.0" xmlns="jabber:client" xmlns:stream="http://etherx.jabber.org/streams" xml:lang="en" id="stream_id">"###;
let expected = r###"<stream:stream from="xmpp.ru" version="1.0" xmlns="jabber:client" xmlns:stream="http://etherx.jabber.org/streams" xml:lang="en" id="stream_id">"###;
let mut output: Vec<u8> = vec![];
let mut writer = Writer::new(&mut output);
input.write_xml(&mut writer).await.unwrap();

View File

@ -0,0 +1,41 @@
use crate::xml::ToXml;
use quick_xml::events::{BytesEnd, BytesStart, Event};
/// Stream error condition
///
/// [Spec](https://xmpp.org/rfcs/rfc6120.html#streams-error-conditions).
pub enum StreamErrorKind {
/// The server has experienced a misconfiguration or other internal error that prevents it from servicing the stream.
InternalServerError,
/// The server is being shut down and all active streams are being closed.
SystemShutdown,
}
impl StreamErrorKind {
pub fn from_str(s: &str) -> Option<Self> {
match s {
"internal-server-error" => Some(Self::InternalServerError),
"system-shutdown" => Some(Self::SystemShutdown),
_ => None,
}
}
pub fn as_str(&self) -> &'static str {
match self {
Self::InternalServerError => "internal-server-error",
Self::SystemShutdown => "system-shutdown",
}
}
}
pub struct StreamError {
pub kind: StreamErrorKind,
}
impl ToXml for StreamError {
fn serialize(&self, events: &mut Vec<Event<'static>>) {
events.push(Event::Start(BytesStart::new("stream:error")));
events.push(Event::Empty(BytesStart::new(format!(
r#"{} xmlns="urn:ietf:params:xml:ns:xmpp-streams""#,
self.kind.as_str()
))));
events.push(Event::End(BytesEnd::new("stream:error")));
}
}

View File

@ -10,9 +10,38 @@ use anyhow::Result;
mod ignore;
pub use ignore::Ignore;
/// Types which can be parsed from an XML input stream.
///
/// Example:
/// ```
/// #![feature(type_alias_impl_trait)]
/// #![feature(impl_trait_in_assoc_type)]
/// #![feature(coroutines)]
/// # use proto_xmpp::xml::FromXml;
/// # use quick_xml::events::Event;
/// # use quick_xml::name::ResolveResult;
/// # use proto_xmpp::xml::Parser;
/// # use anyhow::Result;
///
/// struct MyStruct;
/// impl FromXml for MyStruct {
/// type P = impl Parser<Output = Result<Self>>;
///
/// fn parse() -> Self::P {
/// |(mut namespace, mut event): (ResolveResult<'static>, &'static Event<'static>)| -> Result<Self> {
/// (namespace, event) = yield;
/// Ok(MyStruct)
/// }
/// }
/// }
/// ```
pub trait FromXml: Sized {
/// The type of parser instances.
///
/// If the result type of the [parse] is anonymous, this type member can be defined by using `impl Trait`.
type P: Parser<Output = Result<Self>>;
/// Creates a new instance of a parser with an initial state.
fn parse() -> Self::P;
}
@ -25,9 +54,18 @@ pub trait FromXmlTag: FromXml {
const NS: &'static str;
}
/// A stateful parser instance which consumes XML events until the parsing is complete.
///
/// Usually implemented with the experimental coroutine syntax, which yields to consume the next XML event,
/// and returns the final result when the parsing is done.
pub trait Parser: Sized {
type Output;
/// Advance the parsing by one XML event.
///
/// This method consumes `self`, but if the parsing is incomplete,
/// it will return the next state of the parser in the returned result.
/// Otherwise, it will return the final result of parsing.
fn consume<'a>(self: Self, namespace: ResolveResult, event: &Event<'a>) -> Continuation<Self, Self::Output>;
}
@ -50,8 +88,11 @@ where
}
}
/// The result of a single parser iteration.
pub enum Continuation<Parser, Res> {
/// The parsing is complete and the final result is available.
Final(Res),
/// The parsing is not complete and more XML events are required.
Continue(Parser),
}
@ -89,8 +130,8 @@ macro_rules! delegate_parsing {
Continuation::Final(Ok(res)) => break Ok(res.into()),
Continuation::Final(Err(err)) => break Err(err),
Continuation::Continue(p) => {
let (namespace, event) = yield;
parser = p.consume(namespace, event);
($namespace, $event) = yield;
parser = p.consume($namespace, $event);
}
}
}

View File

@ -79,10 +79,6 @@ mod test {
fn test_fail_if_size_less_then_3() {
let orig = b"login\x00pass";
let encoded = general_purpose::STANDARD.encode(orig);
let expected = AuthBody {
login: "login".to_string(),
password: "pass".to_string(),
};
let result = AuthBody::from_str(encoded.as_bytes());
assert!(result.is_err());
@ -92,10 +88,6 @@ mod test {
fn test_fail_if_size_greater_then_3() {
let orig = b"first\x00login\x00pass\x00other";
let encoded = general_purpose::STANDARD.encode(orig);
let expected = AuthBody {
login: "login".to_string(),
password: "pass".to_string(),
};
let result = AuthBody::from_str(encoded.as_bytes());
assert!(result.is_err());

View File

@ -23,6 +23,11 @@ hostname = "localhost"
[storage]
db_path = "db.sqlite"
[tracing]
# otlp grpc endpoint
endpoint = "http://jaeger:4317"
service_name = "lavina"
```
## With Docker Compose
@ -41,6 +46,15 @@ services:
- '5222:5222' # xmpp
- '6667:6667' # irc non-tls
- '127.0.0.1:1380:8080' # management http (private)
# if you want to observe traces
jaeger:
image: "jaegertracing/all-in-one:1.56"
ports:
- "16686:16686" # web ui
- "4317:4317" # grpc ingest endpoint
environment:
- COLLECTOR_OTLP_ENABLED=true
- SPAN_STORAGE_TYPE=memory
```
## With Cargo

View File

@ -12,13 +12,16 @@ use prometheus::{Encoder, Registry as MetricsRegistry, TextEncoder};
use serde::{Deserialize, Serialize};
use tokio::net::TcpListener;
use lavina_core::auth::UpdatePasswordResult;
use lavina_core::player::{PlayerId, SendMessageResult};
use lavina_core::prelude::*;
use lavina_core::repo::Storage;
use lavina_core::room::RoomRegistry;
use lavina_core::room::RoomId;
use lavina_core::terminator::Terminator;
use lavina_core::LavinaCore;
use mgmt_api::*;
mod clustering;
type HttpResult<T> = std::result::Result<T, Infallible>;
#[derive(Deserialize, Debug)]
@ -26,24 +29,18 @@ pub struct ServerConfig {
pub listen_on: SocketAddr,
}
pub async fn launch(
config: ServerConfig,
metrics: MetricsRegistry,
rooms: RoomRegistry,
storage: Storage,
) -> Result<Terminator> {
pub async fn launch(config: ServerConfig, metrics: MetricsRegistry, core: LavinaCore) -> Result<Terminator> {
log::info!("Starting the http service");
let listener = TcpListener::bind(config.listen_on).await?;
log::debug!("Listener started");
let terminator = Terminator::spawn(|rx| main_loop(listener, metrics, rooms, storage, rx.map(|_| ())));
let terminator = Terminator::spawn(|rx| main_loop(listener, metrics, core, rx.map(|_| ())));
Ok(terminator)
}
async fn main_loop(
listener: TcpListener,
metrics: MetricsRegistry,
rooms: RoomRegistry,
storage: Storage,
core: LavinaCore,
termination: impl Future<Output = ()>,
) -> Result<()> {
pin!(termination);
@ -55,13 +52,10 @@ async fn main_loop(
let (stream, _) = result?;
let stream = TokioIo::new(stream);
let metrics = metrics.clone();
let rooms = rooms.clone();
let storage = storage.clone();
let core = core.clone();
tokio::task::spawn(async move {
let registry = metrics.clone();
let rooms = rooms.clone();
let storage = storage.clone();
let server = http1::Builder::new().serve_connection(stream, service_fn(move |r| route(registry.clone(), rooms.clone(), storage.clone(), r)));
let svc_fn = service_fn(|r| route(&metrics, &core, r));
let server = http1::Builder::new().serve_connection(stream, svc_fn);
if let Err(err) = server.await {
tracing::error!("Error serving connection: {:?}", err);
}
@ -73,90 +67,138 @@ async fn main_loop(
Ok(())
}
#[tracing::instrument(skip_all)]
async fn route(
registry: MetricsRegistry,
rooms: RoomRegistry,
storage: Storage,
registry: &MetricsRegistry,
core: &LavinaCore,
request: Request<hyper::body::Incoming>,
) -> HttpResult<Response<Full<Bytes>>> {
propagade_span_from_headers(&request);
let res = match (request.method(), request.uri().path()) {
(&Method::GET, "/metrics") => endpoint_metrics(registry),
(&Method::GET, "/rooms") => endpoint_rooms(rooms).await,
(&Method::POST, paths::CREATE_PLAYER) => endpoint_create_player(request, storage).await.or5xx(),
(&Method::POST, paths::SET_PASSWORD) => endpoint_set_password(request, storage).await.or5xx(),
_ => not_found(),
(&Method::GET, "/rooms") => endpoint_rooms(core).await,
(&Method::POST, paths::CREATE_PLAYER) => endpoint_create_player(request, core).await.or5xx(),
(&Method::POST, paths::STOP_PLAYER) => endpoint_stop_player(request, core).await.or5xx(),
(&Method::POST, paths::SET_PASSWORD) => endpoint_set_password(request, core).await.or5xx(),
(&Method::POST, rooms::paths::SEND_MESSAGE) => endpoint_send_room_message(request, core).await.or5xx(),
(&Method::POST, rooms::paths::SET_TOPIC) => endpoint_set_room_topic(request, core).await.or5xx(),
_ => clustering::route(core, request).await.unwrap_or_else(endpoint_not_found),
};
Ok(res)
}
fn endpoint_metrics(registry: MetricsRegistry) -> Response<Full<Bytes>> {
fn endpoint_metrics(registry: &MetricsRegistry) -> Response<Full<Bytes>> {
let mf = registry.gather();
let mut buffer = vec![];
TextEncoder.encode(&mf, &mut buffer).expect("write to vec cannot fail");
Response::new(Full::new(Bytes::from(buffer)))
}
async fn endpoint_rooms(rooms: RoomRegistry) -> Response<Full<Bytes>> {
#[tracing::instrument(skip_all)]
async fn endpoint_rooms(core: &LavinaCore) -> Response<Full<Bytes>> {
// TODO introduce management API types independent from core-domain types
// TODO remove `Serialize` implementations from all core-domain types
let room_list = rooms.get_all_rooms().await.to_body();
let room_list = core.get_all_rooms().await.to_body();
Response::new(room_list)
}
#[tracing::instrument(skip_all)]
async fn endpoint_create_player(
request: Request<hyper::body::Incoming>,
mut storage: Storage,
core: &LavinaCore,
) -> Result<Response<Full<Bytes>>> {
let str = request.collect().await?.to_bytes();
let Ok(res) = serde_json::from_slice::<CreatePlayerRequest>(&str[..]) else {
let payload = ErrorResponse {
code: errors::MALFORMED_REQUEST,
message: "The request payload contains incorrect JSON value",
}
.to_body();
let mut response = Response::new(payload);
*response.status_mut() = StatusCode::BAD_REQUEST;
return Ok(response);
return Ok(malformed_request());
};
storage.create_user(&res.name).await?;
core.create_player(&PlayerId::from(res.name)?).await?;
log::info!("Player {} created", res.name);
let mut response = Response::new(Full::<Bytes>::default());
*response.status_mut() = StatusCode::CREATED;
Ok(response)
}
#[tracing::instrument(skip_all)]
async fn endpoint_stop_player(
request: Request<hyper::body::Incoming>,
core: &LavinaCore,
) -> Result<Response<Full<Bytes>>> {
let str = request.collect().await?.to_bytes();
let Ok(res) = serde_json::from_slice::<StopPlayerRequest>(&str[..]) else {
return Ok(malformed_request());
};
let Ok(player_id) = PlayerId::from(res.name) else {
return Ok(player_not_found());
};
let Some(()) = core.stop_player(&player_id).await? else {
return Ok(player_not_found());
};
Ok(empty_204_request())
}
#[tracing::instrument(skip_all)]
async fn endpoint_set_password(
request: Request<hyper::body::Incoming>,
mut storage: Storage,
core: &LavinaCore,
) -> Result<Response<Full<Bytes>>> {
let str = request.collect().await?.to_bytes();
let Ok(res) = serde_json::from_slice::<ChangePasswordRequest>(&str[..]) else {
let payload = ErrorResponse {
code: errors::MALFORMED_REQUEST,
message: "The request payload contains incorrect JSON value",
}
.to_body();
let mut response = Response::new(payload);
*response.status_mut() = StatusCode::BAD_REQUEST;
return Ok(response);
return Ok(malformed_request());
};
let Some(_) = storage.set_password(&res.player_name, &res.password).await? else {
let payload = ErrorResponse {
code: errors::PLAYER_NOT_FOUND,
message: "No such player exists",
let verdict = core.set_password(&res.player_name, &res.password).await?;
match verdict {
UpdatePasswordResult::PasswordUpdated => {}
UpdatePasswordResult::UserNotFound => {
return Ok(player_not_found());
}
.to_body();
let mut response = Response::new(payload);
*response.status_mut() = StatusCode::UNPROCESSABLE_ENTITY;
return Ok(response);
};
log::info!("Password changed for player {}", res.player_name);
let mut response = Response::new(Full::<Bytes>::default());
*response.status_mut() = StatusCode::NO_CONTENT;
Ok(response)
}
Ok(empty_204_request())
}
pub fn not_found() -> Response<Full<Bytes>> {
#[tracing::instrument(skip_all)]
async fn endpoint_send_room_message(
request: Request<hyper::body::Incoming>,
core: &LavinaCore,
) -> Result<Response<Full<Bytes>>> {
let str = request.collect().await?.to_bytes();
let Ok(req) = serde_json::from_slice::<rooms::SendMessageReq>(&str[..]) else {
return Ok(malformed_request());
};
let Ok(room_id) = RoomId::from(req.room_id) else {
return Ok(room_not_found());
};
let Ok(player_id) = PlayerId::from(req.author_id) else {
return Ok(player_not_found());
};
let mut player = core.connect_to_player(&player_id).await;
let res = player.send_message(room_id, req.message.into()).await?;
match res {
SendMessageResult::NoSuchRoom => Ok(room_not_found()),
SendMessageResult::Success(_) => Ok(empty_204_request()),
}
}
#[tracing::instrument(skip_all)]
async fn endpoint_set_room_topic(
request: Request<hyper::body::Incoming>,
core: &LavinaCore,
) -> Result<Response<Full<Bytes>>> {
let str = request.collect().await?.to_bytes();
let Ok(req) = serde_json::from_slice::<rooms::SetTopicReq>(&str[..]) else {
return Ok(malformed_request());
};
let Ok(room_id) = RoomId::from(req.room_id) else {
return Ok(room_not_found());
};
let Ok(player_id) = PlayerId::from(req.author_id) else {
return Ok(player_not_found());
};
let mut player = core.connect_to_player(&player_id).await;
player.change_topic(room_id, req.topic.into()).await?;
Ok(empty_204_request())
}
fn endpoint_not_found() -> Response<Full<Bytes>> {
let payload = ErrorResponse {
code: errors::INVALID_PATH,
message: "The path does not exist",
@ -168,25 +210,64 @@ pub fn not_found() -> Response<Full<Bytes>> {
response
}
fn player_not_found() -> Response<Full<Bytes>> {
let payload = ErrorResponse {
code: errors::PLAYER_NOT_FOUND,
message: "No such player exists",
}
.to_body();
let mut response = Response::new(payload);
*response.status_mut() = StatusCode::UNPROCESSABLE_ENTITY;
response
}
fn room_not_found() -> Response<Full<Bytes>> {
let payload = ErrorResponse {
code: rooms::errors::ROOM_NOT_FOUND,
message: "No such room exists",
}
.to_body();
let mut response = Response::new(payload);
*response.status_mut() = StatusCode::UNPROCESSABLE_ENTITY;
response
}
fn malformed_request() -> Response<Full<Bytes>> {
let payload = ErrorResponse {
code: errors::MALFORMED_REQUEST,
message: "The request payload contains incorrect JSON value",
}
.to_body();
let mut response = Response::new(payload);
*response.status_mut() = StatusCode::BAD_REQUEST;
return response;
}
fn empty_204_request() -> Response<Full<Bytes>> {
let mut response = Response::new(Full::<Bytes>::default());
*response.status_mut() = StatusCode::NO_CONTENT;
response
}
trait Or5xx {
fn or5xx(self) -> Response<Full<Bytes>>;
}
impl Or5xx for Result<Response<Full<Bytes>>> {
fn or5xx(self) -> Response<Full<Bytes>> {
match self {
Ok(e) => e,
Err(e) => {
self.unwrap_or_else(|e| {
let mut response = Response::new(Full::new(e.to_string().into()));
*response.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
response
}
}
})
}
}
trait ToBody {
fn to_body(&self) -> Full<Bytes>;
}
impl<T> ToBody for T
where
T: Serialize,
@ -197,3 +278,24 @@ where
Full::new(Bytes::from(buffer))
}
}
fn propagade_span_from_headers<T>(req: &Request<T>) {
use opentelemetry::propagation::Extractor;
use tracing::Span;
use tracing_opentelemetry::OpenTelemetrySpanExt;
struct HttpReqExtractor<'a, T> {
req: &'a Request<T>,
}
impl<'a, T> Extractor for HttpReqExtractor<'a, T> {
fn get(&self, key: &str) -> Option<&str> {
self.req.headers().get(key).and_then(|v| v.to_str().ok())
}
fn keys(&self) -> Vec<&str> {
self.req.headers().keys().map(|k| k.as_str()).collect()
}
}
let ctx = opentelemetry::global::get_text_map_propagator(|pp| pp.extract(&HttpReqExtractor { req }));
Span::current().set_parent(ctx);
}

72
src/http/clustering.rs Normal file
View File

@ -0,0 +1,72 @@
use http_body_util::{BodyExt, Full};
use hyper::body::Bytes;
use hyper::{Method, Request, Response};
use super::Or5xx;
use crate::http::{empty_204_request, malformed_request, player_not_found, room_not_found};
use lavina_core::clustering::room::{paths, JoinRoomReq, SendMessageReq};
use lavina_core::player::PlayerId;
use lavina_core::room::RoomId;
use lavina_core::LavinaCore;
// TODO move this into core
pub async fn route(core: &LavinaCore, request: Request<hyper::body::Incoming>) -> Option<Response<Full<Bytes>>> {
match (request.method(), request.uri().path()) {
(&Method::POST, paths::JOIN) => Some(endpoint_cluster_join_room(request, core).await.or5xx()),
(&Method::POST, paths::ADD_MESSAGE) => Some(endpoint_cluster_add_message(request, core).await.or5xx()),
_ => None,
}
}
#[tracing::instrument(skip_all, name = "endpoint_cluster_join_room")]
async fn endpoint_cluster_join_room(
request: Request<hyper::body::Incoming>,
core: &LavinaCore,
) -> lavina_core::prelude::Result<Response<Full<Bytes>>> {
let str = request.collect().await?.to_bytes();
let Ok(req) = serde_json::from_slice::<JoinRoomReq>(&str[..]) else {
return Ok(malformed_request());
};
tracing::info!("Incoming request: {:?}", &req);
let Ok(room_id) = RoomId::from(req.room_id) else {
dbg!(&req.room_id);
return Ok(room_not_found());
};
let Ok(player_id) = PlayerId::from(req.player_id) else {
dbg!(&req.player_id);
return Ok(player_not_found());
};
core.cluster_join_room(room_id, &player_id).await?;
Ok(empty_204_request())
}
#[tracing::instrument(skip_all, name = "endpoint_cluster_add_message")]
async fn endpoint_cluster_add_message(
request: Request<hyper::body::Incoming>,
core: &LavinaCore,
) -> lavina_core::prelude::Result<Response<Full<Bytes>>> {
let str = request.collect().await?.to_bytes();
let Ok(req) = serde_json::from_slice::<SendMessageReq>(&str[..]) else {
return Ok(malformed_request());
};
tracing::info!("Incoming request: {:?}", &req);
let Ok(created_at) = chrono::DateTime::parse_from_rfc3339(req.created_at) else {
dbg!(&req.created_at);
return Ok(malformed_request());
};
let Ok(room_id) = RoomId::from(req.room_id) else {
dbg!(&req.room_id);
return Ok(room_not_found());
};
let Ok(player_id) = PlayerId::from(req.player_id) else {
dbg!(&req.player_id);
return Ok(player_not_found());
};
let res = core.cluster_send_room_message(room_id, &player_id, req.message.into(), created_at.to_utc()).await?;
if let Some(_) = res {
Ok(empty_204_request())
} else {
Ok(room_not_found())
}
}

View File

@ -6,13 +6,23 @@ use std::path::Path;
use clap::Parser;
use figment::providers::Format;
use figment::{providers::Toml, Figment};
use opentelemetry::global::set_text_map_propagator;
use opentelemetry::KeyValue;
use opentelemetry_otlp::WithExportConfig;
use opentelemetry_sdk::propagation::TraceContextPropagator;
use opentelemetry_sdk::trace::{BatchConfig, RandomIdGenerator, Sampler};
use opentelemetry_sdk::{runtime, Resource};
use opentelemetry_semantic_conventions::resource::SERVICE_NAME;
use opentelemetry_semantic_conventions::SCHEMA_URL;
use prometheus::Registry as MetricsRegistry;
use serde::Deserialize;
use tracing_opentelemetry::OpenTelemetryLayer;
use tracing_subscriber::fmt::Subscriber;
use tracing_subscriber::prelude::*;
use lavina_core::player::PlayerRegistry;
use lavina_core::prelude::*;
use lavina_core::repo::Storage;
use lavina_core::room::RoomRegistry;
use lavina_core::LavinaCore;
#[derive(Deserialize, Debug)]
struct ServerConfig {
@ -20,6 +30,14 @@ struct ServerConfig {
irc: projection_irc::ServerConfig,
xmpp: projection_xmpp::ServerConfig,
storage: lavina_core::repo::StorageConfig,
cluster: lavina_core::clustering::ClusterConfig,
tracing: Option<TracingConfig>,
}
#[derive(Deserialize, Debug)]
struct TracingConfig {
endpoint: String,
service_name: String,
}
#[derive(Parser)]
@ -37,9 +55,9 @@ fn load_config() -> Result<ServerConfig> {
#[tokio::main]
async fn main() -> Result<()> {
set_up_logging()?;
let sleep = ctrl_c()?;
let config = load_config()?;
set_up_logging(&config.tracing)?;
tracing::info!("Booting up");
tracing::info!("Loaded config: {config:?}");
@ -48,28 +66,15 @@ async fn main() -> Result<()> {
irc: irc_config,
xmpp: xmpp_config,
storage: storage_config,
cluster: cluster_config,
tracing: _,
} = config;
let mut metrics = MetricsRegistry::new();
let storage = Storage::open(storage_config).await?;
let rooms = RoomRegistry::new(&mut metrics, storage.clone())?;
let mut players = PlayerRegistry::empty(rooms.clone(), storage.clone(), &mut metrics)?;
let telemetry_terminator = http::launch(telemetry_config, metrics.clone(), rooms.clone(), storage.clone()).await?;
let irc = projection_irc::launch(
irc_config,
players.clone(),
rooms.clone(),
metrics.clone(),
storage.clone(),
)
.await?;
let xmpp = projection_xmpp::launch(
xmpp_config,
players.clone(),
rooms.clone(),
metrics.clone(),
storage.clone(),
)
.await?;
let core = LavinaCore::new(&mut metrics, cluster_config, storage).await?;
let telemetry_terminator = http::launch(telemetry_config, metrics.clone(), core.clone()).await?;
let irc = projection_irc::launch(irc_config, core.clone(), metrics.clone()).await?;
let xmpp = projection_xmpp::launch(xmpp_config, core.clone(), metrics.clone()).await?;
tracing::info!("Started");
sleep.await;
@ -78,10 +83,8 @@ async fn main() -> Result<()> {
xmpp.terminate().await?;
irc.terminate().await?;
telemetry_terminator.terminate().await?;
players.shutdown_all().await?;
drop(players);
drop(rooms);
storage.close().await?;
let storage = core.shutdown().await;
storage.close().await;
tracing::info!("Shutdown complete");
Ok(())
}
@ -106,7 +109,45 @@ fn ctrl_c() -> Result<impl Future<Output = ()>> {
Ok(recv(chan))
}
fn set_up_logging() -> Result<()> {
tracing_subscriber::fmt::init();
fn set_up_logging(tracing_config: &Option<TracingConfig>) -> Result<()> {
let subscriber = tracing_subscriber::registry().with(tracing_subscriber::fmt::layer());
let targets = {
use std::{env, str::FromStr};
use tracing_subscriber::filter::Targets;
match env::var("RUST_LOG") {
Ok(var) => Targets::from_str(&var)
.map_err(|e| {
eprintln!("Ignoring `RUST_LOG={:?}`: {}", var, e);
})
.unwrap_or_default(),
Err(env::VarError::NotPresent) => Targets::new().with_default(Subscriber::DEFAULT_MAX_LEVEL),
Err(e) => {
eprintln!("Ignoring `RUST_LOG`: {}", e);
Targets::new().with_default(Subscriber::DEFAULT_MAX_LEVEL)
}
}
};
if let Some(config) = tracing_config {
let trace_config = opentelemetry_sdk::trace::Config::default()
.with_sampler(Sampler::ParentBased(Box::new(Sampler::TraceIdRatioBased(1.0))))
.with_id_generator(RandomIdGenerator::default())
.with_resource(Resource::from_schema_url(
[KeyValue::new(SERVICE_NAME, config.service_name.to_string())],
SCHEMA_URL,
));
let trace_exporter = opentelemetry_otlp::new_exporter().tonic().with_endpoint(&config.endpoint);
let tracer = opentelemetry_otlp::new_pipeline()
.tracing()
.with_trace_config(trace_config)
.with_batch_config(BatchConfig::default())
.with_exporter(trace_exporter)
.install_batch(runtime::Tokio)?;
let subscriber = subscriber.with(OpenTelemetryLayer::new(tracer));
set_text_map_propagator(TraceContextPropagator::new());
targets.with_subscriber(subscriber).try_init()?;
} else {
targets.with_subscriber(subscriber).try_init()?;
}
Ok(())
}