forked from lavina/lavina
1
0
Fork 0

irc: implement server-time capability for incoming messages (#52)

Spec: https://ircv3.net/specs/extensions/server-time
Reviewed-on: lavina/lavina#52
This commit is contained in:
Nikita Vilunov 2024-04-21 21:00:44 +00:00
parent ddb348bee9
commit 12d30ca5c2
13 changed files with 183 additions and 27 deletions

1
Cargo.lock generated
View File

@ -1263,6 +1263,7 @@ version = "0.0.2-dev"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"bitflags 2.5.0", "bitflags 2.5.0",
"chrono",
"futures-util", "futures-util",
"lavina-core", "lavina-core",
"nonempty", "nonempty",

View File

@ -31,6 +31,7 @@ base64 = "0.22.0"
lavina-core = { path = "crates/lavina-core" } lavina-core = { path = "crates/lavina-core" }
tracing-subscriber = "0.3.16" tracing-subscriber = "0.3.16"
sasl = { path = "crates/sasl" } sasl = { path = "crates/sasl" }
chrono = "0.4.37"
[package] [package]
name = "lavina" name = "lavina"

View File

@ -10,4 +10,4 @@ serde.workspace = true
tokio.workspace = true tokio.workspace = true
tracing.workspace = true tracing.workspace = true
prometheus.workspace = true prometheus.workspace = true
chrono = "0.4.37" chrono.workspace = true

View File

@ -10,6 +10,7 @@
use std::collections::{HashMap, HashSet}; use std::collections::{HashMap, HashSet};
use std::sync::Arc; use std::sync::Arc;
use chrono::{DateTime, Utc};
use prometheus::{IntGauge, Registry as MetricsRegistry}; use prometheus::{IntGauge, Registry as MetricsRegistry};
use serde::Serialize; use serde::Serialize;
use tokio::sync::mpsc::{channel, Receiver, Sender}; use tokio::sync::mpsc::{channel, Receiver, Sender};
@ -57,7 +58,7 @@ pub struct PlayerConnection {
} }
impl PlayerConnection { impl PlayerConnection {
/// Handled in [Player::send_message]. /// Handled in [Player::send_message].
pub async fn send_message(&mut self, room_id: RoomId, body: Str) -> Result<()> { pub async fn send_message(&mut self, room_id: RoomId, body: Str) -> Result<SendMessageResult> {
let (promise, deferred) = oneshot(); let (promise, deferred) = oneshot();
let cmd = ClientCommand::SendMessage { room_id, body, promise }; let cmd = ClientCommand::SendMessage { room_id, body, promise };
self.player_handle.send(ActorCommand::ClientCommand(cmd, self.connection_id.clone())).await; self.player_handle.send(ActorCommand::ClientCommand(cmd, self.connection_id.clone())).await;
@ -163,7 +164,7 @@ pub enum ClientCommand {
SendMessage { SendMessage {
room_id: RoomId, room_id: RoomId,
body: Str, body: Str,
promise: Promise<()>, promise: Promise<SendMessageResult>,
}, },
ChangeTopic { ChangeTopic {
room_id: RoomId, room_id: RoomId,
@ -181,6 +182,11 @@ pub enum JoinResult {
Banned, Banned,
} }
pub enum SendMessageResult {
Success(DateTime<Utc>),
NoSuchRoom,
}
/// Player update event type which is sent to a player actor and from there to a connection handler. /// Player update event type which is sent to a player actor and from there to a connection handler.
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub enum Updates { pub enum Updates {
@ -192,6 +198,7 @@ pub enum Updates {
room_id: RoomId, room_id: RoomId,
author_id: PlayerId, author_id: PlayerId,
body: Str, body: Str,
created_at: DateTime<Utc>,
}, },
RoomJoined { RoomJoined {
room_id: RoomId, room_id: RoomId,
@ -367,8 +374,8 @@ impl Player {
let _ = promise.send(()); let _ = promise.send(());
} }
ClientCommand::SendMessage { room_id, body, promise } => { ClientCommand::SendMessage { room_id, body, promise } => {
self.send_message(connection_id, room_id, body).await; let result = self.send_message(connection_id, room_id, body).await;
let _ = promise.send(()); let _ = promise.send(result);
} }
ClientCommand::ChangeTopic { ClientCommand::ChangeTopic {
room_id, room_id,
@ -425,18 +432,21 @@ impl Player {
self.broadcast_update(update, connection_id).await; self.broadcast_update(update, connection_id).await;
} }
async fn send_message(&mut self, connection_id: ConnectionId, room_id: RoomId, body: Str) { async fn send_message(&mut self, connection_id: ConnectionId, room_id: RoomId, body: Str) -> SendMessageResult {
let Some(room) = self.my_rooms.get(&room_id) else { let Some(room) = self.my_rooms.get(&room_id) else {
tracing::info!("no room found"); tracing::info!("no room found");
return; return SendMessageResult::NoSuchRoom;
}; };
room.send_message(&self.player_id, body.clone()).await; let created_at = chrono::Utc::now();
room.send_message(&self.player_id, body.clone(), created_at.clone()).await;
let update = Updates::NewMessage { let update = Updates::NewMessage {
room_id, room_id,
author_id: self.player_id.clone(), author_id: self.player_id.clone(),
body, body,
created_at,
}; };
self.broadcast_update(update, connection_id).await; self.broadcast_update(update, connection_id).await;
SendMessageResult::Success(created_at)
} }
async fn change_topic(&mut self, connection_id: ConnectionId, room_id: RoomId, new_topic: Str) { async fn change_topic(&mut self, connection_id: ConnectionId, room_id: RoomId, new_topic: Str) {

View File

@ -4,6 +4,7 @@ use std::str::FromStr;
use std::sync::Arc; use std::sync::Arc;
use anyhow::anyhow; use anyhow::anyhow;
use chrono::{DateTime, Utc};
use serde::Deserialize; use serde::Deserialize;
use sqlx::sqlite::SqliteConnectOptions; use sqlx::sqlite::SqliteConnectOptions;
use sqlx::{ConnectOptions, Connection, FromRow, Sqlite, SqliteConnection, Transaction}; use sqlx::{ConnectOptions, Connection, FromRow, Sqlite, SqliteConnection, Transaction};
@ -80,7 +81,14 @@ impl Storage {
Ok(id) Ok(id)
} }
pub async fn insert_message(&mut self, room_id: u32, id: u32, content: &str, author_id: &str) -> Result<()> { pub async fn insert_message(
&mut self,
room_id: u32,
id: u32,
content: &str,
author_id: &str,
created_at: &DateTime<Utc>,
) -> Result<()> {
let mut executor = self.conn.lock().await; let mut executor = self.conn.lock().await;
let res: Option<(u32,)> = sqlx::query_as("select id from users where name = ?;") let res: Option<(u32,)> = sqlx::query_as("select id from users where name = ?;")
.bind(author_id) .bind(author_id)
@ -98,7 +106,7 @@ impl Storage {
.bind(id) .bind(id)
.bind(content) .bind(content)
.bind(author_id) .bind(author_id)
.bind(chrono::Utc::now().to_string()) .bind(created_at.to_string())
.bind(room_id) .bind(room_id)
.execute(&mut *executor) .execute(&mut *executor)
.await?; .await?;

View File

@ -2,6 +2,7 @@
use std::collections::HashSet; use std::collections::HashSet;
use std::{collections::HashMap, hash::Hash, sync::Arc}; use std::{collections::HashMap, hash::Hash, sync::Arc};
use chrono::{DateTime, Utc};
use prometheus::{IntGauge, Registry as MetricRegistry}; use prometheus::{IntGauge, Registry as MetricRegistry};
use serde::Serialize; use serde::Serialize;
use tokio::sync::RwLock as AsyncRwLock; use tokio::sync::RwLock as AsyncRwLock;
@ -163,9 +164,9 @@ impl RoomHandle {
lock.broadcast_update(update, player_id).await; lock.broadcast_update(update, player_id).await;
} }
pub async fn send_message(&self, player_id: &PlayerId, body: Str) { pub async fn send_message(&self, player_id: &PlayerId, body: Str, created_at: DateTime<Utc>) {
let mut lock = self.0.write().await; let mut lock = self.0.write().await;
let res = lock.send_message(player_id, body).await; let res = lock.send_message(player_id, body, created_at).await;
if let Err(err) = res { if let Err(err) = res {
log::warn!("Failed to send message: {err:?}"); log::warn!("Failed to send message: {err:?}");
} }
@ -208,14 +209,23 @@ struct Room {
storage: Storage, storage: Storage,
} }
impl Room { impl Room {
async fn send_message(&mut self, author_id: &PlayerId, body: Str) -> Result<()> { async fn send_message(&mut self, author_id: &PlayerId, body: Str, created_at: DateTime<Utc>) -> Result<()> {
tracing::info!("Adding a message to room"); tracing::info!("Adding a message to room");
self.storage.insert_message(self.storage_id, self.message_count, &body, &*author_id.as_inner()).await?; self.storage
.insert_message(
self.storage_id,
self.message_count,
&body,
&*author_id.as_inner(),
&created_at,
)
.await?;
self.message_count += 1; self.message_count += 1;
let update = Updates::NewMessage { let update = Updates::NewMessage {
room_id: self.room_id.clone(), room_id: self.room_id.clone(),
author_id: author_id.clone(), author_id: author_id.clone(),
body, body,
created_at,
}; };
self.broadcast_update(update, author_id).await; self.broadcast_update(update, author_id).await;
Ok(()) Ok(())

View File

@ -12,6 +12,7 @@ tokio.workspace = true
prometheus.workspace = true prometheus.workspace = true
futures-util.workspace = true futures-util.workspace = true
nonempty.workspace = true nonempty.workspace = true
chrono.workspace = true
bitflags = "2.4.1" bitflags = "2.4.1"
proto-irc = { path = "../proto-irc" } proto-irc = { path = "../proto-irc" }
sasl = { path = "../sasl" } sasl = { path = "../sasl" }

View File

@ -1,9 +1,10 @@
use bitflags::bitflags; use bitflags::bitflags;
bitflags! { bitflags! {
#[derive(Debug)] #[derive(Debug, Clone, Copy)]
pub struct Capabilities: u32 { pub struct Capabilities: u32 {
const None = 0; const None = 0;
const Sasl = 1 << 0; const Sasl = 1 << 0;
const ServerTime = 1 << 1;
} }
} }

View File

@ -2,6 +2,7 @@ use std::collections::HashMap;
use std::net::SocketAddr; use std::net::SocketAddr;
use anyhow::{anyhow, Result}; use anyhow::{anyhow, Result};
use chrono::SecondsFormat;
use futures_util::future::join_all; use futures_util::future::join_all;
use nonempty::nonempty; use nonempty::nonempty;
use nonempty::NonEmpty; use nonempty::NonEmpty;
@ -24,7 +25,7 @@ use proto_irc::client::{client_message, ClientMessage};
use proto_irc::server::CapSubBody; use proto_irc::server::CapSubBody;
use proto_irc::server::{AwayStatus, ServerMessage, ServerMessageBody}; use proto_irc::server::{AwayStatus, ServerMessage, ServerMessageBody};
use proto_irc::user::PrefixedNick; use proto_irc::user::PrefixedNick;
use proto_irc::{Chan, Recipient}; use proto_irc::{Chan, Recipient, Tag};
use sasl::AuthBody; use sasl::AuthBody;
mod cap; mod cap;
@ -49,6 +50,7 @@ struct RegisteredUser {
*/ */
username: Str, username: Str,
realname: Str, realname: Str,
enabled_capabilities: Capabilities,
} }
async fn handle_socket( async fn handle_socket(
@ -136,7 +138,7 @@ impl RegistrationState {
sender: Some(config.server_name.clone().into()), sender: Some(config.server_name.clone().into()),
body: ServerMessageBody::Cap { body: ServerMessageBody::Cap {
target: self.future_nickname.clone().unwrap_or_else(|| "*".into()), target: self.future_nickname.clone().unwrap_or_else(|| "*".into()),
subcmd: CapSubBody::Ls("sasl=PLAIN".into()), subcmd: CapSubBody::Ls("sasl=PLAIN server-time".into()),
}, },
} }
.write_async(writer) .write_async(writer)
@ -156,16 +158,30 @@ impl RegistrationState {
self.enabled_capabilities |= Capabilities::Sasl; self.enabled_capabilities |= Capabilities::Sasl;
} }
acked.push(cap); acked.push(cap);
} else if &*cap.name == "server-time" {
if cap.to_disable {
self.enabled_capabilities &= !Capabilities::ServerTime;
} else {
self.enabled_capabilities |= Capabilities::ServerTime;
}
acked.push(cap);
} else { } else {
naked.push(cap); naked.push(cap);
} }
} }
let mut ack_body = String::new(); let mut ack_body = String::new();
for cap in acked { if let Some((first, tail)) = acked.split_first() {
if cap.to_disable { if first.to_disable {
ack_body.push('-'); ack_body.push('-');
} }
ack_body += &*cap.name; ack_body += &*first.name;
for cap in tail {
ack_body.push(' ');
if cap.to_disable {
ack_body.push('-');
}
ack_body += &*cap.name;
}
} }
ServerMessage { ServerMessage {
tags: vec![], tags: vec![],
@ -195,6 +211,7 @@ impl RegistrationState {
nickname: nickname.clone(), nickname: nickname.clone(),
username, username,
realname, realname,
enabled_capabilities: self.enabled_capabilities,
}; };
self.finalize_auth(candidate_user, writer, storage, config).await self.finalize_auth(candidate_user, writer, storage, config).await
} }
@ -208,6 +225,7 @@ impl RegistrationState {
nickname: nickname.clone(), nickname: nickname.clone(),
username: username.clone(), username: username.clone(),
realname: realname.clone(), realname: realname.clone(),
enabled_capabilities: self.enabled_capabilities,
}; };
self.finalize_auth(candidate_user, writer, storage, config).await self.finalize_auth(candidate_user, writer, storage, config).await
} else { } else {
@ -224,6 +242,7 @@ impl RegistrationState {
nickname: nickname.clone(), nickname: nickname.clone(),
username, username,
realname, realname,
enabled_capabilities: self.enabled_capabilities,
}; };
self.finalize_auth(candidate_user, writer, storage, config).await self.finalize_auth(candidate_user, writer, storage, config).await
} else { } else {
@ -587,9 +606,18 @@ async fn handle_update(
author_id, author_id,
room_id, room_id,
body, body,
created_at,
} => { } => {
let mut tags = vec![];
if user.enabled_capabilities.contains(Capabilities::ServerTime) {
let tag = Tag {
key: "time".into(),
value: Some(created_at.to_rfc3339_opts(SecondsFormat::Millis, true).into()),
};
tags.push(tag);
}
ServerMessage { ServerMessage {
tags: vec![], tags,
sender: Some(author_id.as_inner().clone()), sender: Some(author_id.as_inner().clone()),
body: ServerMessageBody::PrivateMessage { body: ServerMessageBody::PrivateMessage {
target: Recipient::Chan(Chan::Global(room_id.as_inner().clone())), target: Recipient::Chan(Chan::Global(room_id.as_inner().clone())),

View File

@ -1,17 +1,20 @@
use std::io::ErrorKind; use std::io::ErrorKind;
use std::net::SocketAddr;
use std::time::Duration; use std::time::Duration;
use anyhow::{anyhow, Result}; use anyhow::{anyhow, Result};
use chrono::{DateTime, SecondsFormat};
use prometheus::Registry as MetricsRegistry; use prometheus::Registry as MetricsRegistry;
use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader}; use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader};
use tokio::net::tcp::{ReadHalf, WriteHalf}; use tokio::net::tcp::{ReadHalf, WriteHalf};
use tokio::net::TcpStream; use tokio::net::TcpStream;
use lavina_core::player::{JoinResult, PlayerId, SendMessageResult};
use lavina_core::repo::{Storage, StorageConfig}; use lavina_core::repo::{Storage, StorageConfig};
use lavina_core::room::RoomId;
use lavina_core::LavinaCore; use lavina_core::LavinaCore;
use projection_irc::APP_VERSION; use projection_irc::APP_VERSION;
use projection_irc::{launch, read_irc_message, RunningServer, ServerConfig}; use projection_irc::{launch, read_irc_message, RunningServer, ServerConfig};
struct TestScope<'a> { struct TestScope<'a> {
reader: BufReader<ReadHalf<'a>>, reader: BufReader<ReadHalf<'a>>,
writer: WriteHalf<'a>, writer: WriteHalf<'a>,
@ -89,6 +92,11 @@ impl<'a> TestScope<'a> {
Err(_) => Ok(()), Err(_) => Ok(()),
} }
} }
async fn expect_cap_ls(&mut self) -> Result<()> {
self.expect(":testserver CAP * LS :sasl=PLAIN server-time").await?;
Ok(())
}
} }
struct TestServer { struct TestServer {
@ -388,7 +396,7 @@ async fn scenario_cap_full_negotiation() -> Result<()> {
s.send("CAP LS 302").await?; s.send("CAP LS 302").await?;
s.send("NICK tester").await?; s.send("NICK tester").await?;
s.send("USER UserName 0 * :Real Name").await?; s.send("USER UserName 0 * :Real Name").await?;
s.expect(":testserver CAP * LS :sasl=PLAIN").await?; s.expect_cap_ls().await?;
s.send("CAP REQ :sasl").await?; s.send("CAP REQ :sasl").await?;
s.expect(":testserver CAP tester ACK :sasl").await?; s.expect(":testserver CAP tester ACK :sasl").await?;
s.send("AUTHENTICATE PLAIN").await?; s.send("AUTHENTICATE PLAIN").await?;
@ -426,7 +434,7 @@ async fn scenario_cap_full_negotiation_nick_last() -> Result<()> {
let mut s = TestScope::new(&mut stream); let mut s = TestScope::new(&mut stream);
s.send("CAP LS 302").await?; s.send("CAP LS 302").await?;
s.expect(":testserver CAP * LS :sasl=PLAIN").await?; s.expect_cap_ls().await?;
s.send("CAP REQ :sasl").await?; s.send("CAP REQ :sasl").await?;
s.expect(":testserver CAP * ACK :sasl").await?; s.expect(":testserver CAP * ACK :sasl").await?;
s.send("AUTHENTICATE PLAIN").await?; s.send("AUTHENTICATE PLAIN").await?;
@ -505,7 +513,7 @@ async fn scenario_cap_sasl_fail() -> Result<()> {
s.send("CAP LS 302").await?; s.send("CAP LS 302").await?;
s.send("NICK tester").await?; s.send("NICK tester").await?;
s.send("USER UserName 0 * :Real Name").await?; s.send("USER UserName 0 * :Real Name").await?;
s.expect(":testserver CAP * LS :sasl=PLAIN").await?; s.expect_cap_ls().await?;
s.send("CAP REQ :sasl").await?; s.send("CAP REQ :sasl").await?;
s.expect(":testserver CAP tester ACK :sasl").await?; s.expect(":testserver CAP tester ACK :sasl").await?;
s.send("AUTHENTICATE SHA256").await?; s.send("AUTHENTICATE SHA256").await?;
@ -558,3 +566,72 @@ async fn terminate_socket_scenario() -> Result<()> {
Ok(()) Ok(())
} }
#[tokio::test]
async fn server_time_capability() -> Result<()> {
let mut server = TestServer::start().await?;
// test scenario
server.storage.create_user("tester").await?;
server.storage.set_password("tester", "password").await?;
let mut stream = TcpStream::connect(server.server.addr).await?;
let mut s = TestScope::new(&mut stream);
s.send("CAP LS 302").await?;
s.send("NICK tester").await?;
s.send("USER UserName 0 * :Real Name").await?;
s.expect_cap_ls().await?;
s.send("CAP REQ :sasl server-time").await?;
s.expect(":testserver CAP tester ACK :sasl server-time").await?;
s.send("AUTHENTICATE PLAIN").await?;
s.expect(":testserver AUTHENTICATE +").await?;
s.send("AUTHENTICATE dGVzdGVyAHRlc3RlcgBwYXNzd29yZA==").await?; // base64-encoded 'tester\x00tester\x00password'
s.expect(":testserver 900 tester tester tester :You are now logged in as tester").await?;
s.expect(":testserver 903 tester :SASL authentication successful").await?;
s.send("CAP END").await?;
s.expect_server_introduction("tester").await?;
s.expect_nothing().await?;
s.send("JOIN #test").await?;
s.expect(":tester JOIN #test").await?;
s.expect(":testserver 332 tester #test :New room").await?;
s.expect(":testserver 353 tester = #test :tester").await?;
s.expect(":testserver 366 tester #test :End of /NAMES list").await?;
server.storage.create_user("some_guy").await?;
let mut conn = server.core.players.connect_to_player(&PlayerId::from("some_guy").unwrap()).await;
let res = conn.join_room(RoomId::from("test").unwrap()).await?;
let JoinResult::Success(_) = res else {
panic!("Failed to join room");
};
s.expect(":some_guy JOIN #test").await?;
let SendMessageResult::Success(res) = conn.send_message(RoomId::from("test").unwrap(), "Hello".into()).await?
else {
panic!("Failed to send message");
};
s.expect(&format!(
"@time={} :some_guy PRIVMSG #test :Hello",
res.to_rfc3339_opts(SecondsFormat::Millis, true)
))
.await?;
// formatting check
assert_eq!(
DateTime::parse_from_rfc3339(&"2024-01-01T10:00:32.123Z").unwrap().to_rfc3339_opts(SecondsFormat::Millis, true),
"2024-01-01T10:00:32.123Z"
);
s.send("QUIT :Leaving").await?;
s.expect(":testserver ERROR :Leaving the server").await?;
s.expect_eof().await?;
stream.shutdown().await?;
// wrap up
server.server.terminate().await?;
Ok(())
}

View File

@ -17,6 +17,7 @@ impl<'a> XmppConnection<'a> {
room_id, room_id,
author_id, author_id,
body, body,
created_at: _,
} => { } => {
Message::<()> { Message::<()> {
to: Some(Jid { to: Some(Jid {

View File

@ -18,8 +18,19 @@ use tokio::io::{AsyncWrite, AsyncWriteExt};
/// Single message tag value. /// Single message tag value.
#[derive(Clone, Debug, PartialEq, Eq)] #[derive(Clone, Debug, PartialEq, Eq)]
pub struct Tag { pub struct Tag {
key: Str, pub key: Str,
value: Option<u8>, pub value: Option<Str>,
}
impl Tag {
pub async fn write_async(&self, writer: &mut (impl AsyncWrite + Unpin)) -> std::io::Result<()> {
writer.write_all(self.key.as_bytes()).await?;
if let Some(value) = &self.value {
writer.write_all(b"=").await?;
writer.write_all(value.as_bytes()).await?;
}
Ok(())
}
} }
fn receiver(input: &str) -> IResult<&str, &str> { fn receiver(input: &str) -> IResult<&str, &str> {

View File

@ -19,6 +19,13 @@ pub struct ServerMessage {
impl ServerMessage { impl ServerMessage {
pub async fn write_async(&self, writer: &mut (impl AsyncWrite + Unpin)) -> std::io::Result<()> { pub async fn write_async(&self, writer: &mut (impl AsyncWrite + Unpin)) -> std::io::Result<()> {
if !self.tags.is_empty() {
for tag in &self.tags {
writer.write_all(b"@").await?;
tag.write_async(writer).await?;
writer.write_all(b" ").await?;
}
}
match &self.sender { match &self.sender {
Some(ref sender) => { Some(ref sender) => {
writer.write_all(b":").await?; writer.write_all(b":").await?;