Refactor appservices, pusher, timeline, transactionids, users

This commit is contained in:
Timo Kösters 2022-08-07 19:42:22 +02:00 committed by Nyaaori
parent 01bf348811
commit f56424bc8d
No known key found for this signature in database
GPG key ID: E7819C3ED4D1F82E
18 changed files with 546 additions and 4950 deletions

View file

@ -1,49 +1,10 @@
use crate::{utils, Error, Result};
use ruma::{
api::client::{device::Device, error::ErrorKind, filter::IncomingFilterDefinition},
encryption::{CrossSigningKey, DeviceKeys, OneTimeKey},
events::{AnyToDeviceEvent, StateEventType},
serde::Raw,
DeviceId, DeviceKeyAlgorithm, DeviceKeyId, MilliSecondsSinceUnixEpoch, MxcUri, RoomAliasId,
UInt, UserId,
};
use std::{collections::BTreeMap, mem, sync::Arc};
use tracing::warn;
use super::abstraction::Tree;
pub struct Users {
pub(super) userid_password: Arc<dyn Tree>,
pub(super) userid_displayname: Arc<dyn Tree>,
pub(super) userid_avatarurl: Arc<dyn Tree>,
pub(super) userid_blurhash: Arc<dyn Tree>,
pub(super) userdeviceid_token: Arc<dyn Tree>,
pub(super) userdeviceid_metadata: Arc<dyn Tree>, // This is also used to check if a device exists
pub(super) userid_devicelistversion: Arc<dyn Tree>, // DevicelistVersion = u64
pub(super) token_userdeviceid: Arc<dyn Tree>,
pub(super) onetimekeyid_onetimekeys: Arc<dyn Tree>, // OneTimeKeyId = UserId + DeviceKeyId
pub(super) userid_lastonetimekeyupdate: Arc<dyn Tree>, // LastOneTimeKeyUpdate = Count
pub(super) keychangeid_userid: Arc<dyn Tree>, // KeyChangeId = UserId/RoomId + Count
pub(super) keyid_key: Arc<dyn Tree>, // KeyId = UserId + KeyId (depends on key type)
pub(super) userid_masterkeyid: Arc<dyn Tree>,
pub(super) userid_selfsigningkeyid: Arc<dyn Tree>,
pub(super) userid_usersigningkeyid: Arc<dyn Tree>,
pub(super) userfilterid_filter: Arc<dyn Tree>, // UserFilterId = UserId + FilterId
pub(super) todeviceid_events: Arc<dyn Tree>, // ToDeviceId = UserId + DeviceId + Count
}
impl Users {
impl service::users::Data for KeyValueDatabase {
/// Check if a user has an account on this homeserver.
#[tracing::instrument(skip(self, user_id))]
pub fn exists(&self, user_id: &UserId) -> Result<bool> {
Ok(self.userid_password.get(user_id.as_bytes())?.is_some())
}
/// Check if account is deactivated
#[tracing::instrument(skip(self, user_id))]
pub fn is_deactivated(&self, user_id: &UserId) -> Result<bool> {
Ok(self
.userid_password
@ -56,7 +17,6 @@ impl Users {
}
/// Check if a user is an admin
#[tracing::instrument(skip(self, user_id, rooms, globals))]
pub fn is_admin(
&self,
user_id: &UserId,
@ -71,20 +31,17 @@ impl Users {
}
/// Create a new user account on this homeserver.
#[tracing::instrument(skip(self, user_id, password))]
pub fn create(&self, user_id: &UserId, password: Option<&str>) -> Result<()> {
self.set_password(user_id, password)?;
Ok(())
}
/// Returns the number of users registered on this server.
#[tracing::instrument(skip(self))]
pub fn count(&self) -> Result<usize> {
Ok(self.userid_password.iter().count())
}
/// Find out which user an access token belongs to.
#[tracing::instrument(skip(self, token))]
pub fn find_from_token(&self, token: &str) -> Result<Option<(Box<UserId>, String)>> {
self.token_userdeviceid
.get(token.as_bytes())?
@ -112,7 +69,6 @@ impl Users {
}
/// Returns an iterator over all users on this homeserver.
#[tracing::instrument(skip(self))]
pub fn iter(&self) -> impl Iterator<Item = Result<Box<UserId>>> + '_ {
self.userid_password.iter().map(|(bytes, _)| {
UserId::parse(utils::string_from_bytes(&bytes).map_err(|_| {
@ -125,7 +81,6 @@ impl Users {
/// Returns a list of local users as list of usernames.
///
/// A user account is considered `local` if the length of it's password is greater then zero.
#[tracing::instrument(skip(self))]
pub fn list_local_users(&self) -> Result<Vec<String>> {
let users: Vec<String> = self
.userid_password
@ -139,7 +94,6 @@ impl Users {
/// username could be successfully parsed.
/// If utils::string_from_bytes(...) returns an error that username will be skipped
/// and the error will be logged.
#[tracing::instrument(skip(self))]
fn get_username_with_valid_password(&self, username: &[u8], password: &[u8]) -> Option<String> {
// A valid password is not empty
if password.is_empty() {
@ -159,7 +113,6 @@ impl Users {
}
/// Returns the password hash for the given user.
#[tracing::instrument(skip(self, user_id))]
pub fn password_hash(&self, user_id: &UserId) -> Result<Option<String>> {
self.userid_password
.get(user_id.as_bytes())?
@ -171,7 +124,6 @@ impl Users {
}
/// Hash and set the user's password to the Argon2 hash
#[tracing::instrument(skip(self, user_id, password))]
pub fn set_password(&self, user_id: &UserId, password: Option<&str>) -> Result<()> {
if let Some(password) = password {
if let Ok(hash) = utils::calculate_hash(password) {
@ -191,7 +143,6 @@ impl Users {
}
/// Returns the displayname of a user on this homeserver.
#[tracing::instrument(skip(self, user_id))]
pub fn displayname(&self, user_id: &UserId) -> Result<Option<String>> {
self.userid_displayname
.get(user_id.as_bytes())?
@ -203,7 +154,6 @@ impl Users {
}
/// Sets a new displayname or removes it if displayname is None. You still need to nofify all rooms of this change.
#[tracing::instrument(skip(self, user_id, displayname))]
pub fn set_displayname(&self, user_id: &UserId, displayname: Option<String>) -> Result<()> {
if let Some(displayname) = displayname {
self.userid_displayname
@ -216,7 +166,6 @@ impl Users {
}
/// Get the avatar_url of a user.
#[tracing::instrument(skip(self, user_id))]
pub fn avatar_url(&self, user_id: &UserId) -> Result<Option<Box<MxcUri>>> {
self.userid_avatarurl
.get(user_id.as_bytes())?
@ -230,7 +179,6 @@ impl Users {
}
/// Sets a new avatar_url or removes it if avatar_url is None.
#[tracing::instrument(skip(self, user_id, avatar_url))]
pub fn set_avatar_url(&self, user_id: &UserId, avatar_url: Option<Box<MxcUri>>) -> Result<()> {
if let Some(avatar_url) = avatar_url {
self.userid_avatarurl
@ -243,7 +191,6 @@ impl Users {
}
/// Get the blurhash of a user.
#[tracing::instrument(skip(self, user_id))]
pub fn blurhash(&self, user_id: &UserId) -> Result<Option<String>> {
self.userid_blurhash
.get(user_id.as_bytes())?
@ -257,7 +204,6 @@ impl Users {
}
/// Sets a new avatar_url or removes it if avatar_url is None.
#[tracing::instrument(skip(self, user_id, blurhash))]
pub fn set_blurhash(&self, user_id: &UserId, blurhash: Option<String>) -> Result<()> {
if let Some(blurhash) = blurhash {
self.userid_blurhash
@ -270,7 +216,6 @@ impl Users {
}
/// Adds a new device to a user.
#[tracing::instrument(skip(self, user_id, device_id, token, initial_device_display_name))]
pub fn create_device(
&self,
user_id: &UserId,
@ -305,7 +250,6 @@ impl Users {
}
/// Removes a device from a user.
#[tracing::instrument(skip(self, user_id, device_id))]
pub fn remove_device(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()> {
let mut userdeviceid = user_id.as_bytes().to_vec();
userdeviceid.push(0xff);
@ -336,7 +280,6 @@ impl Users {
}
/// Returns an iterator over all device ids of this user.
#[tracing::instrument(skip(self, user_id))]
pub fn all_device_ids<'a>(
&'a self,
user_id: &UserId,
@ -359,7 +302,6 @@ impl Users {
}
/// Replaces the access token of one device.
#[tracing::instrument(skip(self, user_id, device_id, token))]
pub fn set_token(&self, user_id: &UserId, device_id: &DeviceId, token: &str) -> Result<()> {
let mut userdeviceid = user_id.as_bytes().to_vec();
userdeviceid.push(0xff);
@ -383,14 +325,6 @@ impl Users {
Ok(())
}
#[tracing::instrument(skip(
self,
user_id,
device_id,
one_time_key_key,
one_time_key_value,
globals
))]
pub fn add_one_time_key(
&self,
user_id: &UserId,
@ -427,7 +361,6 @@ impl Users {
Ok(())
}
#[tracing::instrument(skip(self, user_id))]
pub fn last_one_time_keys_update(&self, user_id: &UserId) -> Result<u64> {
self.userid_lastonetimekeyupdate
.get(user_id.as_bytes())?
@ -439,7 +372,6 @@ impl Users {
.unwrap_or(Ok(0))
}
#[tracing::instrument(skip(self, user_id, device_id, key_algorithm, globals))]
pub fn take_one_time_key(
&self,
user_id: &UserId,
@ -479,7 +411,6 @@ impl Users {
.transpose()
}
#[tracing::instrument(skip(self, user_id, device_id))]
pub fn count_one_time_keys(
&self,
user_id: &UserId,
@ -512,7 +443,6 @@ impl Users {
Ok(counts)
}
#[tracing::instrument(skip(self, user_id, device_id, device_keys, rooms, globals))]
pub fn add_device_keys(
&self,
user_id: &UserId,
@ -535,14 +465,6 @@ impl Users {
Ok(())
}
#[tracing::instrument(skip(
self,
master_key,
self_signing_key,
user_signing_key,
rooms,
globals
))]
pub fn add_cross_signing_keys(
&self,
user_id: &UserId,
@ -658,7 +580,6 @@ impl Users {
Ok(())
}
#[tracing::instrument(skip(self, target_id, key_id, signature, sender_id, rooms, globals))]
pub fn sign_key(
&self,
target_id: &UserId,
@ -703,7 +624,6 @@ impl Users {
Ok(())
}
#[tracing::instrument(skip(self, user_or_room_id, from, to))]
pub fn keys_changed<'a>(
&'a self,
user_or_room_id: &str,
@ -742,7 +662,6 @@ impl Users {
})
}
#[tracing::instrument(skip(self, user_id, rooms, globals))]
pub fn mark_device_key_update(
&self,
user_id: &UserId,
@ -774,7 +693,6 @@ impl Users {
Ok(())
}
#[tracing::instrument(skip(self, user_id, device_id))]
pub fn get_device_keys(
&self,
user_id: &UserId,
@ -791,7 +709,6 @@ impl Users {
})
}
#[tracing::instrument(skip(self, user_id, allowed_signatures))]
pub fn get_master_key<F: Fn(&UserId) -> bool>(
&self,
user_id: &UserId,
@ -813,7 +730,6 @@ impl Users {
})
}
#[tracing::instrument(skip(self, user_id, allowed_signatures))]
pub fn get_self_signing_key<F: Fn(&UserId) -> bool>(
&self,
user_id: &UserId,
@ -835,7 +751,6 @@ impl Users {
})
}
#[tracing::instrument(skip(self, user_id))]
pub fn get_user_signing_key(&self, user_id: &UserId) -> Result<Option<Raw<CrossSigningKey>>> {
self.userid_usersigningkeyid
.get(user_id.as_bytes())?
@ -848,15 +763,6 @@ impl Users {
})
}
#[tracing::instrument(skip(
self,
sender,
target_user_id,
target_device_id,
event_type,
content,
globals
))]
pub fn add_to_device_event(
&self,
sender: &UserId,
@ -884,7 +790,6 @@ impl Users {
Ok(())
}
#[tracing::instrument(skip(self, user_id, device_id))]
pub fn get_to_device_events(
&self,
user_id: &UserId,
@ -907,7 +812,6 @@ impl Users {
Ok(events)
}
#[tracing::instrument(skip(self, user_id, device_id, until))]
pub fn remove_to_device_events(
&self,
user_id: &UserId,
@ -942,7 +846,6 @@ impl Users {
Ok(())
}
#[tracing::instrument(skip(self, user_id, device_id, device))]
pub fn update_device_metadata(
&self,
user_id: &UserId,
@ -968,7 +871,6 @@ impl Users {
}
/// Get device metadata.
#[tracing::instrument(skip(self, user_id, device_id))]
pub fn get_device_metadata(
&self,
user_id: &UserId,
@ -987,7 +889,6 @@ impl Users {
})
}
#[tracing::instrument(skip(self, user_id))]
pub fn get_devicelist_version(&self, user_id: &UserId) -> Result<Option<u64>> {
self.userid_devicelistversion
.get(user_id.as_bytes())?
@ -998,7 +899,6 @@ impl Users {
})
}
#[tracing::instrument(skip(self, user_id))]
pub fn all_devices_metadata<'a>(
&'a self,
user_id: &UserId,
@ -1014,25 +914,7 @@ impl Users {
})
}
/// Deactivate account
#[tracing::instrument(skip(self, user_id))]
pub fn deactivate_account(&self, user_id: &UserId) -> Result<()> {
// Remove all associated devices
for device_id in self.all_device_ids(user_id) {
self.remove_device(user_id, &device_id?)?;
}
// Set the password to "" to indicate a deactivated account. Hashes will never result in an
// empty string, so the user will not be able to log in again. Systems like changing the
// password without logging in should check if the account is deactivated.
self.userid_password.insert(user_id.as_bytes(), &[])?;
// TODO: Unhook 3PID
Ok(())
}
/// Creates a new sync filter. Returns the filter id.
#[tracing::instrument(skip(self))]
pub fn create_filter(
&self,
user_id: &UserId,
@ -1052,7 +934,6 @@ impl Users {
Ok(filter_id)
}
#[tracing::instrument(skip(self))]
pub fn get_filter(
&self,
user_id: &UserId,
@ -1072,30 +953,3 @@ impl Users {
}
}
}
/// Ensure that a user only sees signatures from themselves and the target user
fn clean_signatures<F: Fn(&UserId) -> bool>(
cross_signing_key: &mut serde_json::Value,
user_id: &UserId,
allowed_signatures: F,
) -> Result<(), Error> {
if let Some(signatures) = cross_signing_key
.get_mut("signatures")
.and_then(|v| v.as_object_mut())
{
// Don't allocate for the full size of the current signatures, but require
// at most one resize if nothing is dropped
let new_capacity = signatures.len() / 2;
for (user, signature) in
mem::replace(signatures, serde_json::Map::with_capacity(new_capacity))
{
let id = <&UserId>::try_from(user.as_str())
.map_err(|_| Error::bad_database("Invalid user ID in database."))?;
if id == user_id || allowed_signatures(id) {
signatures.insert(user, signature);
}
}
}
Ok(())
}