refactor: renames and split room.rs
This commit is contained in:
parent
92e59f14e0
commit
025b64befc
67 changed files with 278 additions and 45801 deletions
|
@ -1,153 +0,0 @@
|
|||
use crate::{utils, Error, Result};
|
||||
use ruma::{
|
||||
api::client::error::ErrorKind,
|
||||
events::{AnyEphemeralRoomEvent, RoomAccountDataEventType},
|
||||
serde::Raw,
|
||||
RoomId, UserId,
|
||||
};
|
||||
use serde::{de::DeserializeOwned, Serialize};
|
||||
use std::{collections::HashMap, sync::Arc};
|
||||
|
||||
use super::abstraction::Tree;
|
||||
|
||||
pub struct AccountData {
|
||||
pub(super) roomuserdataid_accountdata: Arc<dyn Tree>, // RoomUserDataId = Room + User + Count + Type
|
||||
pub(super) roomusertype_roomuserdataid: Arc<dyn Tree>, // RoomUserType = Room + User + Type
|
||||
}
|
||||
|
||||
impl AccountData {
|
||||
/// Places one event in the account data of the user and removes the previous entry.
|
||||
#[tracing::instrument(skip(self, room_id, user_id, event_type, data, globals))]
|
||||
pub fn update<T: Serialize>(
|
||||
&self,
|
||||
room_id: Option<&RoomId>,
|
||||
user_id: &UserId,
|
||||
event_type: RoomAccountDataEventType,
|
||||
data: &T,
|
||||
globals: &super::globals::Globals,
|
||||
) -> Result<()> {
|
||||
let mut prefix = room_id
|
||||
.map(|r| r.to_string())
|
||||
.unwrap_or_default()
|
||||
.as_bytes()
|
||||
.to_vec();
|
||||
prefix.push(0xff);
|
||||
prefix.extend_from_slice(user_id.as_bytes());
|
||||
prefix.push(0xff);
|
||||
|
||||
let mut roomuserdataid = prefix.clone();
|
||||
roomuserdataid.extend_from_slice(&globals.next_count()?.to_be_bytes());
|
||||
roomuserdataid.push(0xff);
|
||||
roomuserdataid.extend_from_slice(event_type.to_string().as_bytes());
|
||||
|
||||
let mut key = prefix;
|
||||
key.extend_from_slice(event_type.to_string().as_bytes());
|
||||
|
||||
let json = serde_json::to_value(data).expect("all types here can be serialized"); // TODO: maybe add error handling
|
||||
if json.get("type").is_none() || json.get("content").is_none() {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
"Account data doesn't have all required fields.",
|
||||
));
|
||||
}
|
||||
|
||||
self.roomuserdataid_accountdata.insert(
|
||||
&roomuserdataid,
|
||||
&serde_json::to_vec(&json).expect("to_vec always works on json values"),
|
||||
)?;
|
||||
|
||||
let prev = self.roomusertype_roomuserdataid.get(&key)?;
|
||||
|
||||
self.roomusertype_roomuserdataid
|
||||
.insert(&key, &roomuserdataid)?;
|
||||
|
||||
// Remove old entry
|
||||
if let Some(prev) = prev {
|
||||
self.roomuserdataid_accountdata.remove(&prev)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Searches the account data for a specific kind.
|
||||
#[tracing::instrument(skip(self, room_id, user_id, kind))]
|
||||
pub fn get<T: DeserializeOwned>(
|
||||
&self,
|
||||
room_id: Option<&RoomId>,
|
||||
user_id: &UserId,
|
||||
kind: RoomAccountDataEventType,
|
||||
) -> Result<Option<T>> {
|
||||
let mut key = room_id
|
||||
.map(|r| r.to_string())
|
||||
.unwrap_or_default()
|
||||
.as_bytes()
|
||||
.to_vec();
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(user_id.as_bytes());
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(kind.to_string().as_bytes());
|
||||
|
||||
self.roomusertype_roomuserdataid
|
||||
.get(&key)?
|
||||
.and_then(|roomuserdataid| {
|
||||
self.roomuserdataid_accountdata
|
||||
.get(&roomuserdataid)
|
||||
.transpose()
|
||||
})
|
||||
.transpose()?
|
||||
.map(|data| {
|
||||
serde_json::from_slice(&data)
|
||||
.map_err(|_| Error::bad_database("could not deserialize"))
|
||||
})
|
||||
.transpose()
|
||||
}
|
||||
|
||||
/// Returns all changes to the account data that happened after `since`.
|
||||
#[tracing::instrument(skip(self, room_id, user_id, since))]
|
||||
pub fn changes_since(
|
||||
&self,
|
||||
room_id: Option<&RoomId>,
|
||||
user_id: &UserId,
|
||||
since: u64,
|
||||
) -> Result<HashMap<RoomAccountDataEventType, Raw<AnyEphemeralRoomEvent>>> {
|
||||
let mut userdata = HashMap::new();
|
||||
|
||||
let mut prefix = room_id
|
||||
.map(|r| r.to_string())
|
||||
.unwrap_or_default()
|
||||
.as_bytes()
|
||||
.to_vec();
|
||||
prefix.push(0xff);
|
||||
prefix.extend_from_slice(user_id.as_bytes());
|
||||
prefix.push(0xff);
|
||||
|
||||
// Skip the data that's exactly at since, because we sent that last time
|
||||
let mut first_possible = prefix.clone();
|
||||
first_possible.extend_from_slice(&(since + 1).to_be_bytes());
|
||||
|
||||
for r in self
|
||||
.roomuserdataid_accountdata
|
||||
.iter_from(&first_possible, false)
|
||||
.take_while(move |(k, _)| k.starts_with(&prefix))
|
||||
.map(|(k, v)| {
|
||||
Ok::<_, Error>((
|
||||
RoomAccountDataEventType::try_from(
|
||||
utils::string_from_bytes(k.rsplit(|&b| b == 0xff).next().ok_or_else(
|
||||
|| Error::bad_database("RoomUserData ID in db is invalid."),
|
||||
)?)
|
||||
.map_err(|_| Error::bad_database("RoomUserData ID in db is invalid."))?,
|
||||
)
|
||||
.map_err(|_| Error::bad_database("RoomUserData ID in db is invalid."))?,
|
||||
serde_json::from_slice::<Raw<AnyEphemeralRoomEvent>>(&v).map_err(|_| {
|
||||
Error::bad_database("Database contains invalid account data.")
|
||||
})?,
|
||||
))
|
||||
})
|
||||
{
|
||||
let (kind, data) = r?;
|
||||
userdata.insert(kind, data);
|
||||
}
|
||||
|
||||
Ok(userdata)
|
||||
}
|
||||
}
|
File diff suppressed because it is too large
Load diff
|
@ -1,88 +0,0 @@
|
|||
use crate::{utils, Error, Result};
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
sync::{Arc, RwLock},
|
||||
};
|
||||
|
||||
use super::abstraction::Tree;
|
||||
|
||||
pub struct Appservice {
|
||||
pub(super) cached_registrations: Arc<RwLock<HashMap<String, serde_yaml::Value>>>,
|
||||
pub(super) id_appserviceregistrations: Arc<dyn Tree>,
|
||||
}
|
||||
|
||||
impl Appservice {
|
||||
/// Registers an appservice and returns the ID to the caller
|
||||
///
|
||||
pub fn register_appservice(&self, yaml: serde_yaml::Value) -> Result<String> {
|
||||
// TODO: Rumaify
|
||||
let id = yaml.get("id").unwrap().as_str().unwrap();
|
||||
self.id_appserviceregistrations.insert(
|
||||
id.as_bytes(),
|
||||
serde_yaml::to_string(&yaml).unwrap().as_bytes(),
|
||||
)?;
|
||||
self.cached_registrations
|
||||
.write()
|
||||
.unwrap()
|
||||
.insert(id.to_owned(), yaml.to_owned());
|
||||
|
||||
Ok(id.to_owned())
|
||||
}
|
||||
|
||||
/// Remove an appservice registration
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `service_name` - the name you send to register the service previously
|
||||
pub fn unregister_appservice(&self, service_name: &str) -> Result<()> {
|
||||
self.id_appserviceregistrations
|
||||
.remove(service_name.as_bytes())?;
|
||||
self.cached_registrations
|
||||
.write()
|
||||
.unwrap()
|
||||
.remove(service_name);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn get_registration(&self, id: &str) -> Result<Option<serde_yaml::Value>> {
|
||||
self.cached_registrations
|
||||
.read()
|
||||
.unwrap()
|
||||
.get(id)
|
||||
.map_or_else(
|
||||
|| {
|
||||
self.id_appserviceregistrations
|
||||
.get(id.as_bytes())?
|
||||
.map(|bytes| {
|
||||
serde_yaml::from_slice(&bytes).map_err(|_| {
|
||||
Error::bad_database(
|
||||
"Invalid registration bytes in id_appserviceregistrations.",
|
||||
)
|
||||
})
|
||||
})
|
||||
.transpose()
|
||||
},
|
||||
|r| Ok(Some(r.clone())),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn iter_ids(&self) -> Result<impl Iterator<Item = Result<String>> + '_> {
|
||||
Ok(self.id_appserviceregistrations.iter().map(|(id, _)| {
|
||||
utils::string_from_bytes(&id)
|
||||
.map_err(|_| Error::bad_database("Invalid id bytes in id_appserviceregistrations."))
|
||||
}))
|
||||
}
|
||||
|
||||
pub fn all(&self) -> Result<Vec<(String, serde_yaml::Value)>> {
|
||||
self.iter_ids()?
|
||||
.filter_map(|id| id.ok())
|
||||
.map(move |id| {
|
||||
Ok((
|
||||
id.clone(),
|
||||
self.get_registration(&id)?
|
||||
.expect("iter_ids only returns appservices that exist"),
|
||||
))
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
}
|
|
@ -1,420 +0,0 @@
|
|||
use crate::{database::Config, server_server::FedDest, utils, Error, Result};
|
||||
use ruma::{
|
||||
api::{
|
||||
client::sync::sync_events,
|
||||
federation::discovery::{ServerSigningKeys, VerifyKey},
|
||||
},
|
||||
DeviceId, EventId, MilliSecondsSinceUnixEpoch, RoomId, RoomVersionId, ServerName,
|
||||
ServerSigningKeyId, UserId,
|
||||
};
|
||||
use std::{
|
||||
collections::{BTreeMap, HashMap},
|
||||
fs,
|
||||
future::Future,
|
||||
net::{IpAddr, SocketAddr},
|
||||
path::PathBuf,
|
||||
sync::{Arc, Mutex, RwLock},
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
use tokio::sync::{broadcast, watch::Receiver, Mutex as TokioMutex, Semaphore};
|
||||
use tracing::error;
|
||||
use trust_dns_resolver::TokioAsyncResolver;
|
||||
|
||||
use super::abstraction::Tree;
|
||||
|
||||
pub const COUNTER: &[u8] = b"c";
|
||||
|
||||
type WellKnownMap = HashMap<Box<ServerName>, (FedDest, String)>;
|
||||
type TlsNameMap = HashMap<String, (Vec<IpAddr>, u16)>;
|
||||
type RateLimitState = (Instant, u32); // Time if last failed try, number of failed tries
|
||||
type SyncHandle = (
|
||||
Option<String>, // since
|
||||
Receiver<Option<Result<sync_events::v3::Response>>>, // rx
|
||||
);
|
||||
|
||||
pub struct Globals {
|
||||
pub actual_destination_cache: Arc<RwLock<WellKnownMap>>, // actual_destination, host
|
||||
pub tls_name_override: Arc<RwLock<TlsNameMap>>,
|
||||
pub(super) globals: Arc<dyn Tree>,
|
||||
pub config: Config,
|
||||
keypair: Arc<ruma::signatures::Ed25519KeyPair>,
|
||||
dns_resolver: TokioAsyncResolver,
|
||||
jwt_decoding_key: Option<jsonwebtoken::DecodingKey<'static>>,
|
||||
federation_client: reqwest::Client,
|
||||
default_client: reqwest::Client,
|
||||
pub stable_room_versions: Vec<RoomVersionId>,
|
||||
pub unstable_room_versions: Vec<RoomVersionId>,
|
||||
pub(super) server_signingkeys: Arc<dyn Tree>,
|
||||
pub bad_event_ratelimiter: Arc<RwLock<HashMap<Box<EventId>, RateLimitState>>>,
|
||||
pub bad_signature_ratelimiter: Arc<RwLock<HashMap<Vec<String>, RateLimitState>>>,
|
||||
pub servername_ratelimiter: Arc<RwLock<HashMap<Box<ServerName>, Arc<Semaphore>>>>,
|
||||
pub sync_receivers: RwLock<HashMap<(Box<UserId>, Box<DeviceId>), SyncHandle>>,
|
||||
pub roomid_mutex_insert: RwLock<HashMap<Box<RoomId>, Arc<Mutex<()>>>>,
|
||||
pub roomid_mutex_state: RwLock<HashMap<Box<RoomId>, Arc<TokioMutex<()>>>>,
|
||||
pub roomid_mutex_federation: RwLock<HashMap<Box<RoomId>, Arc<TokioMutex<()>>>>, // this lock will be held longer
|
||||
pub roomid_federationhandletime: RwLock<HashMap<Box<RoomId>, (Box<EventId>, Instant)>>,
|
||||
pub stateres_mutex: Arc<Mutex<()>>,
|
||||
pub rotate: RotationHandler,
|
||||
}
|
||||
|
||||
/// Handles "rotation" of long-polling requests. "Rotation" in this context is similar to "rotation" of log files and the like.
|
||||
///
|
||||
/// This is utilized to have sync workers return early and release read locks on the database.
|
||||
pub struct RotationHandler(broadcast::Sender<()>, broadcast::Receiver<()>);
|
||||
|
||||
impl RotationHandler {
|
||||
pub fn new() -> Self {
|
||||
let (s, r) = broadcast::channel(1);
|
||||
Self(s, r)
|
||||
}
|
||||
|
||||
pub fn watch(&self) -> impl Future<Output = ()> {
|
||||
let mut r = self.0.subscribe();
|
||||
|
||||
async move {
|
||||
let _ = r.recv().await;
|
||||
}
|
||||
}
|
||||
|
||||
pub fn fire(&self) {
|
||||
let _ = self.0.send(());
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for RotationHandler {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl Globals {
|
||||
pub fn load(
|
||||
globals: Arc<dyn Tree>,
|
||||
server_signingkeys: Arc<dyn Tree>,
|
||||
config: Config,
|
||||
) -> Result<Self> {
|
||||
let keypair_bytes = globals.get(b"keypair")?.map_or_else(
|
||||
|| {
|
||||
let keypair = utils::generate_keypair();
|
||||
globals.insert(b"keypair", &keypair)?;
|
||||
Ok::<_, Error>(keypair)
|
||||
},
|
||||
|s| Ok(s.to_vec()),
|
||||
)?;
|
||||
|
||||
let mut parts = keypair_bytes.splitn(2, |&b| b == 0xff);
|
||||
|
||||
let keypair = utils::string_from_bytes(
|
||||
// 1. version
|
||||
parts
|
||||
.next()
|
||||
.expect("splitn always returns at least one element"),
|
||||
)
|
||||
.map_err(|_| Error::bad_database("Invalid version bytes in keypair."))
|
||||
.and_then(|version| {
|
||||
// 2. key
|
||||
parts
|
||||
.next()
|
||||
.ok_or_else(|| Error::bad_database("Invalid keypair format in database."))
|
||||
.map(|key| (version, key))
|
||||
})
|
||||
.and_then(|(version, key)| {
|
||||
ruma::signatures::Ed25519KeyPair::from_der(key, version)
|
||||
.map_err(|_| Error::bad_database("Private or public keys are invalid."))
|
||||
});
|
||||
|
||||
let keypair = match keypair {
|
||||
Ok(k) => k,
|
||||
Err(e) => {
|
||||
error!("Keypair invalid. Deleting...");
|
||||
globals.remove(b"keypair")?;
|
||||
return Err(e);
|
||||
}
|
||||
};
|
||||
|
||||
let tls_name_override = Arc::new(RwLock::new(TlsNameMap::new()));
|
||||
|
||||
let jwt_decoding_key = config
|
||||
.jwt_secret
|
||||
.as_ref()
|
||||
.map(|secret| jsonwebtoken::DecodingKey::from_secret(secret.as_bytes()).into_static());
|
||||
|
||||
let default_client = reqwest_client_builder(&config)?.build()?;
|
||||
let name_override = Arc::clone(&tls_name_override);
|
||||
let federation_client = reqwest_client_builder(&config)?
|
||||
.resolve_fn(move |domain| {
|
||||
let read_guard = name_override.read().unwrap();
|
||||
let (override_name, port) = read_guard.get(&domain)?;
|
||||
let first_name = override_name.get(0)?;
|
||||
Some(SocketAddr::new(*first_name, *port))
|
||||
})
|
||||
.build()?;
|
||||
|
||||
// Supported and stable room versions
|
||||
let stable_room_versions = vec![
|
||||
RoomVersionId::V6,
|
||||
RoomVersionId::V7,
|
||||
RoomVersionId::V8,
|
||||
RoomVersionId::V9,
|
||||
];
|
||||
// Experimental, partially supported room versions
|
||||
let unstable_room_versions = vec![RoomVersionId::V3, RoomVersionId::V4, RoomVersionId::V5];
|
||||
|
||||
let mut s = Self {
|
||||
globals,
|
||||
config,
|
||||
keypair: Arc::new(keypair),
|
||||
dns_resolver: TokioAsyncResolver::tokio_from_system_conf().map_err(|e| {
|
||||
error!(
|
||||
"Failed to set up trust dns resolver with system config: {}",
|
||||
e
|
||||
);
|
||||
Error::bad_config("Failed to set up trust dns resolver with system config.")
|
||||
})?,
|
||||
actual_destination_cache: Arc::new(RwLock::new(WellKnownMap::new())),
|
||||
tls_name_override,
|
||||
federation_client,
|
||||
default_client,
|
||||
server_signingkeys,
|
||||
jwt_decoding_key,
|
||||
stable_room_versions,
|
||||
unstable_room_versions,
|
||||
bad_event_ratelimiter: Arc::new(RwLock::new(HashMap::new())),
|
||||
bad_signature_ratelimiter: Arc::new(RwLock::new(HashMap::new())),
|
||||
servername_ratelimiter: Arc::new(RwLock::new(HashMap::new())),
|
||||
roomid_mutex_state: RwLock::new(HashMap::new()),
|
||||
roomid_mutex_insert: RwLock::new(HashMap::new()),
|
||||
roomid_mutex_federation: RwLock::new(HashMap::new()),
|
||||
roomid_federationhandletime: RwLock::new(HashMap::new()),
|
||||
stateres_mutex: Arc::new(Mutex::new(())),
|
||||
sync_receivers: RwLock::new(HashMap::new()),
|
||||
rotate: RotationHandler::new(),
|
||||
};
|
||||
|
||||
fs::create_dir_all(s.get_media_folder())?;
|
||||
|
||||
if !s
|
||||
.supported_room_versions()
|
||||
.contains(&s.config.default_room_version)
|
||||
{
|
||||
error!("Room version in config isn't supported, falling back to Version 6");
|
||||
s.config.default_room_version = RoomVersionId::V6;
|
||||
};
|
||||
|
||||
Ok(s)
|
||||
}
|
||||
|
||||
/// Returns this server's keypair.
|
||||
pub fn keypair(&self) -> &ruma::signatures::Ed25519KeyPair {
|
||||
&self.keypair
|
||||
}
|
||||
|
||||
/// Returns a reqwest client which can be used to send requests
|
||||
pub fn default_client(&self) -> reqwest::Client {
|
||||
// Client is cheap to clone (Arc wrapper) and avoids lifetime issues
|
||||
self.default_client.clone()
|
||||
}
|
||||
|
||||
/// Returns a client used for resolving .well-knowns
|
||||
pub fn federation_client(&self) -> reqwest::Client {
|
||||
// Client is cheap to clone (Arc wrapper) and avoids lifetime issues
|
||||
self.federation_client.clone()
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self))]
|
||||
pub fn next_count(&self) -> Result<u64> {
|
||||
utils::u64_from_bytes(&self.globals.increment(COUNTER)?)
|
||||
.map_err(|_| Error::bad_database("Count has invalid bytes."))
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self))]
|
||||
pub fn current_count(&self) -> Result<u64> {
|
||||
self.globals.get(COUNTER)?.map_or(Ok(0_u64), |bytes| {
|
||||
utils::u64_from_bytes(&bytes)
|
||||
.map_err(|_| Error::bad_database("Count has invalid bytes."))
|
||||
})
|
||||
}
|
||||
|
||||
pub fn server_name(&self) -> &ServerName {
|
||||
self.config.server_name.as_ref()
|
||||
}
|
||||
|
||||
pub fn max_request_size(&self) -> u32 {
|
||||
self.config.max_request_size
|
||||
}
|
||||
|
||||
pub fn allow_registration(&self) -> bool {
|
||||
self.config.allow_registration
|
||||
}
|
||||
|
||||
pub fn allow_encryption(&self) -> bool {
|
||||
self.config.allow_encryption
|
||||
}
|
||||
|
||||
pub fn allow_federation(&self) -> bool {
|
||||
self.config.allow_federation
|
||||
}
|
||||
|
||||
pub fn allow_room_creation(&self) -> bool {
|
||||
self.config.allow_room_creation
|
||||
}
|
||||
|
||||
pub fn allow_unstable_room_versions(&self) -> bool {
|
||||
self.config.allow_unstable_room_versions
|
||||
}
|
||||
|
||||
pub fn default_room_version(&self) -> RoomVersionId {
|
||||
self.config.default_room_version.clone()
|
||||
}
|
||||
|
||||
pub fn trusted_servers(&self) -> &[Box<ServerName>] {
|
||||
&self.config.trusted_servers
|
||||
}
|
||||
|
||||
pub fn dns_resolver(&self) -> &TokioAsyncResolver {
|
||||
&self.dns_resolver
|
||||
}
|
||||
|
||||
pub fn jwt_decoding_key(&self) -> Option<&jsonwebtoken::DecodingKey<'_>> {
|
||||
self.jwt_decoding_key.as_ref()
|
||||
}
|
||||
|
||||
pub fn turn_password(&self) -> &String {
|
||||
&self.config.turn_password
|
||||
}
|
||||
|
||||
pub fn turn_ttl(&self) -> u64 {
|
||||
self.config.turn_ttl
|
||||
}
|
||||
|
||||
pub fn turn_uris(&self) -> &[String] {
|
||||
&self.config.turn_uris
|
||||
}
|
||||
|
||||
pub fn turn_username(&self) -> &String {
|
||||
&self.config.turn_username
|
||||
}
|
||||
|
||||
pub fn turn_secret(&self) -> &String {
|
||||
&self.config.turn_secret
|
||||
}
|
||||
|
||||
pub fn emergency_password(&self) -> &Option<String> {
|
||||
&self.config.emergency_password
|
||||
}
|
||||
|
||||
pub fn supported_room_versions(&self) -> Vec<RoomVersionId> {
|
||||
let mut room_versions: Vec<RoomVersionId> = vec![];
|
||||
room_versions.extend(self.stable_room_versions.clone());
|
||||
if self.allow_unstable_room_versions() {
|
||||
room_versions.extend(self.unstable_room_versions.clone());
|
||||
};
|
||||
room_versions
|
||||
}
|
||||
|
||||
/// TODO: the key valid until timestamp is only honored in room version > 4
|
||||
/// Remove the outdated keys and insert the new ones.
|
||||
///
|
||||
/// This doesn't actually check that the keys provided are newer than the old set.
|
||||
pub fn add_signing_key(
|
||||
&self,
|
||||
origin: &ServerName,
|
||||
new_keys: ServerSigningKeys,
|
||||
) -> Result<BTreeMap<Box<ServerSigningKeyId>, VerifyKey>> {
|
||||
// Not atomic, but this is not critical
|
||||
let signingkeys = self.server_signingkeys.get(origin.as_bytes())?;
|
||||
|
||||
let mut keys = signingkeys
|
||||
.and_then(|keys| serde_json::from_slice(&keys).ok())
|
||||
.unwrap_or_else(|| {
|
||||
// Just insert "now", it doesn't matter
|
||||
ServerSigningKeys::new(origin.to_owned(), MilliSecondsSinceUnixEpoch::now())
|
||||
});
|
||||
|
||||
let ServerSigningKeys {
|
||||
verify_keys,
|
||||
old_verify_keys,
|
||||
..
|
||||
} = new_keys;
|
||||
|
||||
keys.verify_keys.extend(verify_keys.into_iter());
|
||||
keys.old_verify_keys.extend(old_verify_keys.into_iter());
|
||||
|
||||
self.server_signingkeys.insert(
|
||||
origin.as_bytes(),
|
||||
&serde_json::to_vec(&keys).expect("serversigningkeys can be serialized"),
|
||||
)?;
|
||||
|
||||
let mut tree = keys.verify_keys;
|
||||
tree.extend(
|
||||
keys.old_verify_keys
|
||||
.into_iter()
|
||||
.map(|old| (old.0, VerifyKey::new(old.1.key))),
|
||||
);
|
||||
|
||||
Ok(tree)
|
||||
}
|
||||
|
||||
/// This returns an empty `Ok(BTreeMap<..>)` when there are no keys found for the server.
|
||||
pub fn signing_keys_for(
|
||||
&self,
|
||||
origin: &ServerName,
|
||||
) -> Result<BTreeMap<Box<ServerSigningKeyId>, VerifyKey>> {
|
||||
let signingkeys = self
|
||||
.server_signingkeys
|
||||
.get(origin.as_bytes())?
|
||||
.and_then(|bytes| serde_json::from_slice(&bytes).ok())
|
||||
.map(|keys: ServerSigningKeys| {
|
||||
let mut tree = keys.verify_keys;
|
||||
tree.extend(
|
||||
keys.old_verify_keys
|
||||
.into_iter()
|
||||
.map(|old| (old.0, VerifyKey::new(old.1.key))),
|
||||
);
|
||||
tree
|
||||
})
|
||||
.unwrap_or_else(BTreeMap::new);
|
||||
|
||||
Ok(signingkeys)
|
||||
}
|
||||
|
||||
pub fn database_version(&self) -> Result<u64> {
|
||||
self.globals.get(b"version")?.map_or(Ok(0), |version| {
|
||||
utils::u64_from_bytes(&version)
|
||||
.map_err(|_| Error::bad_database("Database version id is invalid."))
|
||||
})
|
||||
}
|
||||
|
||||
pub fn bump_database_version(&self, new_version: u64) -> Result<()> {
|
||||
self.globals
|
||||
.insert(b"version", &new_version.to_be_bytes())?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn get_media_folder(&self) -> PathBuf {
|
||||
let mut r = PathBuf::new();
|
||||
r.push(self.config.database_path.clone());
|
||||
r.push("media");
|
||||
r
|
||||
}
|
||||
|
||||
pub fn get_media_file(&self, key: &[u8]) -> PathBuf {
|
||||
let mut r = PathBuf::new();
|
||||
r.push(self.config.database_path.clone());
|
||||
r.push("media");
|
||||
r.push(base64::encode_config(key, base64::URL_SAFE_NO_PAD));
|
||||
r
|
||||
}
|
||||
}
|
||||
|
||||
fn reqwest_client_builder(config: &Config) -> Result<reqwest::ClientBuilder> {
|
||||
let mut reqwest_client_builder = reqwest::Client::builder()
|
||||
.connect_timeout(Duration::from_secs(30))
|
||||
.timeout(Duration::from_secs(60 * 3));
|
||||
|
||||
if let Some(proxy) = config.proxy.to_proxy()? {
|
||||
reqwest_client_builder = reqwest_client_builder.proxy(proxy);
|
||||
}
|
||||
|
||||
Ok(reqwest_client_builder)
|
||||
}
|
|
@ -1,382 +0,0 @@
|
|||
use crate::{utils, Error, Result};
|
||||
use ruma::{
|
||||
api::client::{
|
||||
backup::{BackupAlgorithm, KeyBackupData, RoomKeyBackup},
|
||||
error::ErrorKind,
|
||||
},
|
||||
serde::Raw,
|
||||
RoomId, UserId,
|
||||
};
|
||||
use std::{collections::BTreeMap, sync::Arc};
|
||||
|
||||
use super::abstraction::Tree;
|
||||
|
||||
pub struct KeyBackups {
|
||||
pub(super) backupid_algorithm: Arc<dyn Tree>, // BackupId = UserId + Version(Count)
|
||||
pub(super) backupid_etag: Arc<dyn Tree>, // BackupId = UserId + Version(Count)
|
||||
pub(super) backupkeyid_backup: Arc<dyn Tree>, // BackupKeyId = UserId + Version + RoomId + SessionId
|
||||
}
|
||||
|
||||
impl KeyBackups {
|
||||
pub fn create_backup(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
backup_metadata: &Raw<BackupAlgorithm>,
|
||||
globals: &super::globals::Globals,
|
||||
) -> Result<String> {
|
||||
let version = globals.next_count()?.to_string();
|
||||
|
||||
let mut key = user_id.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(version.as_bytes());
|
||||
|
||||
self.backupid_algorithm.insert(
|
||||
&key,
|
||||
&serde_json::to_vec(backup_metadata).expect("BackupAlgorithm::to_vec always works"),
|
||||
)?;
|
||||
self.backupid_etag
|
||||
.insert(&key, &globals.next_count()?.to_be_bytes())?;
|
||||
Ok(version)
|
||||
}
|
||||
|
||||
pub fn delete_backup(&self, user_id: &UserId, version: &str) -> Result<()> {
|
||||
let mut key = user_id.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(version.as_bytes());
|
||||
|
||||
self.backupid_algorithm.remove(&key)?;
|
||||
self.backupid_etag.remove(&key)?;
|
||||
|
||||
key.push(0xff);
|
||||
|
||||
for (outdated_key, _) in self.backupkeyid_backup.scan_prefix(key) {
|
||||
self.backupkeyid_backup.remove(&outdated_key)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn update_backup(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
version: &str,
|
||||
backup_metadata: &Raw<BackupAlgorithm>,
|
||||
globals: &super::globals::Globals,
|
||||
) -> Result<String> {
|
||||
let mut key = user_id.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(version.as_bytes());
|
||||
|
||||
if self.backupid_algorithm.get(&key)?.is_none() {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::NotFound,
|
||||
"Tried to update nonexistent backup.",
|
||||
));
|
||||
}
|
||||
|
||||
self.backupid_algorithm
|
||||
.insert(&key, backup_metadata.json().get().as_bytes())?;
|
||||
self.backupid_etag
|
||||
.insert(&key, &globals.next_count()?.to_be_bytes())?;
|
||||
Ok(version.to_owned())
|
||||
}
|
||||
|
||||
pub fn get_latest_backup_version(&self, user_id: &UserId) -> Result<Option<String>> {
|
||||
let mut prefix = user_id.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
let mut last_possible_key = prefix.clone();
|
||||
last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes());
|
||||
|
||||
self.backupid_algorithm
|
||||
.iter_from(&last_possible_key, true)
|
||||
.take_while(move |(k, _)| k.starts_with(&prefix))
|
||||
.next()
|
||||
.map(|(key, _)| {
|
||||
utils::string_from_bytes(
|
||||
key.rsplit(|&b| b == 0xff)
|
||||
.next()
|
||||
.expect("rsplit always returns an element"),
|
||||
)
|
||||
.map_err(|_| Error::bad_database("backupid_algorithm key is invalid."))
|
||||
})
|
||||
.transpose()
|
||||
}
|
||||
|
||||
pub fn get_latest_backup(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
) -> Result<Option<(String, Raw<BackupAlgorithm>)>> {
|
||||
let mut prefix = user_id.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
let mut last_possible_key = prefix.clone();
|
||||
last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes());
|
||||
|
||||
self.backupid_algorithm
|
||||
.iter_from(&last_possible_key, true)
|
||||
.take_while(move |(k, _)| k.starts_with(&prefix))
|
||||
.next()
|
||||
.map(|(key, value)| {
|
||||
let version = utils::string_from_bytes(
|
||||
key.rsplit(|&b| b == 0xff)
|
||||
.next()
|
||||
.expect("rsplit always returns an element"),
|
||||
)
|
||||
.map_err(|_| Error::bad_database("backupid_algorithm key is invalid."))?;
|
||||
|
||||
Ok((
|
||||
version,
|
||||
serde_json::from_slice(&value).map_err(|_| {
|
||||
Error::bad_database("Algorithm in backupid_algorithm is invalid.")
|
||||
})?,
|
||||
))
|
||||
})
|
||||
.transpose()
|
||||
}
|
||||
|
||||
pub fn get_backup(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
version: &str,
|
||||
) -> Result<Option<Raw<BackupAlgorithm>>> {
|
||||
let mut key = user_id.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(version.as_bytes());
|
||||
|
||||
self.backupid_algorithm
|
||||
.get(&key)?
|
||||
.map_or(Ok(None), |bytes| {
|
||||
serde_json::from_slice(&bytes)
|
||||
.map_err(|_| Error::bad_database("Algorithm in backupid_algorithm is invalid."))
|
||||
})
|
||||
}
|
||||
|
||||
pub fn add_key(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
version: &str,
|
||||
room_id: &RoomId,
|
||||
session_id: &str,
|
||||
key_data: &Raw<KeyBackupData>,
|
||||
globals: &super::globals::Globals,
|
||||
) -> Result<()> {
|
||||
let mut key = user_id.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(version.as_bytes());
|
||||
|
||||
if self.backupid_algorithm.get(&key)?.is_none() {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::NotFound,
|
||||
"Tried to update nonexistent backup.",
|
||||
));
|
||||
}
|
||||
|
||||
self.backupid_etag
|
||||
.insert(&key, &globals.next_count()?.to_be_bytes())?;
|
||||
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(room_id.as_bytes());
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(session_id.as_bytes());
|
||||
|
||||
self.backupkeyid_backup
|
||||
.insert(&key, key_data.json().get().as_bytes())?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn count_keys(&self, user_id: &UserId, version: &str) -> Result<usize> {
|
||||
let mut prefix = user_id.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
prefix.extend_from_slice(version.as_bytes());
|
||||
|
||||
Ok(self.backupkeyid_backup.scan_prefix(prefix).count())
|
||||
}
|
||||
|
||||
pub fn get_etag(&self, user_id: &UserId, version: &str) -> Result<String> {
|
||||
let mut key = user_id.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(version.as_bytes());
|
||||
|
||||
Ok(utils::u64_from_bytes(
|
||||
&self
|
||||
.backupid_etag
|
||||
.get(&key)?
|
||||
.ok_or_else(|| Error::bad_database("Backup has no etag."))?,
|
||||
)
|
||||
.map_err(|_| Error::bad_database("etag in backupid_etag invalid."))?
|
||||
.to_string())
|
||||
}
|
||||
|
||||
pub fn get_all(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
version: &str,
|
||||
) -> Result<BTreeMap<Box<RoomId>, RoomKeyBackup>> {
|
||||
let mut prefix = user_id.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
prefix.extend_from_slice(version.as_bytes());
|
||||
prefix.push(0xff);
|
||||
|
||||
let mut rooms = BTreeMap::<Box<RoomId>, RoomKeyBackup>::new();
|
||||
|
||||
for result in self
|
||||
.backupkeyid_backup
|
||||
.scan_prefix(prefix)
|
||||
.map(|(key, value)| {
|
||||
let mut parts = key.rsplit(|&b| b == 0xff);
|
||||
|
||||
let session_id =
|
||||
utils::string_from_bytes(parts.next().ok_or_else(|| {
|
||||
Error::bad_database("backupkeyid_backup key is invalid.")
|
||||
})?)
|
||||
.map_err(|_| {
|
||||
Error::bad_database("backupkeyid_backup session_id is invalid.")
|
||||
})?;
|
||||
|
||||
let room_id = RoomId::parse(
|
||||
utils::string_from_bytes(parts.next().ok_or_else(|| {
|
||||
Error::bad_database("backupkeyid_backup key is invalid.")
|
||||
})?)
|
||||
.map_err(|_| Error::bad_database("backupkeyid_backup room_id is invalid."))?,
|
||||
)
|
||||
.map_err(|_| {
|
||||
Error::bad_database("backupkeyid_backup room_id is invalid room id.")
|
||||
})?;
|
||||
|
||||
let key_data = serde_json::from_slice(&value).map_err(|_| {
|
||||
Error::bad_database("KeyBackupData in backupkeyid_backup is invalid.")
|
||||
})?;
|
||||
|
||||
Ok::<_, Error>((room_id, session_id, key_data))
|
||||
})
|
||||
{
|
||||
let (room_id, session_id, key_data) = result?;
|
||||
rooms
|
||||
.entry(room_id)
|
||||
.or_insert_with(|| RoomKeyBackup {
|
||||
sessions: BTreeMap::new(),
|
||||
})
|
||||
.sessions
|
||||
.insert(session_id, key_data);
|
||||
}
|
||||
|
||||
Ok(rooms)
|
||||
}
|
||||
|
||||
pub fn get_room(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
version: &str,
|
||||
room_id: &RoomId,
|
||||
) -> Result<BTreeMap<String, Raw<KeyBackupData>>> {
|
||||
let mut prefix = user_id.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
prefix.extend_from_slice(version.as_bytes());
|
||||
prefix.push(0xff);
|
||||
prefix.extend_from_slice(room_id.as_bytes());
|
||||
prefix.push(0xff);
|
||||
|
||||
Ok(self
|
||||
.backupkeyid_backup
|
||||
.scan_prefix(prefix)
|
||||
.map(|(key, value)| {
|
||||
let mut parts = key.rsplit(|&b| b == 0xff);
|
||||
|
||||
let session_id =
|
||||
utils::string_from_bytes(parts.next().ok_or_else(|| {
|
||||
Error::bad_database("backupkeyid_backup key is invalid.")
|
||||
})?)
|
||||
.map_err(|_| {
|
||||
Error::bad_database("backupkeyid_backup session_id is invalid.")
|
||||
})?;
|
||||
|
||||
let key_data = serde_json::from_slice(&value).map_err(|_| {
|
||||
Error::bad_database("KeyBackupData in backupkeyid_backup is invalid.")
|
||||
})?;
|
||||
|
||||
Ok::<_, Error>((session_id, key_data))
|
||||
})
|
||||
.filter_map(|r| r.ok())
|
||||
.collect())
|
||||
}
|
||||
|
||||
pub fn get_session(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
version: &str,
|
||||
room_id: &RoomId,
|
||||
session_id: &str,
|
||||
) -> Result<Option<Raw<KeyBackupData>>> {
|
||||
let mut key = user_id.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(version.as_bytes());
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(room_id.as_bytes());
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(session_id.as_bytes());
|
||||
|
||||
self.backupkeyid_backup
|
||||
.get(&key)?
|
||||
.map(|value| {
|
||||
serde_json::from_slice(&value).map_err(|_| {
|
||||
Error::bad_database("KeyBackupData in backupkeyid_backup is invalid.")
|
||||
})
|
||||
})
|
||||
.transpose()
|
||||
}
|
||||
|
||||
pub fn delete_all_keys(&self, user_id: &UserId, version: &str) -> Result<()> {
|
||||
let mut key = user_id.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(version.as_bytes());
|
||||
key.push(0xff);
|
||||
|
||||
for (outdated_key, _) in self.backupkeyid_backup.scan_prefix(key) {
|
||||
self.backupkeyid_backup.remove(&outdated_key)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn delete_room_keys(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
version: &str,
|
||||
room_id: &RoomId,
|
||||
) -> Result<()> {
|
||||
let mut key = user_id.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(version.as_bytes());
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(room_id.as_bytes());
|
||||
key.push(0xff);
|
||||
|
||||
for (outdated_key, _) in self.backupkeyid_backup.scan_prefix(key) {
|
||||
self.backupkeyid_backup.remove(&outdated_key)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn delete_room_key(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
version: &str,
|
||||
room_id: &RoomId,
|
||||
session_id: &str,
|
||||
) -> Result<()> {
|
||||
let mut key = user_id.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(version.as_bytes());
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(room_id.as_bytes());
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(session_id.as_bytes());
|
||||
|
||||
for (outdated_key, _) in self.backupkeyid_backup.scan_prefix(key) {
|
||||
self.backupkeyid_backup.remove(&outdated_key)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
|
@ -1,358 +0,0 @@
|
|||
use crate::database::globals::Globals;
|
||||
use image::{imageops::FilterType, GenericImageView};
|
||||
|
||||
use super::abstraction::Tree;
|
||||
use crate::{utils, Error, Result};
|
||||
use std::{mem, sync::Arc};
|
||||
use tokio::{
|
||||
fs::File,
|
||||
io::{AsyncReadExt, AsyncWriteExt},
|
||||
};
|
||||
|
||||
pub struct FileMeta {
|
||||
pub content_disposition: Option<String>,
|
||||
pub content_type: Option<String>,
|
||||
pub file: Vec<u8>,
|
||||
}
|
||||
|
||||
pub struct Media {
|
||||
pub(super) mediaid_file: Arc<dyn Tree>, // MediaId = MXC + WidthHeight + ContentDisposition + ContentType
|
||||
}
|
||||
|
||||
impl Media {
|
||||
/// Uploads a file.
|
||||
pub async fn create(
|
||||
&self,
|
||||
mxc: String,
|
||||
globals: &Globals,
|
||||
content_disposition: &Option<&str>,
|
||||
content_type: &Option<&str>,
|
||||
file: &[u8],
|
||||
) -> Result<()> {
|
||||
let mut key = mxc.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(&0_u32.to_be_bytes()); // Width = 0 if it's not a thumbnail
|
||||
key.extend_from_slice(&0_u32.to_be_bytes()); // Height = 0 if it's not a thumbnail
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(
|
||||
content_disposition
|
||||
.as_ref()
|
||||
.map(|f| f.as_bytes())
|
||||
.unwrap_or_default(),
|
||||
);
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(
|
||||
content_type
|
||||
.as_ref()
|
||||
.map(|c| c.as_bytes())
|
||||
.unwrap_or_default(),
|
||||
);
|
||||
|
||||
let path = globals.get_media_file(&key);
|
||||
let mut f = File::create(path).await?;
|
||||
f.write_all(file).await?;
|
||||
|
||||
self.mediaid_file.insert(&key, &[])?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Uploads or replaces a file thumbnail.
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub async fn upload_thumbnail(
|
||||
&self,
|
||||
mxc: String,
|
||||
globals: &Globals,
|
||||
content_disposition: &Option<String>,
|
||||
content_type: &Option<String>,
|
||||
width: u32,
|
||||
height: u32,
|
||||
file: &[u8],
|
||||
) -> Result<()> {
|
||||
let mut key = mxc.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(&width.to_be_bytes());
|
||||
key.extend_from_slice(&height.to_be_bytes());
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(
|
||||
content_disposition
|
||||
.as_ref()
|
||||
.map(|f| f.as_bytes())
|
||||
.unwrap_or_default(),
|
||||
);
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(
|
||||
content_type
|
||||
.as_ref()
|
||||
.map(|c| c.as_bytes())
|
||||
.unwrap_or_default(),
|
||||
);
|
||||
|
||||
let path = globals.get_media_file(&key);
|
||||
let mut f = File::create(path).await?;
|
||||
f.write_all(file).await?;
|
||||
|
||||
self.mediaid_file.insert(&key, &[])?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Downloads a file.
|
||||
pub async fn get(&self, globals: &Globals, mxc: &str) -> Result<Option<FileMeta>> {
|
||||
let mut prefix = mxc.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
prefix.extend_from_slice(&0_u32.to_be_bytes()); // Width = 0 if it's not a thumbnail
|
||||
prefix.extend_from_slice(&0_u32.to_be_bytes()); // Height = 0 if it's not a thumbnail
|
||||
prefix.push(0xff);
|
||||
|
||||
let first = self.mediaid_file.scan_prefix(prefix).next();
|
||||
if let Some((key, _)) = first {
|
||||
let path = globals.get_media_file(&key);
|
||||
let mut file = Vec::new();
|
||||
File::open(path).await?.read_to_end(&mut file).await?;
|
||||
let mut parts = key.rsplit(|&b| b == 0xff);
|
||||
|
||||
let content_type = parts
|
||||
.next()
|
||||
.map(|bytes| {
|
||||
utils::string_from_bytes(bytes).map_err(|_| {
|
||||
Error::bad_database("Content type in mediaid_file is invalid unicode.")
|
||||
})
|
||||
})
|
||||
.transpose()?;
|
||||
|
||||
let content_disposition_bytes = parts
|
||||
.next()
|
||||
.ok_or_else(|| Error::bad_database("Media ID in db is invalid."))?;
|
||||
|
||||
let content_disposition = if content_disposition_bytes.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(
|
||||
utils::string_from_bytes(content_disposition_bytes).map_err(|_| {
|
||||
Error::bad_database(
|
||||
"Content Disposition in mediaid_file is invalid unicode.",
|
||||
)
|
||||
})?,
|
||||
)
|
||||
};
|
||||
|
||||
Ok(Some(FileMeta {
|
||||
content_disposition,
|
||||
content_type,
|
||||
file,
|
||||
}))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns width, height of the thumbnail and whether it should be cropped. Returns None when
|
||||
/// the server should send the original file.
|
||||
pub fn thumbnail_properties(&self, width: u32, height: u32) -> Option<(u32, u32, bool)> {
|
||||
match (width, height) {
|
||||
(0..=32, 0..=32) => Some((32, 32, true)),
|
||||
(0..=96, 0..=96) => Some((96, 96, true)),
|
||||
(0..=320, 0..=240) => Some((320, 240, false)),
|
||||
(0..=640, 0..=480) => Some((640, 480, false)),
|
||||
(0..=800, 0..=600) => Some((800, 600, false)),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Downloads a file's thumbnail.
|
||||
///
|
||||
/// Here's an example on how it works:
|
||||
///
|
||||
/// - Client requests an image with width=567, height=567
|
||||
/// - Server rounds that up to (800, 600), so it doesn't have to save too many thumbnails
|
||||
/// - Server rounds that up again to (958, 600) to fix the aspect ratio (only for width,height>96)
|
||||
/// - Server creates the thumbnail and sends it to the user
|
||||
///
|
||||
/// For width,height <= 96 the server uses another thumbnailing algorithm which crops the image afterwards.
|
||||
pub async fn get_thumbnail(
|
||||
&self,
|
||||
mxc: &str,
|
||||
globals: &Globals,
|
||||
width: u32,
|
||||
height: u32,
|
||||
) -> Result<Option<FileMeta>> {
|
||||
let (width, height, crop) = self
|
||||
.thumbnail_properties(width, height)
|
||||
.unwrap_or((0, 0, false)); // 0, 0 because that's the original file
|
||||
|
||||
let mut main_prefix = mxc.as_bytes().to_vec();
|
||||
main_prefix.push(0xff);
|
||||
|
||||
let mut thumbnail_prefix = main_prefix.clone();
|
||||
thumbnail_prefix.extend_from_slice(&width.to_be_bytes());
|
||||
thumbnail_prefix.extend_from_slice(&height.to_be_bytes());
|
||||
thumbnail_prefix.push(0xff);
|
||||
|
||||
let mut original_prefix = main_prefix;
|
||||
original_prefix.extend_from_slice(&0_u32.to_be_bytes()); // Width = 0 if it's not a thumbnail
|
||||
original_prefix.extend_from_slice(&0_u32.to_be_bytes()); // Height = 0 if it's not a thumbnail
|
||||
original_prefix.push(0xff);
|
||||
|
||||
let first_thumbnailprefix = self.mediaid_file.scan_prefix(thumbnail_prefix).next();
|
||||
let first_originalprefix = self.mediaid_file.scan_prefix(original_prefix).next();
|
||||
if let Some((key, _)) = first_thumbnailprefix {
|
||||
// Using saved thumbnail
|
||||
let path = globals.get_media_file(&key);
|
||||
let mut file = Vec::new();
|
||||
File::open(path).await?.read_to_end(&mut file).await?;
|
||||
let mut parts = key.rsplit(|&b| b == 0xff);
|
||||
|
||||
let content_type = parts
|
||||
.next()
|
||||
.map(|bytes| {
|
||||
utils::string_from_bytes(bytes).map_err(|_| {
|
||||
Error::bad_database("Content type in mediaid_file is invalid unicode.")
|
||||
})
|
||||
})
|
||||
.transpose()?;
|
||||
|
||||
let content_disposition_bytes = parts
|
||||
.next()
|
||||
.ok_or_else(|| Error::bad_database("Media ID in db is invalid."))?;
|
||||
|
||||
let content_disposition = if content_disposition_bytes.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(
|
||||
utils::string_from_bytes(content_disposition_bytes).map_err(|_| {
|
||||
Error::bad_database("Content Disposition in db is invalid.")
|
||||
})?,
|
||||
)
|
||||
};
|
||||
|
||||
Ok(Some(FileMeta {
|
||||
content_disposition,
|
||||
content_type,
|
||||
file: file.to_vec(),
|
||||
}))
|
||||
} else if let Some((key, _)) = first_originalprefix {
|
||||
// Generate a thumbnail
|
||||
let path = globals.get_media_file(&key);
|
||||
let mut file = Vec::new();
|
||||
File::open(path).await?.read_to_end(&mut file).await?;
|
||||
|
||||
let mut parts = key.rsplit(|&b| b == 0xff);
|
||||
|
||||
let content_type = parts
|
||||
.next()
|
||||
.map(|bytes| {
|
||||
utils::string_from_bytes(bytes).map_err(|_| {
|
||||
Error::bad_database("Content type in mediaid_file is invalid unicode.")
|
||||
})
|
||||
})
|
||||
.transpose()?;
|
||||
|
||||
let content_disposition_bytes = parts
|
||||
.next()
|
||||
.ok_or_else(|| Error::bad_database("Media ID in db is invalid."))?;
|
||||
|
||||
let content_disposition = if content_disposition_bytes.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(
|
||||
utils::string_from_bytes(content_disposition_bytes).map_err(|_| {
|
||||
Error::bad_database(
|
||||
"Content Disposition in mediaid_file is invalid unicode.",
|
||||
)
|
||||
})?,
|
||||
)
|
||||
};
|
||||
|
||||
if let Ok(image) = image::load_from_memory(&file) {
|
||||
let original_width = image.width();
|
||||
let original_height = image.height();
|
||||
if width > original_width || height > original_height {
|
||||
return Ok(Some(FileMeta {
|
||||
content_disposition,
|
||||
content_type,
|
||||
file: file.to_vec(),
|
||||
}));
|
||||
}
|
||||
|
||||
let thumbnail = if crop {
|
||||
image.resize_to_fill(width, height, FilterType::CatmullRom)
|
||||
} else {
|
||||
let (exact_width, exact_height) = {
|
||||
// Copied from image::dynimage::resize_dimensions
|
||||
let ratio = u64::from(original_width) * u64::from(height);
|
||||
let nratio = u64::from(width) * u64::from(original_height);
|
||||
|
||||
let use_width = nratio <= ratio;
|
||||
let intermediate = if use_width {
|
||||
u64::from(original_height) * u64::from(width)
|
||||
/ u64::from(original_width)
|
||||
} else {
|
||||
u64::from(original_width) * u64::from(height)
|
||||
/ u64::from(original_height)
|
||||
};
|
||||
if use_width {
|
||||
if intermediate <= u64::from(::std::u32::MAX) {
|
||||
(width, intermediate as u32)
|
||||
} else {
|
||||
(
|
||||
(u64::from(width) * u64::from(::std::u32::MAX) / intermediate)
|
||||
as u32,
|
||||
::std::u32::MAX,
|
||||
)
|
||||
}
|
||||
} else if intermediate <= u64::from(::std::u32::MAX) {
|
||||
(intermediate as u32, height)
|
||||
} else {
|
||||
(
|
||||
::std::u32::MAX,
|
||||
(u64::from(height) * u64::from(::std::u32::MAX) / intermediate)
|
||||
as u32,
|
||||
)
|
||||
}
|
||||
};
|
||||
|
||||
image.thumbnail_exact(exact_width, exact_height)
|
||||
};
|
||||
|
||||
let mut thumbnail_bytes = Vec::new();
|
||||
thumbnail.write_to(&mut thumbnail_bytes, image::ImageOutputFormat::Png)?;
|
||||
|
||||
// Save thumbnail in database so we don't have to generate it again next time
|
||||
let mut thumbnail_key = key.to_vec();
|
||||
let width_index = thumbnail_key
|
||||
.iter()
|
||||
.position(|&b| b == 0xff)
|
||||
.ok_or_else(|| Error::bad_database("Media in db is invalid."))?
|
||||
+ 1;
|
||||
let mut widthheight = width.to_be_bytes().to_vec();
|
||||
widthheight.extend_from_slice(&height.to_be_bytes());
|
||||
|
||||
thumbnail_key.splice(
|
||||
width_index..width_index + 2 * mem::size_of::<u32>(),
|
||||
widthheight,
|
||||
);
|
||||
|
||||
let path = globals.get_media_file(&thumbnail_key);
|
||||
let mut f = File::create(path).await?;
|
||||
f.write_all(&thumbnail_bytes).await?;
|
||||
|
||||
self.mediaid_file.insert(&thumbnail_key, &[])?;
|
||||
|
||||
Ok(Some(FileMeta {
|
||||
content_disposition,
|
||||
content_type,
|
||||
file: thumbnail_bytes.to_vec(),
|
||||
}))
|
||||
} else {
|
||||
// Couldn't parse file to generate thumbnail, send original
|
||||
Ok(Some(FileMeta {
|
||||
content_disposition,
|
||||
content_type,
|
||||
file: file.to_vec(),
|
||||
}))
|
||||
}
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
}
|
1017
src/database/mod.rs
Normal file
1017
src/database/mod.rs
Normal file
File diff suppressed because it is too large
Load diff
|
@ -1,348 +0,0 @@
|
|||
use crate::{Database, Error, PduEvent, Result};
|
||||
use bytes::BytesMut;
|
||||
use ruma::{
|
||||
api::{
|
||||
client::push::{get_pushers, set_pusher, PusherKind},
|
||||
push_gateway::send_event_notification::{
|
||||
self,
|
||||
v1::{Device, Notification, NotificationCounts, NotificationPriority},
|
||||
},
|
||||
IncomingResponse, MatrixVersion, OutgoingRequest, SendAccessToken,
|
||||
},
|
||||
events::{
|
||||
room::{name::RoomNameEventContent, power_levels::RoomPowerLevelsEventContent},
|
||||
AnySyncRoomEvent, RoomEventType, StateEventType,
|
||||
},
|
||||
push::{Action, PushConditionRoomCtx, PushFormat, Ruleset, Tweak},
|
||||
serde::Raw,
|
||||
uint, RoomId, UInt, UserId,
|
||||
};
|
||||
use tracing::{error, info, warn};
|
||||
|
||||
use std::{fmt::Debug, mem, sync::Arc};
|
||||
|
||||
use super::abstraction::Tree;
|
||||
|
||||
pub struct PushData {
|
||||
/// UserId + pushkey -> Pusher
|
||||
pub(super) senderkey_pusher: Arc<dyn Tree>,
|
||||
}
|
||||
|
||||
impl PushData {
|
||||
#[tracing::instrument(skip(self, sender, pusher))]
|
||||
pub fn set_pusher(&self, sender: &UserId, pusher: set_pusher::v3::Pusher) -> Result<()> {
|
||||
let mut key = sender.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(pusher.pushkey.as_bytes());
|
||||
|
||||
// There are 2 kinds of pushers but the spec says: null deletes the pusher.
|
||||
if pusher.kind.is_none() {
|
||||
return self
|
||||
.senderkey_pusher
|
||||
.remove(&key)
|
||||
.map(|_| ())
|
||||
.map_err(Into::into);
|
||||
}
|
||||
|
||||
self.senderkey_pusher.insert(
|
||||
&key,
|
||||
&serde_json::to_vec(&pusher).expect("Pusher is valid JSON value"),
|
||||
)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self, senderkey))]
|
||||
pub fn get_pusher(&self, senderkey: &[u8]) -> Result<Option<get_pushers::v3::Pusher>> {
|
||||
self.senderkey_pusher
|
||||
.get(senderkey)?
|
||||
.map(|push| {
|
||||
serde_json::from_slice(&*push)
|
||||
.map_err(|_| Error::bad_database("Invalid Pusher in db."))
|
||||
})
|
||||
.transpose()
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self, sender))]
|
||||
pub fn get_pushers(&self, sender: &UserId) -> Result<Vec<get_pushers::v3::Pusher>> {
|
||||
let mut prefix = sender.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
|
||||
self.senderkey_pusher
|
||||
.scan_prefix(prefix)
|
||||
.map(|(_, push)| {
|
||||
serde_json::from_slice(&*push)
|
||||
.map_err(|_| Error::bad_database("Invalid Pusher in db."))
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self, sender))]
|
||||
pub fn get_pusher_senderkeys<'a>(
|
||||
&'a self,
|
||||
sender: &UserId,
|
||||
) -> impl Iterator<Item = Vec<u8>> + 'a {
|
||||
let mut prefix = sender.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
|
||||
self.senderkey_pusher.scan_prefix(prefix).map(|(k, _)| k)
|
||||
}
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(globals, destination, request))]
|
||||
pub async fn send_request<T: OutgoingRequest>(
|
||||
globals: &crate::database::globals::Globals,
|
||||
destination: &str,
|
||||
request: T,
|
||||
) -> Result<T::IncomingResponse>
|
||||
where
|
||||
T: Debug,
|
||||
{
|
||||
let destination = destination.replace("/_matrix/push/v1/notify", "");
|
||||
|
||||
let http_request = request
|
||||
.try_into_http_request::<BytesMut>(
|
||||
&destination,
|
||||
SendAccessToken::IfRequired(""),
|
||||
&[MatrixVersion::V1_0],
|
||||
)
|
||||
.map_err(|e| {
|
||||
warn!("Failed to find destination {}: {}", destination, e);
|
||||
Error::BadServerResponse("Invalid destination")
|
||||
})?
|
||||
.map(|body| body.freeze());
|
||||
|
||||
let reqwest_request = reqwest::Request::try_from(http_request)
|
||||
.expect("all http requests are valid reqwest requests");
|
||||
|
||||
// TODO: we could keep this very short and let expo backoff do it's thing...
|
||||
//*reqwest_request.timeout_mut() = Some(Duration::from_secs(5));
|
||||
|
||||
let url = reqwest_request.url().clone();
|
||||
let response = globals.default_client().execute(reqwest_request).await;
|
||||
|
||||
match response {
|
||||
Ok(mut response) => {
|
||||
// reqwest::Response -> http::Response conversion
|
||||
let status = response.status();
|
||||
let mut http_response_builder = http::Response::builder()
|
||||
.status(status)
|
||||
.version(response.version());
|
||||
mem::swap(
|
||||
response.headers_mut(),
|
||||
http_response_builder
|
||||
.headers_mut()
|
||||
.expect("http::response::Builder is usable"),
|
||||
);
|
||||
|
||||
let body = response.bytes().await.unwrap_or_else(|e| {
|
||||
warn!("server error {}", e);
|
||||
Vec::new().into()
|
||||
}); // TODO: handle timeout
|
||||
|
||||
if status != 200 {
|
||||
info!(
|
||||
"Push gateway returned bad response {} {}\n{}\n{:?}",
|
||||
destination,
|
||||
status,
|
||||
url,
|
||||
crate::utils::string_from_bytes(&body)
|
||||
);
|
||||
}
|
||||
|
||||
let response = T::IncomingResponse::try_from_http_response(
|
||||
http_response_builder
|
||||
.body(body)
|
||||
.expect("reqwest body is valid http body"),
|
||||
);
|
||||
response.map_err(|_| {
|
||||
info!(
|
||||
"Push gateway returned invalid response bytes {}\n{}",
|
||||
destination, url
|
||||
);
|
||||
Error::BadServerResponse("Push gateway returned bad response.")
|
||||
})
|
||||
}
|
||||
Err(e) => Err(e.into()),
|
||||
}
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(user, unread, pusher, ruleset, pdu, db))]
|
||||
pub async fn send_push_notice(
|
||||
user: &UserId,
|
||||
unread: UInt,
|
||||
pusher: &get_pushers::v3::Pusher,
|
||||
ruleset: Ruleset,
|
||||
pdu: &PduEvent,
|
||||
db: &Database,
|
||||
) -> Result<()> {
|
||||
let mut notify = None;
|
||||
let mut tweaks = Vec::new();
|
||||
|
||||
let power_levels: RoomPowerLevelsEventContent = db
|
||||
.rooms
|
||||
.room_state_get(&pdu.room_id, &StateEventType::RoomPowerLevels, "")?
|
||||
.map(|ev| {
|
||||
serde_json::from_str(ev.content.get())
|
||||
.map_err(|_| Error::bad_database("invalid m.room.power_levels event"))
|
||||
})
|
||||
.transpose()?
|
||||
.unwrap_or_default();
|
||||
|
||||
for action in get_actions(
|
||||
user,
|
||||
&ruleset,
|
||||
&power_levels,
|
||||
&pdu.to_sync_room_event(),
|
||||
&pdu.room_id,
|
||||
db,
|
||||
)? {
|
||||
let n = match action {
|
||||
Action::DontNotify => false,
|
||||
// TODO: Implement proper support for coalesce
|
||||
Action::Notify | Action::Coalesce => true,
|
||||
Action::SetTweak(tweak) => {
|
||||
tweaks.push(tweak.clone());
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
if notify.is_some() {
|
||||
return Err(Error::bad_database(
|
||||
r#"Malformed pushrule contains more than one of these actions: ["dont_notify", "notify", "coalesce"]"#,
|
||||
));
|
||||
}
|
||||
|
||||
notify = Some(n);
|
||||
}
|
||||
|
||||
if notify == Some(true) {
|
||||
send_notice(unread, pusher, tweaks, pdu, db).await?;
|
||||
}
|
||||
// Else the event triggered no actions
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(user, ruleset, pdu, db))]
|
||||
pub fn get_actions<'a>(
|
||||
user: &UserId,
|
||||
ruleset: &'a Ruleset,
|
||||
power_levels: &RoomPowerLevelsEventContent,
|
||||
pdu: &Raw<AnySyncRoomEvent>,
|
||||
room_id: &RoomId,
|
||||
db: &Database,
|
||||
) -> Result<&'a [Action]> {
|
||||
let ctx = PushConditionRoomCtx {
|
||||
room_id: room_id.to_owned(),
|
||||
member_count: 10_u32.into(), // TODO: get member count efficiently
|
||||
user_display_name: db
|
||||
.users
|
||||
.displayname(user)?
|
||||
.unwrap_or_else(|| user.localpart().to_owned()),
|
||||
users_power_levels: power_levels.users.clone(),
|
||||
default_power_level: power_levels.users_default,
|
||||
notification_power_levels: power_levels.notifications.clone(),
|
||||
};
|
||||
|
||||
Ok(ruleset.get_actions(pdu, &ctx))
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(unread, pusher, tweaks, event, db))]
|
||||
async fn send_notice(
|
||||
unread: UInt,
|
||||
pusher: &get_pushers::v3::Pusher,
|
||||
tweaks: Vec<Tweak>,
|
||||
event: &PduEvent,
|
||||
db: &Database,
|
||||
) -> Result<()> {
|
||||
// TODO: email
|
||||
if pusher.kind == PusherKind::Email {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// TODO:
|
||||
// Two problems with this
|
||||
// 1. if "event_id_only" is the only format kind it seems we should never add more info
|
||||
// 2. can pusher/devices have conflicting formats
|
||||
let event_id_only = pusher.data.format == Some(PushFormat::EventIdOnly);
|
||||
let url = if let Some(url) = &pusher.data.url {
|
||||
url
|
||||
} else {
|
||||
error!("Http Pusher must have URL specified.");
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
let mut device = Device::new(pusher.app_id.clone(), pusher.pushkey.clone());
|
||||
let mut data_minus_url = pusher.data.clone();
|
||||
// The url must be stripped off according to spec
|
||||
data_minus_url.url = None;
|
||||
device.data = data_minus_url;
|
||||
|
||||
// Tweaks are only added if the format is NOT event_id_only
|
||||
if !event_id_only {
|
||||
device.tweaks = tweaks.clone();
|
||||
}
|
||||
|
||||
let d = &[device];
|
||||
let mut notifi = Notification::new(d);
|
||||
|
||||
notifi.prio = NotificationPriority::Low;
|
||||
notifi.event_id = Some(&event.event_id);
|
||||
notifi.room_id = Some(&event.room_id);
|
||||
// TODO: missed calls
|
||||
notifi.counts = NotificationCounts::new(unread, uint!(0));
|
||||
|
||||
if event.kind == RoomEventType::RoomEncrypted
|
||||
|| tweaks
|
||||
.iter()
|
||||
.any(|t| matches!(t, Tweak::Highlight(true) | Tweak::Sound(_)))
|
||||
{
|
||||
notifi.prio = NotificationPriority::High
|
||||
}
|
||||
|
||||
if event_id_only {
|
||||
send_request(
|
||||
&db.globals,
|
||||
url,
|
||||
send_event_notification::v1::Request::new(notifi),
|
||||
)
|
||||
.await?;
|
||||
} else {
|
||||
notifi.sender = Some(&event.sender);
|
||||
notifi.event_type = Some(&event.kind);
|
||||
let content = serde_json::value::to_raw_value(&event.content).ok();
|
||||
notifi.content = content.as_deref();
|
||||
|
||||
if event.kind == RoomEventType::RoomMember {
|
||||
notifi.user_is_target = event.state_key.as_deref() == Some(event.sender.as_str());
|
||||
}
|
||||
|
||||
let user_name = db.users.displayname(&event.sender)?;
|
||||
notifi.sender_display_name = user_name.as_deref();
|
||||
|
||||
let room_name = if let Some(room_name_pdu) =
|
||||
db.rooms
|
||||
.room_state_get(&event.room_id, &StateEventType::RoomName, "")?
|
||||
{
|
||||
serde_json::from_str::<RoomNameEventContent>(room_name_pdu.content.get())
|
||||
.map_err(|_| Error::bad_database("Invalid room name event in database."))?
|
||||
.name
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
notifi.room_name = room_name.as_deref();
|
||||
|
||||
send_request(
|
||||
&db.globals,
|
||||
url,
|
||||
send_event_notification::v1::Request::new(notifi),
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
|
||||
// TODO: email
|
||||
|
||||
Ok(())
|
||||
}
|
|
@ -1,550 +0,0 @@
|
|||
use crate::{database::abstraction::Tree, utils, Error, Result};
|
||||
use ruma::{
|
||||
events::{
|
||||
presence::{PresenceEvent, PresenceEventContent},
|
||||
receipt::ReceiptEvent,
|
||||
SyncEphemeralRoomEvent,
|
||||
},
|
||||
presence::PresenceState,
|
||||
serde::Raw,
|
||||
signatures::CanonicalJsonObject,
|
||||
RoomId, UInt, UserId,
|
||||
};
|
||||
use std::{
|
||||
collections::{HashMap, HashSet},
|
||||
mem,
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
pub struct RoomEdus {
|
||||
pub(in super::super) readreceiptid_readreceipt: Arc<dyn Tree>, // ReadReceiptId = RoomId + Count + UserId
|
||||
pub(in super::super) roomuserid_privateread: Arc<dyn Tree>, // RoomUserId = Room + User, PrivateRead = Count
|
||||
pub(in super::super) roomuserid_lastprivatereadupdate: Arc<dyn Tree>, // LastPrivateReadUpdate = Count
|
||||
pub(in super::super) typingid_userid: Arc<dyn Tree>, // TypingId = RoomId + TimeoutTime + Count
|
||||
pub(in super::super) roomid_lasttypingupdate: Arc<dyn Tree>, // LastRoomTypingUpdate = Count
|
||||
pub(in super::super) presenceid_presence: Arc<dyn Tree>, // PresenceId = RoomId + Count + UserId
|
||||
pub(in super::super) userid_lastpresenceupdate: Arc<dyn Tree>, // LastPresenceUpdate = Count
|
||||
}
|
||||
|
||||
impl RoomEdus {
|
||||
/// Adds an event which will be saved until a new event replaces it (e.g. read receipt).
|
||||
pub fn readreceipt_update(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
room_id: &RoomId,
|
||||
event: ReceiptEvent,
|
||||
globals: &super::super::globals::Globals,
|
||||
) -> Result<()> {
|
||||
let mut prefix = room_id.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
|
||||
let mut last_possible_key = prefix.clone();
|
||||
last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes());
|
||||
|
||||
// Remove old entry
|
||||
if let Some((old, _)) = self
|
||||
.readreceiptid_readreceipt
|
||||
.iter_from(&last_possible_key, true)
|
||||
.take_while(|(key, _)| key.starts_with(&prefix))
|
||||
.find(|(key, _)| {
|
||||
key.rsplit(|&b| b == 0xff)
|
||||
.next()
|
||||
.expect("rsplit always returns an element")
|
||||
== user_id.as_bytes()
|
||||
})
|
||||
{
|
||||
// This is the old room_latest
|
||||
self.readreceiptid_readreceipt.remove(&old)?;
|
||||
}
|
||||
|
||||
let mut room_latest_id = prefix;
|
||||
room_latest_id.extend_from_slice(&globals.next_count()?.to_be_bytes());
|
||||
room_latest_id.push(0xff);
|
||||
room_latest_id.extend_from_slice(user_id.as_bytes());
|
||||
|
||||
self.readreceiptid_readreceipt.insert(
|
||||
&room_latest_id,
|
||||
&serde_json::to_vec(&event).expect("EduEvent::to_string always works"),
|
||||
)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Returns an iterator over the most recent read_receipts in a room that happened after the event with id `since`.
|
||||
#[tracing::instrument(skip(self))]
|
||||
pub fn readreceipts_since<'a>(
|
||||
&'a self,
|
||||
room_id: &RoomId,
|
||||
since: u64,
|
||||
) -> impl Iterator<
|
||||
Item = Result<(
|
||||
Box<UserId>,
|
||||
u64,
|
||||
Raw<ruma::events::AnySyncEphemeralRoomEvent>,
|
||||
)>,
|
||||
> + 'a {
|
||||
let mut prefix = room_id.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
let prefix2 = prefix.clone();
|
||||
|
||||
let mut first_possible_edu = prefix.clone();
|
||||
first_possible_edu.extend_from_slice(&(since + 1).to_be_bytes()); // +1 so we don't send the event at since
|
||||
|
||||
self.readreceiptid_readreceipt
|
||||
.iter_from(&first_possible_edu, false)
|
||||
.take_while(move |(k, _)| k.starts_with(&prefix2))
|
||||
.map(move |(k, v)| {
|
||||
let count =
|
||||
utils::u64_from_bytes(&k[prefix.len()..prefix.len() + mem::size_of::<u64>()])
|
||||
.map_err(|_| Error::bad_database("Invalid readreceiptid count in db."))?;
|
||||
let user_id = UserId::parse(
|
||||
utils::string_from_bytes(&k[prefix.len() + mem::size_of::<u64>() + 1..])
|
||||
.map_err(|_| {
|
||||
Error::bad_database("Invalid readreceiptid userid bytes in db.")
|
||||
})?,
|
||||
)
|
||||
.map_err(|_| Error::bad_database("Invalid readreceiptid userid in db."))?;
|
||||
|
||||
let mut json = serde_json::from_slice::<CanonicalJsonObject>(&v).map_err(|_| {
|
||||
Error::bad_database("Read receipt in roomlatestid_roomlatest is invalid json.")
|
||||
})?;
|
||||
json.remove("room_id");
|
||||
|
||||
Ok((
|
||||
user_id,
|
||||
count,
|
||||
Raw::from_json(
|
||||
serde_json::value::to_raw_value(&json).expect("json is valid raw value"),
|
||||
),
|
||||
))
|
||||
})
|
||||
}
|
||||
|
||||
/// Sets a private read marker at `count`.
|
||||
#[tracing::instrument(skip(self, globals))]
|
||||
pub fn private_read_set(
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
user_id: &UserId,
|
||||
count: u64,
|
||||
globals: &super::super::globals::Globals,
|
||||
) -> Result<()> {
|
||||
let mut key = room_id.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(user_id.as_bytes());
|
||||
|
||||
self.roomuserid_privateread
|
||||
.insert(&key, &count.to_be_bytes())?;
|
||||
|
||||
self.roomuserid_lastprivatereadupdate
|
||||
.insert(&key, &globals.next_count()?.to_be_bytes())?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Returns the private read marker.
|
||||
#[tracing::instrument(skip(self))]
|
||||
pub fn private_read_get(&self, room_id: &RoomId, user_id: &UserId) -> Result<Option<u64>> {
|
||||
let mut key = room_id.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(user_id.as_bytes());
|
||||
|
||||
self.roomuserid_privateread
|
||||
.get(&key)?
|
||||
.map_or(Ok(None), |v| {
|
||||
Ok(Some(utils::u64_from_bytes(&v).map_err(|_| {
|
||||
Error::bad_database("Invalid private read marker bytes")
|
||||
})?))
|
||||
})
|
||||
}
|
||||
|
||||
/// Returns the count of the last typing update in this room.
|
||||
pub fn last_privateread_update(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64> {
|
||||
let mut key = room_id.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(user_id.as_bytes());
|
||||
|
||||
Ok(self
|
||||
.roomuserid_lastprivatereadupdate
|
||||
.get(&key)?
|
||||
.map(|bytes| {
|
||||
utils::u64_from_bytes(&bytes).map_err(|_| {
|
||||
Error::bad_database("Count in roomuserid_lastprivatereadupdate is invalid.")
|
||||
})
|
||||
})
|
||||
.transpose()?
|
||||
.unwrap_or(0))
|
||||
}
|
||||
|
||||
/// Sets a user as typing until the timeout timestamp is reached or roomtyping_remove is
|
||||
/// called.
|
||||
pub fn typing_add(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
room_id: &RoomId,
|
||||
timeout: u64,
|
||||
globals: &super::super::globals::Globals,
|
||||
) -> Result<()> {
|
||||
let mut prefix = room_id.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
|
||||
let count = globals.next_count()?.to_be_bytes();
|
||||
|
||||
let mut room_typing_id = prefix;
|
||||
room_typing_id.extend_from_slice(&timeout.to_be_bytes());
|
||||
room_typing_id.push(0xff);
|
||||
room_typing_id.extend_from_slice(&count);
|
||||
|
||||
self.typingid_userid
|
||||
.insert(&room_typing_id, &*user_id.as_bytes())?;
|
||||
|
||||
self.roomid_lasttypingupdate
|
||||
.insert(room_id.as_bytes(), &count)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Removes a user from typing before the timeout is reached.
|
||||
pub fn typing_remove(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
room_id: &RoomId,
|
||||
globals: &super::super::globals::Globals,
|
||||
) -> Result<()> {
|
||||
let mut prefix = room_id.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
|
||||
let user_id = user_id.to_string();
|
||||
|
||||
let mut found_outdated = false;
|
||||
|
||||
// Maybe there are multiple ones from calling roomtyping_add multiple times
|
||||
for outdated_edu in self
|
||||
.typingid_userid
|
||||
.scan_prefix(prefix)
|
||||
.filter(|(_, v)| &**v == user_id.as_bytes())
|
||||
{
|
||||
self.typingid_userid.remove(&outdated_edu.0)?;
|
||||
found_outdated = true;
|
||||
}
|
||||
|
||||
if found_outdated {
|
||||
self.roomid_lasttypingupdate
|
||||
.insert(room_id.as_bytes(), &globals.next_count()?.to_be_bytes())?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Makes sure that typing events with old timestamps get removed.
|
||||
fn typings_maintain(
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
globals: &super::super::globals::Globals,
|
||||
) -> Result<()> {
|
||||
let mut prefix = room_id.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
|
||||
let current_timestamp = utils::millis_since_unix_epoch();
|
||||
|
||||
let mut found_outdated = false;
|
||||
|
||||
// Find all outdated edus before inserting a new one
|
||||
for outdated_edu in self
|
||||
.typingid_userid
|
||||
.scan_prefix(prefix)
|
||||
.map(|(key, _)| {
|
||||
Ok::<_, Error>((
|
||||
key.clone(),
|
||||
utils::u64_from_bytes(
|
||||
&key.splitn(2, |&b| b == 0xff).nth(1).ok_or_else(|| {
|
||||
Error::bad_database("RoomTyping has invalid timestamp or delimiters.")
|
||||
})?[0..mem::size_of::<u64>()],
|
||||
)
|
||||
.map_err(|_| Error::bad_database("RoomTyping has invalid timestamp bytes."))?,
|
||||
))
|
||||
})
|
||||
.filter_map(|r| r.ok())
|
||||
.take_while(|&(_, timestamp)| timestamp < current_timestamp)
|
||||
{
|
||||
// This is an outdated edu (time > timestamp)
|
||||
self.typingid_userid.remove(&outdated_edu.0)?;
|
||||
found_outdated = true;
|
||||
}
|
||||
|
||||
if found_outdated {
|
||||
self.roomid_lasttypingupdate
|
||||
.insert(room_id.as_bytes(), &globals.next_count()?.to_be_bytes())?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Returns the count of the last typing update in this room.
|
||||
#[tracing::instrument(skip(self, globals))]
|
||||
pub fn last_typing_update(
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
globals: &super::super::globals::Globals,
|
||||
) -> Result<u64> {
|
||||
self.typings_maintain(room_id, globals)?;
|
||||
|
||||
Ok(self
|
||||
.roomid_lasttypingupdate
|
||||
.get(room_id.as_bytes())?
|
||||
.map(|bytes| {
|
||||
utils::u64_from_bytes(&bytes).map_err(|_| {
|
||||
Error::bad_database("Count in roomid_lastroomactiveupdate is invalid.")
|
||||
})
|
||||
})
|
||||
.transpose()?
|
||||
.unwrap_or(0))
|
||||
}
|
||||
|
||||
pub fn typings_all(
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
) -> Result<SyncEphemeralRoomEvent<ruma::events::typing::TypingEventContent>> {
|
||||
let mut prefix = room_id.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
|
||||
let mut user_ids = HashSet::new();
|
||||
|
||||
for (_, user_id) in self.typingid_userid.scan_prefix(prefix) {
|
||||
let user_id = UserId::parse(utils::string_from_bytes(&user_id).map_err(|_| {
|
||||
Error::bad_database("User ID in typingid_userid is invalid unicode.")
|
||||
})?)
|
||||
.map_err(|_| Error::bad_database("User ID in typingid_userid is invalid."))?;
|
||||
|
||||
user_ids.insert(user_id);
|
||||
}
|
||||
|
||||
Ok(SyncEphemeralRoomEvent {
|
||||
content: ruma::events::typing::TypingEventContent {
|
||||
user_ids: user_ids.into_iter().collect(),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
/// Adds a presence event which will be saved until a new event replaces it.
|
||||
///
|
||||
/// Note: This method takes a RoomId because presence updates are always bound to rooms to
|
||||
/// make sure users outside these rooms can't see them.
|
||||
pub fn update_presence(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
room_id: &RoomId,
|
||||
presence: PresenceEvent,
|
||||
globals: &super::super::globals::Globals,
|
||||
) -> Result<()> {
|
||||
// TODO: Remove old entry? Or maybe just wipe completely from time to time?
|
||||
|
||||
let count = globals.next_count()?.to_be_bytes();
|
||||
|
||||
let mut presence_id = room_id.as_bytes().to_vec();
|
||||
presence_id.push(0xff);
|
||||
presence_id.extend_from_slice(&count);
|
||||
presence_id.push(0xff);
|
||||
presence_id.extend_from_slice(presence.sender.as_bytes());
|
||||
|
||||
self.presenceid_presence.insert(
|
||||
&presence_id,
|
||||
&serde_json::to_vec(&presence).expect("PresenceEvent can be serialized"),
|
||||
)?;
|
||||
|
||||
self.userid_lastpresenceupdate.insert(
|
||||
user_id.as_bytes(),
|
||||
&utils::millis_since_unix_epoch().to_be_bytes(),
|
||||
)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Resets the presence timeout, so the user will stay in their current presence state.
|
||||
#[tracing::instrument(skip(self))]
|
||||
pub fn ping_presence(&self, user_id: &UserId) -> Result<()> {
|
||||
self.userid_lastpresenceupdate.insert(
|
||||
user_id.as_bytes(),
|
||||
&utils::millis_since_unix_epoch().to_be_bytes(),
|
||||
)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Returns the timestamp of the last presence update of this user in millis since the unix epoch.
|
||||
pub fn last_presence_update(&self, user_id: &UserId) -> Result<Option<u64>> {
|
||||
self.userid_lastpresenceupdate
|
||||
.get(user_id.as_bytes())?
|
||||
.map(|bytes| {
|
||||
utils::u64_from_bytes(&bytes).map_err(|_| {
|
||||
Error::bad_database("Invalid timestamp in userid_lastpresenceupdate.")
|
||||
})
|
||||
})
|
||||
.transpose()
|
||||
}
|
||||
|
||||
pub fn get_last_presence_event(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
room_id: &RoomId,
|
||||
) -> Result<Option<PresenceEvent>> {
|
||||
let last_update = match self.last_presence_update(user_id)? {
|
||||
Some(last) => last,
|
||||
None => return Ok(None),
|
||||
};
|
||||
|
||||
let mut presence_id = room_id.as_bytes().to_vec();
|
||||
presence_id.push(0xff);
|
||||
presence_id.extend_from_slice(&last_update.to_be_bytes());
|
||||
presence_id.push(0xff);
|
||||
presence_id.extend_from_slice(user_id.as_bytes());
|
||||
|
||||
self.presenceid_presence
|
||||
.get(&presence_id)?
|
||||
.map(|value| {
|
||||
let mut presence: PresenceEvent = serde_json::from_slice(&value)
|
||||
.map_err(|_| Error::bad_database("Invalid presence event in db."))?;
|
||||
let current_timestamp: UInt = utils::millis_since_unix_epoch()
|
||||
.try_into()
|
||||
.expect("time is valid");
|
||||
|
||||
if presence.content.presence == PresenceState::Online {
|
||||
// Don't set last_active_ago when the user is online
|
||||
presence.content.last_active_ago = None;
|
||||
} else {
|
||||
// Convert from timestamp to duration
|
||||
presence.content.last_active_ago = presence
|
||||
.content
|
||||
.last_active_ago
|
||||
.map(|timestamp| current_timestamp - timestamp);
|
||||
}
|
||||
|
||||
Ok(presence)
|
||||
})
|
||||
.transpose()
|
||||
}
|
||||
|
||||
/// Sets all users to offline who have been quiet for too long.
|
||||
fn _presence_maintain(
|
||||
&self,
|
||||
rooms: &super::Rooms,
|
||||
globals: &super::super::globals::Globals,
|
||||
) -> Result<()> {
|
||||
let current_timestamp = utils::millis_since_unix_epoch();
|
||||
|
||||
for (user_id_bytes, last_timestamp) in self
|
||||
.userid_lastpresenceupdate
|
||||
.iter()
|
||||
.filter_map(|(k, bytes)| {
|
||||
Some((
|
||||
k,
|
||||
utils::u64_from_bytes(&bytes)
|
||||
.map_err(|_| {
|
||||
Error::bad_database("Invalid timestamp in userid_lastpresenceupdate.")
|
||||
})
|
||||
.ok()?,
|
||||
))
|
||||
})
|
||||
.take_while(|(_, timestamp)| current_timestamp.saturating_sub(*timestamp) > 5 * 60_000)
|
||||
// 5 Minutes
|
||||
{
|
||||
// Send new presence events to set the user offline
|
||||
let count = globals.next_count()?.to_be_bytes();
|
||||
let user_id: Box<_> = utils::string_from_bytes(&user_id_bytes)
|
||||
.map_err(|_| {
|
||||
Error::bad_database("Invalid UserId bytes in userid_lastpresenceupdate.")
|
||||
})?
|
||||
.try_into()
|
||||
.map_err(|_| Error::bad_database("Invalid UserId in userid_lastpresenceupdate."))?;
|
||||
for room_id in rooms.rooms_joined(&user_id).filter_map(|r| r.ok()) {
|
||||
let mut presence_id = room_id.as_bytes().to_vec();
|
||||
presence_id.push(0xff);
|
||||
presence_id.extend_from_slice(&count);
|
||||
presence_id.push(0xff);
|
||||
presence_id.extend_from_slice(&user_id_bytes);
|
||||
|
||||
self.presenceid_presence.insert(
|
||||
&presence_id,
|
||||
&serde_json::to_vec(&PresenceEvent {
|
||||
content: PresenceEventContent {
|
||||
avatar_url: None,
|
||||
currently_active: None,
|
||||
displayname: None,
|
||||
last_active_ago: Some(
|
||||
last_timestamp.try_into().expect("time is valid"),
|
||||
),
|
||||
presence: PresenceState::Offline,
|
||||
status_msg: None,
|
||||
},
|
||||
sender: user_id.to_owned(),
|
||||
})
|
||||
.expect("PresenceEvent can be serialized"),
|
||||
)?;
|
||||
}
|
||||
|
||||
self.userid_lastpresenceupdate.insert(
|
||||
user_id.as_bytes(),
|
||||
&utils::millis_since_unix_epoch().to_be_bytes(),
|
||||
)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Returns an iterator over the most recent presence updates that happened after the event with id `since`.
|
||||
#[tracing::instrument(skip(self, since, _rooms, _globals))]
|
||||
pub fn presence_since(
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
since: u64,
|
||||
_rooms: &super::Rooms,
|
||||
_globals: &super::super::globals::Globals,
|
||||
) -> Result<HashMap<Box<UserId>, PresenceEvent>> {
|
||||
//self.presence_maintain(rooms, globals)?;
|
||||
|
||||
let mut prefix = room_id.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
|
||||
let mut first_possible_edu = prefix.clone();
|
||||
first_possible_edu.extend_from_slice(&(since + 1).to_be_bytes()); // +1 so we don't send the event at since
|
||||
let mut hashmap = HashMap::new();
|
||||
|
||||
for (key, value) in self
|
||||
.presenceid_presence
|
||||
.iter_from(&*first_possible_edu, false)
|
||||
.take_while(|(key, _)| key.starts_with(&prefix))
|
||||
{
|
||||
let user_id = UserId::parse(
|
||||
utils::string_from_bytes(
|
||||
key.rsplit(|&b| b == 0xff)
|
||||
.next()
|
||||
.expect("rsplit always returns an element"),
|
||||
)
|
||||
.map_err(|_| Error::bad_database("Invalid UserId bytes in presenceid_presence."))?,
|
||||
)
|
||||
.map_err(|_| Error::bad_database("Invalid UserId in presenceid_presence."))?;
|
||||
|
||||
let mut presence: PresenceEvent = serde_json::from_slice(&value)
|
||||
.map_err(|_| Error::bad_database("Invalid presence event in db."))?;
|
||||
|
||||
let current_timestamp: UInt = utils::millis_since_unix_epoch()
|
||||
.try_into()
|
||||
.expect("time is valid");
|
||||
|
||||
if presence.content.presence == PresenceState::Online {
|
||||
// Don't set last_active_ago when the user is online
|
||||
presence.content.last_active_ago = None;
|
||||
} else {
|
||||
// Convert from timestamp to duration
|
||||
presence.content.last_active_ago = presence
|
||||
.content
|
||||
.last_active_ago
|
||||
.map(|timestamp| current_timestamp - timestamp);
|
||||
}
|
||||
|
||||
hashmap.insert(user_id, presence);
|
||||
}
|
||||
|
||||
Ok(hashmap)
|
||||
}
|
||||
}
|
|
@ -1,845 +0,0 @@
|
|||
use std::{
|
||||
collections::{BTreeMap, HashMap, HashSet},
|
||||
fmt::Debug,
|
||||
sync::Arc,
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
|
||||
use crate::{
|
||||
appservice_server, database::pusher, server_server, utils, Database, Error, PduEvent, Result,
|
||||
};
|
||||
use federation::transactions::send_transaction_message;
|
||||
use futures_util::{stream::FuturesUnordered, StreamExt};
|
||||
use ring::digest;
|
||||
use ruma::{
|
||||
api::{
|
||||
appservice,
|
||||
federation::{
|
||||
self,
|
||||
transactions::edu::{
|
||||
DeviceListUpdateContent, Edu, ReceiptContent, ReceiptData, ReceiptMap,
|
||||
},
|
||||
},
|
||||
OutgoingRequest,
|
||||
},
|
||||
device_id,
|
||||
events::{push_rules::PushRulesEvent, AnySyncEphemeralRoomEvent, GlobalAccountDataEventType},
|
||||
push,
|
||||
receipt::ReceiptType,
|
||||
uint, MilliSecondsSinceUnixEpoch, ServerName, UInt, UserId,
|
||||
};
|
||||
use tokio::{
|
||||
select,
|
||||
sync::{mpsc, RwLock, Semaphore},
|
||||
};
|
||||
use tracing::{error, warn};
|
||||
|
||||
use super::abstraction::Tree;
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
|
||||
pub enum OutgoingKind {
|
||||
Appservice(String),
|
||||
Push(Vec<u8>, Vec<u8>), // user and pushkey
|
||||
Normal(Box<ServerName>),
|
||||
}
|
||||
|
||||
impl OutgoingKind {
|
||||
#[tracing::instrument(skip(self))]
|
||||
pub fn get_prefix(&self) -> Vec<u8> {
|
||||
let mut prefix = match self {
|
||||
OutgoingKind::Appservice(server) => {
|
||||
let mut p = b"+".to_vec();
|
||||
p.extend_from_slice(server.as_bytes());
|
||||
p
|
||||
}
|
||||
OutgoingKind::Push(user, pushkey) => {
|
||||
let mut p = b"$".to_vec();
|
||||
p.extend_from_slice(user);
|
||||
p.push(0xff);
|
||||
p.extend_from_slice(pushkey);
|
||||
p
|
||||
}
|
||||
OutgoingKind::Normal(server) => {
|
||||
let mut p = Vec::new();
|
||||
p.extend_from_slice(server.as_bytes());
|
||||
p
|
||||
}
|
||||
};
|
||||
prefix.push(0xff);
|
||||
|
||||
prefix
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
|
||||
pub enum SendingEventType {
|
||||
Pdu(Vec<u8>),
|
||||
Edu(Vec<u8>),
|
||||
}
|
||||
|
||||
pub struct Sending {
|
||||
/// The state for a given state hash.
|
||||
pub(super) servername_educount: Arc<dyn Tree>, // EduCount: Count of last EDU sync
|
||||
pub(super) servernameevent_data: Arc<dyn Tree>, // ServernameEvent = (+ / $)SenderKey / ServerName / UserId + PduId / Id (for edus), Data = EDU content
|
||||
pub(super) servercurrentevent_data: Arc<dyn Tree>, // ServerCurrentEvents = (+ / $)ServerName / UserId + PduId / Id (for edus), Data = EDU content
|
||||
pub(super) maximum_requests: Arc<Semaphore>,
|
||||
pub sender: mpsc::UnboundedSender<(Vec<u8>, Vec<u8>)>,
|
||||
}
|
||||
|
||||
enum TransactionStatus {
|
||||
Running,
|
||||
Failed(u32, Instant), // number of times failed, time of last failure
|
||||
Retrying(u32), // number of times failed
|
||||
}
|
||||
|
||||
impl Sending {
|
||||
pub fn start_handler(
|
||||
&self,
|
||||
db: Arc<RwLock<Database>>,
|
||||
mut receiver: mpsc::UnboundedReceiver<(Vec<u8>, Vec<u8>)>,
|
||||
) {
|
||||
tokio::spawn(async move {
|
||||
let mut futures = FuturesUnordered::new();
|
||||
|
||||
let mut current_transaction_status = HashMap::<Vec<u8>, TransactionStatus>::new();
|
||||
|
||||
// Retry requests we could not finish yet
|
||||
let mut initial_transactions = HashMap::<OutgoingKind, Vec<SendingEventType>>::new();
|
||||
|
||||
let guard = db.read().await;
|
||||
|
||||
for (key, outgoing_kind, event) in guard
|
||||
.sending
|
||||
.servercurrentevent_data
|
||||
.iter()
|
||||
.filter_map(|(key, v)| {
|
||||
Self::parse_servercurrentevent(&key, v)
|
||||
.ok()
|
||||
.map(|(k, e)| (key, k, e))
|
||||
})
|
||||
{
|
||||
let entry = initial_transactions
|
||||
.entry(outgoing_kind.clone())
|
||||
.or_insert_with(Vec::new);
|
||||
|
||||
if entry.len() > 30 {
|
||||
warn!(
|
||||
"Dropping some current events: {:?} {:?} {:?}",
|
||||
key, outgoing_kind, event
|
||||
);
|
||||
guard.sending.servercurrentevent_data.remove(&key).unwrap();
|
||||
continue;
|
||||
}
|
||||
|
||||
entry.push(event);
|
||||
}
|
||||
|
||||
drop(guard);
|
||||
|
||||
for (outgoing_kind, events) in initial_transactions {
|
||||
current_transaction_status
|
||||
.insert(outgoing_kind.get_prefix(), TransactionStatus::Running);
|
||||
futures.push(Self::handle_events(
|
||||
outgoing_kind.clone(),
|
||||
events,
|
||||
Arc::clone(&db),
|
||||
));
|
||||
}
|
||||
|
||||
loop {
|
||||
select! {
|
||||
Some(response) = futures.next() => {
|
||||
match response {
|
||||
Ok(outgoing_kind) => {
|
||||
let guard = db.read().await;
|
||||
|
||||
let prefix = outgoing_kind.get_prefix();
|
||||
for (key, _) in guard.sending.servercurrentevent_data
|
||||
.scan_prefix(prefix.clone())
|
||||
{
|
||||
guard.sending.servercurrentevent_data.remove(&key).unwrap();
|
||||
}
|
||||
|
||||
// Find events that have been added since starting the last request
|
||||
let new_events: Vec<_> = guard.sending.servernameevent_data
|
||||
.scan_prefix(prefix.clone())
|
||||
.filter_map(|(k, v)| {
|
||||
Self::parse_servercurrentevent(&k, v).ok().map(|ev| (ev, k))
|
||||
})
|
||||
.take(30)
|
||||
.collect();
|
||||
|
||||
// TODO: find edus
|
||||
|
||||
if !new_events.is_empty() {
|
||||
// Insert pdus we found
|
||||
for (e, key) in &new_events {
|
||||
let value = if let SendingEventType::Edu(value) = &e.1 { &**value } else { &[] };
|
||||
guard.sending.servercurrentevent_data.insert(key, value).unwrap();
|
||||
guard.sending.servernameevent_data.remove(key).unwrap();
|
||||
}
|
||||
|
||||
drop(guard);
|
||||
|
||||
futures.push(
|
||||
Self::handle_events(
|
||||
outgoing_kind.clone(),
|
||||
new_events.into_iter().map(|(event, _)| event.1).collect(),
|
||||
Arc::clone(&db),
|
||||
)
|
||||
);
|
||||
} else {
|
||||
current_transaction_status.remove(&prefix);
|
||||
}
|
||||
}
|
||||
Err((outgoing_kind, _)) => {
|
||||
current_transaction_status.entry(outgoing_kind.get_prefix()).and_modify(|e| *e = match e {
|
||||
TransactionStatus::Running => TransactionStatus::Failed(1, Instant::now()),
|
||||
TransactionStatus::Retrying(n) => TransactionStatus::Failed(*n+1, Instant::now()),
|
||||
TransactionStatus::Failed(_, _) => {
|
||||
error!("Request that was not even running failed?!");
|
||||
return
|
||||
},
|
||||
});
|
||||
}
|
||||
};
|
||||
},
|
||||
Some((key, value)) = receiver.recv() => {
|
||||
if let Ok((outgoing_kind, event)) = Self::parse_servercurrentevent(&key, value) {
|
||||
let guard = db.read().await;
|
||||
|
||||
if let Ok(Some(events)) = Self::select_events(
|
||||
&outgoing_kind,
|
||||
vec![(event, key)],
|
||||
&mut current_transaction_status,
|
||||
&guard
|
||||
) {
|
||||
futures.push(Self::handle_events(outgoing_kind, events, Arc::clone(&db)));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(outgoing_kind, new_events, current_transaction_status, db))]
|
||||
fn select_events(
|
||||
outgoing_kind: &OutgoingKind,
|
||||
new_events: Vec<(SendingEventType, Vec<u8>)>, // Events we want to send: event and full key
|
||||
current_transaction_status: &mut HashMap<Vec<u8>, TransactionStatus>,
|
||||
db: &Database,
|
||||
) -> Result<Option<Vec<SendingEventType>>> {
|
||||
let mut retry = false;
|
||||
let mut allow = true;
|
||||
|
||||
let prefix = outgoing_kind.get_prefix();
|
||||
let entry = current_transaction_status.entry(prefix.clone());
|
||||
|
||||
entry
|
||||
.and_modify(|e| match e {
|
||||
TransactionStatus::Running | TransactionStatus::Retrying(_) => {
|
||||
allow = false; // already running
|
||||
}
|
||||
TransactionStatus::Failed(tries, time) => {
|
||||
// Fail if a request has failed recently (exponential backoff)
|
||||
let mut min_elapsed_duration = Duration::from_secs(30) * (*tries) * (*tries);
|
||||
if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) {
|
||||
min_elapsed_duration = Duration::from_secs(60 * 60 * 24);
|
||||
}
|
||||
|
||||
if time.elapsed() < min_elapsed_duration {
|
||||
allow = false;
|
||||
} else {
|
||||
retry = true;
|
||||
*e = TransactionStatus::Retrying(*tries);
|
||||
}
|
||||
}
|
||||
})
|
||||
.or_insert(TransactionStatus::Running);
|
||||
|
||||
if !allow {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let mut events = Vec::new();
|
||||
|
||||
if retry {
|
||||
// We retry the previous transaction
|
||||
for (key, value) in db.sending.servercurrentevent_data.scan_prefix(prefix) {
|
||||
if let Ok((_, e)) = Self::parse_servercurrentevent(&key, value) {
|
||||
events.push(e);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (e, full_key) in new_events {
|
||||
let value = if let SendingEventType::Edu(value) = &e {
|
||||
&**value
|
||||
} else {
|
||||
&[][..]
|
||||
};
|
||||
db.sending
|
||||
.servercurrentevent_data
|
||||
.insert(&full_key, value)?;
|
||||
|
||||
// If it was a PDU we have to unqueue it
|
||||
// TODO: don't try to unqueue EDUs
|
||||
db.sending.servernameevent_data.remove(&full_key)?;
|
||||
|
||||
events.push(e);
|
||||
}
|
||||
|
||||
if let OutgoingKind::Normal(server_name) = outgoing_kind {
|
||||
if let Ok((select_edus, last_count)) = Self::select_edus(db, server_name) {
|
||||
events.extend(select_edus.into_iter().map(SendingEventType::Edu));
|
||||
|
||||
db.sending
|
||||
.servername_educount
|
||||
.insert(server_name.as_bytes(), &last_count.to_be_bytes())?;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Some(events))
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(db, server))]
|
||||
pub fn select_edus(db: &Database, server: &ServerName) -> Result<(Vec<Vec<u8>>, u64)> {
|
||||
// u64: count of last edu
|
||||
let since = db
|
||||
.sending
|
||||
.servername_educount
|
||||
.get(server.as_bytes())?
|
||||
.map_or(Ok(0), |bytes| {
|
||||
utils::u64_from_bytes(&bytes)
|
||||
.map_err(|_| Error::bad_database("Invalid u64 in servername_educount."))
|
||||
})?;
|
||||
let mut events = Vec::new();
|
||||
let mut max_edu_count = since;
|
||||
let mut device_list_changes = HashSet::new();
|
||||
|
||||
'outer: for room_id in db.rooms.server_rooms(server) {
|
||||
let room_id = room_id?;
|
||||
// Look for device list updates in this room
|
||||
device_list_changes.extend(
|
||||
db.users
|
||||
.keys_changed(&room_id.to_string(), since, None)
|
||||
.filter_map(|r| r.ok())
|
||||
.filter(|user_id| user_id.server_name() == db.globals.server_name()),
|
||||
);
|
||||
|
||||
// Look for read receipts in this room
|
||||
for r in db.rooms.edus.readreceipts_since(&room_id, since) {
|
||||
let (user_id, count, read_receipt) = r?;
|
||||
|
||||
if count > max_edu_count {
|
||||
max_edu_count = count;
|
||||
}
|
||||
|
||||
if user_id.server_name() != db.globals.server_name() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let event: AnySyncEphemeralRoomEvent =
|
||||
serde_json::from_str(read_receipt.json().get())
|
||||
.map_err(|_| Error::bad_database("Invalid edu event in read_receipts."))?;
|
||||
let federation_event = match event {
|
||||
AnySyncEphemeralRoomEvent::Receipt(r) => {
|
||||
let mut read = BTreeMap::new();
|
||||
|
||||
let (event_id, mut receipt) = r
|
||||
.content
|
||||
.0
|
||||
.into_iter()
|
||||
.next()
|
||||
.expect("we only use one event per read receipt");
|
||||
let receipt = receipt
|
||||
.remove(&ReceiptType::Read)
|
||||
.expect("our read receipts always set this")
|
||||
.remove(&user_id)
|
||||
.expect("our read receipts always have the user here");
|
||||
|
||||
read.insert(
|
||||
user_id,
|
||||
ReceiptData {
|
||||
data: receipt.clone(),
|
||||
event_ids: vec![event_id.clone()],
|
||||
},
|
||||
);
|
||||
|
||||
let receipt_map = ReceiptMap { read };
|
||||
|
||||
let mut receipts = BTreeMap::new();
|
||||
receipts.insert(room_id.clone(), receipt_map);
|
||||
|
||||
Edu::Receipt(ReceiptContent { receipts })
|
||||
}
|
||||
_ => {
|
||||
Error::bad_database("Invalid event type in read_receipts");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
events.push(serde_json::to_vec(&federation_event).expect("json can be serialized"));
|
||||
|
||||
if events.len() >= 20 {
|
||||
break 'outer;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for user_id in device_list_changes {
|
||||
// Empty prev id forces synapse to resync: https://github.com/matrix-org/synapse/blob/98aec1cc9da2bd6b8e34ffb282c85abf9b8b42ca/synapse/handlers/device.py#L767
|
||||
// Because synapse resyncs, we can just insert dummy data
|
||||
let edu = Edu::DeviceListUpdate(DeviceListUpdateContent {
|
||||
user_id,
|
||||
device_id: device_id!("dummy").to_owned(),
|
||||
device_display_name: Some("Dummy".to_owned()),
|
||||
stream_id: uint!(1),
|
||||
prev_id: Vec::new(),
|
||||
deleted: None,
|
||||
keys: None,
|
||||
});
|
||||
|
||||
events.push(serde_json::to_vec(&edu).expect("json can be serialized"));
|
||||
}
|
||||
|
||||
Ok((events, max_edu_count))
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self, pdu_id, senderkey))]
|
||||
pub fn send_push_pdu(&self, pdu_id: &[u8], senderkey: Vec<u8>) -> Result<()> {
|
||||
let mut key = b"$".to_vec();
|
||||
key.extend_from_slice(&senderkey);
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(pdu_id);
|
||||
self.servernameevent_data.insert(&key, &[])?;
|
||||
self.sender.send((key, vec![])).unwrap();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self, servers, pdu_id))]
|
||||
pub fn send_pdu<I: Iterator<Item = Box<ServerName>>>(
|
||||
&self,
|
||||
servers: I,
|
||||
pdu_id: &[u8],
|
||||
) -> Result<()> {
|
||||
let mut batch = servers.map(|server| {
|
||||
let mut key = server.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(pdu_id);
|
||||
|
||||
self.sender.send((key.clone(), vec![])).unwrap();
|
||||
|
||||
(key, Vec::new())
|
||||
});
|
||||
|
||||
self.servernameevent_data.insert_batch(&mut batch)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self, server, serialized))]
|
||||
pub fn send_reliable_edu(
|
||||
&self,
|
||||
server: &ServerName,
|
||||
serialized: Vec<u8>,
|
||||
id: u64,
|
||||
) -> Result<()> {
|
||||
let mut key = server.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(&id.to_be_bytes());
|
||||
self.servernameevent_data.insert(&key, &serialized)?;
|
||||
self.sender.send((key, serialized)).unwrap();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self))]
|
||||
pub fn send_pdu_appservice(&self, appservice_id: &str, pdu_id: &[u8]) -> Result<()> {
|
||||
let mut key = b"+".to_vec();
|
||||
key.extend_from_slice(appservice_id.as_bytes());
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(pdu_id);
|
||||
self.servernameevent_data.insert(&key, &[])?;
|
||||
self.sender.send((key, vec![])).unwrap();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(keys))]
|
||||
fn calculate_hash(keys: &[&[u8]]) -> Vec<u8> {
|
||||
// We only hash the pdu's event ids, not the whole pdu
|
||||
let bytes = keys.join(&0xff);
|
||||
let hash = digest::digest(&digest::SHA256, &bytes);
|
||||
hash.as_ref().to_owned()
|
||||
}
|
||||
|
||||
/// Cleanup event data
|
||||
/// Used for instance after we remove an appservice registration
|
||||
///
|
||||
#[tracing::instrument(skip(self))]
|
||||
pub fn cleanup_events(&self, key_id: &str) -> Result<()> {
|
||||
let mut prefix = b"+".to_vec();
|
||||
prefix.extend_from_slice(key_id.as_bytes());
|
||||
prefix.push(0xff);
|
||||
|
||||
for (key, _) in self.servercurrentevent_data.scan_prefix(prefix.clone()) {
|
||||
self.servercurrentevent_data.remove(&key).unwrap();
|
||||
}
|
||||
|
||||
for (key, _) in self.servernameevent_data.scan_prefix(prefix.clone()) {
|
||||
self.servernameevent_data.remove(&key).unwrap();
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(db, events, kind))]
|
||||
async fn handle_events(
|
||||
kind: OutgoingKind,
|
||||
events: Vec<SendingEventType>,
|
||||
db: Arc<RwLock<Database>>,
|
||||
) -> Result<OutgoingKind, (OutgoingKind, Error)> {
|
||||
let db = db.read().await;
|
||||
|
||||
match &kind {
|
||||
OutgoingKind::Appservice(id) => {
|
||||
let mut pdu_jsons = Vec::new();
|
||||
|
||||
for event in &events {
|
||||
match event {
|
||||
SendingEventType::Pdu(pdu_id) => {
|
||||
pdu_jsons.push(db.rooms
|
||||
.get_pdu_from_id(pdu_id)
|
||||
.map_err(|e| (kind.clone(), e))?
|
||||
.ok_or_else(|| {
|
||||
(
|
||||
kind.clone(),
|
||||
Error::bad_database(
|
||||
"[Appservice] Event in servernameevent_data not found in db.",
|
||||
),
|
||||
)
|
||||
})?
|
||||
.to_room_event())
|
||||
}
|
||||
SendingEventType::Edu(_) => {
|
||||
// Appservices don't need EDUs (?)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let permit = db.sending.maximum_requests.acquire().await;
|
||||
|
||||
let response = appservice_server::send_request(
|
||||
&db.globals,
|
||||
db.appservice
|
||||
.get_registration(&id)
|
||||
.map_err(|e| (kind.clone(), e))?
|
||||
.ok_or_else(|| {
|
||||
(
|
||||
kind.clone(),
|
||||
Error::bad_database(
|
||||
"[Appservice] Could not load registration from db.",
|
||||
),
|
||||
)
|
||||
})?,
|
||||
appservice::event::push_events::v1::Request {
|
||||
events: &pdu_jsons,
|
||||
txn_id: (&*base64::encode_config(
|
||||
Self::calculate_hash(
|
||||
&events
|
||||
.iter()
|
||||
.map(|e| match e {
|
||||
SendingEventType::Edu(b) | SendingEventType::Pdu(b) => &**b,
|
||||
})
|
||||
.collect::<Vec<_>>(),
|
||||
),
|
||||
base64::URL_SAFE_NO_PAD,
|
||||
))
|
||||
.into(),
|
||||
},
|
||||
)
|
||||
.await
|
||||
.map(|_response| kind.clone())
|
||||
.map_err(|e| (kind, e));
|
||||
|
||||
drop(permit);
|
||||
|
||||
response
|
||||
}
|
||||
OutgoingKind::Push(user, pushkey) => {
|
||||
let mut pdus = Vec::new();
|
||||
|
||||
for event in &events {
|
||||
match event {
|
||||
SendingEventType::Pdu(pdu_id) => {
|
||||
pdus.push(
|
||||
db.rooms
|
||||
.get_pdu_from_id(pdu_id)
|
||||
.map_err(|e| (kind.clone(), e))?
|
||||
.ok_or_else(|| {
|
||||
(
|
||||
kind.clone(),
|
||||
Error::bad_database(
|
||||
"[Push] Event in servernamevent_datas not found in db.",
|
||||
),
|
||||
)
|
||||
})?,
|
||||
);
|
||||
}
|
||||
SendingEventType::Edu(_) => {
|
||||
// Push gateways don't need EDUs (?)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for pdu in pdus {
|
||||
// Redacted events are not notification targets (we don't send push for them)
|
||||
if let Some(unsigned) = &pdu.unsigned {
|
||||
if let Ok(unsigned) =
|
||||
serde_json::from_str::<serde_json::Value>(unsigned.get())
|
||||
{
|
||||
if unsigned.get("redacted_because").is_some() {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let userid = UserId::parse(utils::string_from_bytes(user).map_err(|_| {
|
||||
(
|
||||
kind.clone(),
|
||||
Error::bad_database("Invalid push user string in db."),
|
||||
)
|
||||
})?)
|
||||
.map_err(|_| {
|
||||
(
|
||||
kind.clone(),
|
||||
Error::bad_database("Invalid push user id in db."),
|
||||
)
|
||||
})?;
|
||||
|
||||
let mut senderkey = user.clone();
|
||||
senderkey.push(0xff);
|
||||
senderkey.extend_from_slice(pushkey);
|
||||
|
||||
let pusher = match db
|
||||
.pusher
|
||||
.get_pusher(&senderkey)
|
||||
.map_err(|e| (OutgoingKind::Push(user.clone(), pushkey.clone()), e))?
|
||||
{
|
||||
Some(pusher) => pusher,
|
||||
None => continue,
|
||||
};
|
||||
|
||||
let rules_for_user = db
|
||||
.account_data
|
||||
.get(
|
||||
None,
|
||||
&userid,
|
||||
GlobalAccountDataEventType::PushRules.to_string().into(),
|
||||
)
|
||||
.unwrap_or_default()
|
||||
.map(|ev: PushRulesEvent| ev.content.global)
|
||||
.unwrap_or_else(|| push::Ruleset::server_default(&userid));
|
||||
|
||||
let unread: UInt = db
|
||||
.rooms
|
||||
.notification_count(&userid, &pdu.room_id)
|
||||
.map_err(|e| (kind.clone(), e))?
|
||||
.try_into()
|
||||
.expect("notifiation count can't go that high");
|
||||
|
||||
let permit = db.sending.maximum_requests.acquire().await;
|
||||
|
||||
let _response = pusher::send_push_notice(
|
||||
&userid,
|
||||
unread,
|
||||
&pusher,
|
||||
rules_for_user,
|
||||
&pdu,
|
||||
&db,
|
||||
)
|
||||
.await
|
||||
.map(|_response| kind.clone())
|
||||
.map_err(|e| (kind.clone(), e));
|
||||
|
||||
drop(permit);
|
||||
}
|
||||
Ok(OutgoingKind::Push(user.clone(), pushkey.clone()))
|
||||
}
|
||||
OutgoingKind::Normal(server) => {
|
||||
let mut edu_jsons = Vec::new();
|
||||
let mut pdu_jsons = Vec::new();
|
||||
|
||||
for event in &events {
|
||||
match event {
|
||||
SendingEventType::Pdu(pdu_id) => {
|
||||
// TODO: check room version and remove event_id if needed
|
||||
let raw = PduEvent::convert_to_outgoing_federation_event(
|
||||
db.rooms
|
||||
.get_pdu_json_from_id(pdu_id)
|
||||
.map_err(|e| (OutgoingKind::Normal(server.clone()), e))?
|
||||
.ok_or_else(|| {
|
||||
(
|
||||
OutgoingKind::Normal(server.clone()),
|
||||
Error::bad_database(
|
||||
"[Normal] Event in servernamevent_datas not found in db.",
|
||||
),
|
||||
)
|
||||
})?,
|
||||
);
|
||||
pdu_jsons.push(raw);
|
||||
}
|
||||
SendingEventType::Edu(edu) => {
|
||||
if let Ok(raw) = serde_json::from_slice(edu) {
|
||||
edu_jsons.push(raw);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let permit = db.sending.maximum_requests.acquire().await;
|
||||
|
||||
let response = server_server::send_request(
|
||||
&db.globals,
|
||||
&*server,
|
||||
send_transaction_message::v1::Request {
|
||||
origin: db.globals.server_name(),
|
||||
pdus: &pdu_jsons,
|
||||
edus: &edu_jsons,
|
||||
origin_server_ts: MilliSecondsSinceUnixEpoch::now(),
|
||||
transaction_id: (&*base64::encode_config(
|
||||
Self::calculate_hash(
|
||||
&events
|
||||
.iter()
|
||||
.map(|e| match e {
|
||||
SendingEventType::Edu(b) | SendingEventType::Pdu(b) => &**b,
|
||||
})
|
||||
.collect::<Vec<_>>(),
|
||||
),
|
||||
base64::URL_SAFE_NO_PAD,
|
||||
))
|
||||
.into(),
|
||||
},
|
||||
)
|
||||
.await
|
||||
.map(|response| {
|
||||
for pdu in response.pdus {
|
||||
if pdu.1.is_err() {
|
||||
warn!("Failed to send to {}: {:?}", server, pdu);
|
||||
}
|
||||
}
|
||||
kind.clone()
|
||||
})
|
||||
.map_err(|e| (kind, e));
|
||||
|
||||
drop(permit);
|
||||
|
||||
response
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(key))]
|
||||
fn parse_servercurrentevent(
|
||||
key: &[u8],
|
||||
value: Vec<u8>,
|
||||
) -> Result<(OutgoingKind, SendingEventType)> {
|
||||
// Appservices start with a plus
|
||||
Ok::<_, Error>(if key.starts_with(b"+") {
|
||||
let mut parts = key[1..].splitn(2, |&b| b == 0xff);
|
||||
|
||||
let server = parts.next().expect("splitn always returns one element");
|
||||
let event = parts
|
||||
.next()
|
||||
.ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?;
|
||||
let server = utils::string_from_bytes(server).map_err(|_| {
|
||||
Error::bad_database("Invalid server bytes in server_currenttransaction")
|
||||
})?;
|
||||
|
||||
(
|
||||
OutgoingKind::Appservice(server),
|
||||
if value.is_empty() {
|
||||
SendingEventType::Pdu(event.to_vec())
|
||||
} else {
|
||||
SendingEventType::Edu(value)
|
||||
},
|
||||
)
|
||||
} else if key.starts_with(b"$") {
|
||||
let mut parts = key[1..].splitn(3, |&b| b == 0xff);
|
||||
|
||||
let user = parts.next().expect("splitn always returns one element");
|
||||
let pushkey = parts
|
||||
.next()
|
||||
.ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?;
|
||||
let event = parts
|
||||
.next()
|
||||
.ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?;
|
||||
(
|
||||
OutgoingKind::Push(user.to_vec(), pushkey.to_vec()),
|
||||
if value.is_empty() {
|
||||
SendingEventType::Pdu(event.to_vec())
|
||||
} else {
|
||||
SendingEventType::Edu(value)
|
||||
},
|
||||
)
|
||||
} else {
|
||||
let mut parts = key.splitn(2, |&b| b == 0xff);
|
||||
|
||||
let server = parts.next().expect("splitn always returns one element");
|
||||
let event = parts
|
||||
.next()
|
||||
.ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?;
|
||||
let server = utils::string_from_bytes(server).map_err(|_| {
|
||||
Error::bad_database("Invalid server bytes in server_currenttransaction")
|
||||
})?;
|
||||
|
||||
(
|
||||
OutgoingKind::Normal(ServerName::parse(server).map_err(|_| {
|
||||
Error::bad_database("Invalid server string in server_currenttransaction")
|
||||
})?),
|
||||
if value.is_empty() {
|
||||
SendingEventType::Pdu(event.to_vec())
|
||||
} else {
|
||||
SendingEventType::Edu(value)
|
||||
},
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self, globals, destination, request))]
|
||||
pub async fn send_federation_request<T: OutgoingRequest>(
|
||||
&self,
|
||||
globals: &crate::database::globals::Globals,
|
||||
destination: &ServerName,
|
||||
request: T,
|
||||
) -> Result<T::IncomingResponse>
|
||||
where
|
||||
T: Debug,
|
||||
{
|
||||
let permit = self.maximum_requests.acquire().await;
|
||||
let response = server_server::send_request(globals, destination, request).await;
|
||||
drop(permit);
|
||||
|
||||
response
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self, globals, registration, request))]
|
||||
pub async fn send_appservice_request<T: OutgoingRequest>(
|
||||
&self,
|
||||
globals: &crate::database::globals::Globals,
|
||||
registration: serde_yaml::Value,
|
||||
request: T,
|
||||
) -> Result<T::IncomingResponse>
|
||||
where
|
||||
T: Debug,
|
||||
{
|
||||
let permit = self.maximum_requests.acquire().await;
|
||||
let response = appservice_server::send_request(globals, registration, request).await;
|
||||
drop(permit);
|
||||
|
||||
response
|
||||
}
|
||||
}
|
|
@ -1,46 +0,0 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use crate::Result;
|
||||
use ruma::{DeviceId, TransactionId, UserId};
|
||||
|
||||
use super::abstraction::Tree;
|
||||
|
||||
pub struct TransactionIds {
|
||||
pub(super) userdevicetxnid_response: Arc<dyn Tree>, // Response can be empty (/sendToDevice) or the event id (/send)
|
||||
}
|
||||
|
||||
impl TransactionIds {
|
||||
pub fn add_txnid(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
device_id: Option<&DeviceId>,
|
||||
txn_id: &TransactionId,
|
||||
data: &[u8],
|
||||
) -> Result<()> {
|
||||
let mut key = user_id.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(device_id.map(|d| d.as_bytes()).unwrap_or_default());
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(txn_id.as_bytes());
|
||||
|
||||
self.userdevicetxnid_response.insert(&key, data)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn existing_txnid(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
device_id: Option<&DeviceId>,
|
||||
txn_id: &TransactionId,
|
||||
) -> Result<Option<Vec<u8>>> {
|
||||
let mut key = user_id.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(device_id.map(|d| d.as_bytes()).unwrap_or_default());
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(txn_id.as_bytes());
|
||||
|
||||
// If there's no entry, this is a new transaction
|
||||
self.userdevicetxnid_response.get(&key)
|
||||
}
|
||||
}
|
|
@ -1,227 +0,0 @@
|
|||
use std::{
|
||||
collections::BTreeMap,
|
||||
sync::{Arc, RwLock},
|
||||
};
|
||||
|
||||
use crate::{client_server::SESSION_ID_LENGTH, utils, Error, Result};
|
||||
use ruma::{
|
||||
api::client::{
|
||||
error::ErrorKind,
|
||||
uiaa::{
|
||||
AuthType, IncomingAuthData, IncomingPassword,
|
||||
IncomingUserIdentifier::UserIdOrLocalpart, UiaaInfo,
|
||||
},
|
||||
},
|
||||
signatures::CanonicalJsonValue,
|
||||
DeviceId, UserId,
|
||||
};
|
||||
use tracing::error;
|
||||
|
||||
use super::abstraction::Tree;
|
||||
|
||||
pub struct Uiaa {
|
||||
pub(super) userdevicesessionid_uiaainfo: Arc<dyn Tree>, // User-interactive authentication
|
||||
pub(super) userdevicesessionid_uiaarequest:
|
||||
RwLock<BTreeMap<(Box<UserId>, Box<DeviceId>, String), CanonicalJsonValue>>,
|
||||
}
|
||||
|
||||
impl Uiaa {
|
||||
/// Creates a new Uiaa session. Make sure the session token is unique.
|
||||
pub fn create(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
device_id: &DeviceId,
|
||||
uiaainfo: &UiaaInfo,
|
||||
json_body: &CanonicalJsonValue,
|
||||
) -> Result<()> {
|
||||
self.set_uiaa_request(
|
||||
user_id,
|
||||
device_id,
|
||||
uiaainfo.session.as_ref().expect("session should be set"), // TODO: better session error handling (why is it optional in ruma?)
|
||||
json_body,
|
||||
)?;
|
||||
self.update_uiaa_session(
|
||||
user_id,
|
||||
device_id,
|
||||
uiaainfo.session.as_ref().expect("session should be set"),
|
||||
Some(uiaainfo),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn try_auth(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
device_id: &DeviceId,
|
||||
auth: &IncomingAuthData,
|
||||
uiaainfo: &UiaaInfo,
|
||||
users: &super::users::Users,
|
||||
globals: &super::globals::Globals,
|
||||
) -> Result<(bool, UiaaInfo)> {
|
||||
let mut uiaainfo = auth
|
||||
.session()
|
||||
.map(|session| self.get_uiaa_session(user_id, device_id, session))
|
||||
.unwrap_or_else(|| Ok(uiaainfo.clone()))?;
|
||||
|
||||
if uiaainfo.session.is_none() {
|
||||
uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH));
|
||||
}
|
||||
|
||||
match auth {
|
||||
// Find out what the user completed
|
||||
IncomingAuthData::Password(IncomingPassword {
|
||||
identifier,
|
||||
password,
|
||||
..
|
||||
}) => {
|
||||
let username = match identifier {
|
||||
UserIdOrLocalpart(username) => username,
|
||||
_ => {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::Unrecognized,
|
||||
"Identifier type not recognized.",
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
let user_id =
|
||||
UserId::parse_with_server_name(username.clone(), globals.server_name())
|
||||
.map_err(|_| {
|
||||
Error::BadRequest(ErrorKind::InvalidParam, "User ID is invalid.")
|
||||
})?;
|
||||
|
||||
// Check if password is correct
|
||||
if let Some(hash) = users.password_hash(&user_id)? {
|
||||
let hash_matches =
|
||||
argon2::verify_encoded(&hash, password.as_bytes()).unwrap_or(false);
|
||||
|
||||
if !hash_matches {
|
||||
uiaainfo.auth_error = Some(ruma::api::client::error::ErrorBody {
|
||||
kind: ErrorKind::Forbidden,
|
||||
message: "Invalid username or password.".to_owned(),
|
||||
});
|
||||
return Ok((false, uiaainfo));
|
||||
}
|
||||
}
|
||||
|
||||
// Password was correct! Let's add it to `completed`
|
||||
uiaainfo.completed.push(AuthType::Password);
|
||||
}
|
||||
IncomingAuthData::Dummy(_) => {
|
||||
uiaainfo.completed.push(AuthType::Dummy);
|
||||
}
|
||||
k => error!("type not supported: {:?}", k),
|
||||
}
|
||||
|
||||
// Check if a flow now succeeds
|
||||
let mut completed = false;
|
||||
'flows: for flow in &mut uiaainfo.flows {
|
||||
for stage in &flow.stages {
|
||||
if !uiaainfo.completed.contains(stage) {
|
||||
continue 'flows;
|
||||
}
|
||||
}
|
||||
// We didn't break, so this flow succeeded!
|
||||
completed = true;
|
||||
}
|
||||
|
||||
if !completed {
|
||||
self.update_uiaa_session(
|
||||
user_id,
|
||||
device_id,
|
||||
uiaainfo.session.as_ref().expect("session is always set"),
|
||||
Some(&uiaainfo),
|
||||
)?;
|
||||
return Ok((false, uiaainfo));
|
||||
}
|
||||
|
||||
// UIAA was successful! Remove this session and return true
|
||||
self.update_uiaa_session(
|
||||
user_id,
|
||||
device_id,
|
||||
uiaainfo.session.as_ref().expect("session is always set"),
|
||||
None,
|
||||
)?;
|
||||
Ok((true, uiaainfo))
|
||||
}
|
||||
|
||||
fn set_uiaa_request(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
device_id: &DeviceId,
|
||||
session: &str,
|
||||
request: &CanonicalJsonValue,
|
||||
) -> Result<()> {
|
||||
self.userdevicesessionid_uiaarequest
|
||||
.write()
|
||||
.unwrap()
|
||||
.insert(
|
||||
(user_id.to_owned(), device_id.to_owned(), session.to_owned()),
|
||||
request.to_owned(),
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn get_uiaa_request(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
device_id: &DeviceId,
|
||||
session: &str,
|
||||
) -> Option<CanonicalJsonValue> {
|
||||
self.userdevicesessionid_uiaarequest
|
||||
.read()
|
||||
.unwrap()
|
||||
.get(&(user_id.to_owned(), device_id.to_owned(), session.to_owned()))
|
||||
.map(|j| j.to_owned())
|
||||
}
|
||||
|
||||
fn update_uiaa_session(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
device_id: &DeviceId,
|
||||
session: &str,
|
||||
uiaainfo: Option<&UiaaInfo>,
|
||||
) -> Result<()> {
|
||||
let mut userdevicesessionid = user_id.as_bytes().to_vec();
|
||||
userdevicesessionid.push(0xff);
|
||||
userdevicesessionid.extend_from_slice(device_id.as_bytes());
|
||||
userdevicesessionid.push(0xff);
|
||||
userdevicesessionid.extend_from_slice(session.as_bytes());
|
||||
|
||||
if let Some(uiaainfo) = uiaainfo {
|
||||
self.userdevicesessionid_uiaainfo.insert(
|
||||
&userdevicesessionid,
|
||||
&serde_json::to_vec(&uiaainfo).expect("UiaaInfo::to_vec always works"),
|
||||
)?;
|
||||
} else {
|
||||
self.userdevicesessionid_uiaainfo
|
||||
.remove(&userdevicesessionid)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn get_uiaa_session(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
device_id: &DeviceId,
|
||||
session: &str,
|
||||
) -> Result<UiaaInfo> {
|
||||
let mut userdevicesessionid = user_id.as_bytes().to_vec();
|
||||
userdevicesessionid.push(0xff);
|
||||
userdevicesessionid.extend_from_slice(device_id.as_bytes());
|
||||
userdevicesessionid.push(0xff);
|
||||
userdevicesessionid.extend_from_slice(session.as_bytes());
|
||||
|
||||
serde_json::from_slice(
|
||||
&self
|
||||
.userdevicesessionid_uiaainfo
|
||||
.get(&userdevicesessionid)?
|
||||
.ok_or(Error::BadRequest(
|
||||
ErrorKind::Forbidden,
|
||||
"UIAA session does not exist.",
|
||||
))?,
|
||||
)
|
||||
.map_err(|_| Error::bad_database("UiaaInfo in userdeviceid_uiaainfo is invalid."))
|
||||
}
|
||||
}
|
File diff suppressed because it is too large
Load diff
Loading…
Add table
Add a link
Reference in a new issue