feat: swappable database backend
This commit is contained in:
parent
81715bd84d
commit
d0ee823254
47 changed files with 1434 additions and 981 deletions
309
src/database/abstraction.rs
Normal file
309
src/database/abstraction.rs
Normal file
|
@ -0,0 +1,309 @@
|
|||
use std::{
|
||||
collections::BTreeMap,
|
||||
future::Future,
|
||||
pin::Pin,
|
||||
sync::{Arc, RwLock},
|
||||
};
|
||||
|
||||
use log::warn;
|
||||
use rocksdb::{
|
||||
BoundColumnFamily, ColumnFamilyDescriptor, DBWithThreadMode, Direction, MultiThreaded, Options,
|
||||
};
|
||||
|
||||
use super::Config;
|
||||
use crate::{utils, Result};
|
||||
|
||||
pub struct SledEngine(sled::Db);
|
||||
pub struct SledEngineTree(sled::Tree);
|
||||
pub struct RocksDbEngine(rocksdb::DBWithThreadMode<MultiThreaded>);
|
||||
pub struct RocksDbEngineTree<'a> {
|
||||
db: Arc<RocksDbEngine>,
|
||||
name: &'a str,
|
||||
watchers: RwLock<BTreeMap<Vec<u8>, Vec<tokio::sync::oneshot::Sender<()>>>>,
|
||||
}
|
||||
|
||||
pub trait DatabaseEngine: Sized {
|
||||
fn open(config: &Config) -> Result<Arc<Self>>;
|
||||
fn open_tree(self: &Arc<Self>, name: &'static str) -> Result<Arc<dyn Tree>>;
|
||||
}
|
||||
|
||||
pub trait Tree: Send + Sync {
|
||||
fn get(&self, key: &[u8]) -> Result<Option<Vec<u8>>>;
|
||||
|
||||
fn insert(&self, key: &[u8], value: &[u8]) -> Result<()>;
|
||||
|
||||
fn remove(&self, key: &[u8]) -> Result<()>;
|
||||
|
||||
fn iter<'a>(&'a self) -> Box<dyn Iterator<Item = (Box<[u8]>, Box<[u8]>)> + Send + Sync + 'a>;
|
||||
|
||||
fn iter_from<'a>(
|
||||
&'a self,
|
||||
from: &[u8],
|
||||
backwards: bool,
|
||||
) -> Box<dyn Iterator<Item = (Box<[u8]>, Box<[u8]>)> + 'a>;
|
||||
|
||||
fn increment(&self, key: &[u8]) -> Result<Vec<u8>>;
|
||||
|
||||
fn scan_prefix<'a>(
|
||||
&'a self,
|
||||
prefix: Vec<u8>,
|
||||
) -> Box<dyn Iterator<Item = (Box<[u8]>, Box<[u8]>)> + 'a>;
|
||||
|
||||
fn watch_prefix<'a>(&'a self, prefix: &[u8]) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>>;
|
||||
|
||||
fn clear(&self) -> Result<()> {
|
||||
for (key, _) in self.iter() {
|
||||
self.remove(&key)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl DatabaseEngine for SledEngine {
|
||||
fn open(config: &Config) -> Result<Arc<Self>> {
|
||||
Ok(Arc::new(SledEngine(
|
||||
sled::Config::default()
|
||||
.path(&config.database_path)
|
||||
.cache_capacity(config.cache_capacity as u64)
|
||||
.use_compression(true)
|
||||
.open()?,
|
||||
)))
|
||||
}
|
||||
|
||||
fn open_tree(self: &Arc<Self>, name: &'static str) -> Result<Arc<dyn Tree>> {
|
||||
Ok(Arc::new(SledEngineTree(self.0.open_tree(name)?)))
|
||||
}
|
||||
}
|
||||
|
||||
impl Tree for SledEngineTree {
|
||||
fn get(&self, key: &[u8]) -> Result<Option<Vec<u8>>> {
|
||||
Ok(self.0.get(key)?.map(|v| v.to_vec()))
|
||||
}
|
||||
|
||||
fn insert(&self, key: &[u8], value: &[u8]) -> Result<()> {
|
||||
self.0.insert(key, value)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn remove(&self, key: &[u8]) -> Result<()> {
|
||||
self.0.remove(key)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn iter<'a>(&'a self) -> Box<dyn Iterator<Item = (Box<[u8]>, Box<[u8]>)> + Send + Sync + 'a> {
|
||||
Box::new(
|
||||
self.0
|
||||
.iter()
|
||||
.filter_map(|r| {
|
||||
if let Err(e) = &r {
|
||||
warn!("Error: {}", e);
|
||||
}
|
||||
r.ok()
|
||||
})
|
||||
.map(|(k, v)| (k.to_vec().into(), v.to_vec().into())),
|
||||
)
|
||||
}
|
||||
|
||||
fn iter_from(
|
||||
&self,
|
||||
from: &[u8],
|
||||
backwards: bool,
|
||||
) -> Box<dyn Iterator<Item = (Box<[u8]>, Box<[u8]>)>> {
|
||||
let iter = if backwards {
|
||||
self.0.range(..from)
|
||||
} else {
|
||||
self.0.range(from..)
|
||||
};
|
||||
|
||||
let iter = iter
|
||||
.filter_map(|r| {
|
||||
if let Err(e) = &r {
|
||||
warn!("Error: {}", e);
|
||||
}
|
||||
r.ok()
|
||||
})
|
||||
.map(|(k, v)| (k.to_vec().into(), v.to_vec().into()));
|
||||
|
||||
if backwards {
|
||||
Box::new(iter.rev())
|
||||
} else {
|
||||
Box::new(iter)
|
||||
}
|
||||
}
|
||||
|
||||
fn increment(&self, key: &[u8]) -> Result<Vec<u8>> {
|
||||
Ok(self
|
||||
.0
|
||||
.update_and_fetch(key, utils::increment)
|
||||
.map(|o| o.expect("increment always sets a value").to_vec())?)
|
||||
}
|
||||
|
||||
fn scan_prefix<'a>(
|
||||
&'a self,
|
||||
prefix: Vec<u8>,
|
||||
) -> Box<dyn Iterator<Item = (Box<[u8]>, Box<[u8]>)> + 'a> {
|
||||
let iter = self
|
||||
.0
|
||||
.scan_prefix(prefix)
|
||||
.filter_map(|r| {
|
||||
if let Err(e) = &r {
|
||||
warn!("Error: {}", e);
|
||||
}
|
||||
r.ok()
|
||||
})
|
||||
.map(|(k, v)| (k.to_vec().into(), v.to_vec().into()));
|
||||
|
||||
Box::new(iter)
|
||||
}
|
||||
|
||||
fn watch_prefix<'a>(&'a self, prefix: &[u8]) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>> {
|
||||
let prefix = prefix.to_vec();
|
||||
Box::pin(async move {
|
||||
self.0.watch_prefix(prefix).await;
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl DatabaseEngine for RocksDbEngine {
|
||||
fn open(config: &Config) -> Result<Arc<Self>> {
|
||||
let mut db_opts = Options::default();
|
||||
db_opts.create_if_missing(true);
|
||||
|
||||
let cfs = DBWithThreadMode::<MultiThreaded>::list_cf(&db_opts, &config.database_path)
|
||||
.unwrap_or_default();
|
||||
|
||||
let mut options = Options::default();
|
||||
options.set_merge_operator_associative("increment", utils::increment_rocksdb);
|
||||
|
||||
let db = DBWithThreadMode::<MultiThreaded>::open_cf_descriptors(
|
||||
&db_opts,
|
||||
&config.database_path,
|
||||
cfs.iter()
|
||||
.map(|name| ColumnFamilyDescriptor::new(name, options.clone())),
|
||||
)?;
|
||||
|
||||
Ok(Arc::new(RocksDbEngine(db)))
|
||||
}
|
||||
|
||||
fn open_tree(self: &Arc<Self>, name: &'static str) -> Result<Arc<dyn Tree>> {
|
||||
let mut options = Options::default();
|
||||
options.set_merge_operator_associative("increment", utils::increment_rocksdb);
|
||||
|
||||
// Create if it doesn't exist
|
||||
let _ = self.0.create_cf(name, &options);
|
||||
|
||||
Ok(Arc::new(RocksDbEngineTree {
|
||||
name,
|
||||
db: Arc::clone(self),
|
||||
watchers: RwLock::new(BTreeMap::new()),
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
impl RocksDbEngineTree<'_> {
|
||||
fn cf(&self) -> BoundColumnFamily<'_> {
|
||||
self.db.0.cf_handle(self.name).unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
impl Tree for RocksDbEngineTree<'_> {
|
||||
fn get(&self, key: &[u8]) -> Result<Option<Vec<u8>>> {
|
||||
Ok(self.db.0.get_cf(self.cf(), key)?)
|
||||
}
|
||||
|
||||
fn insert(&self, key: &[u8], value: &[u8]) -> Result<()> {
|
||||
let watchers = self.watchers.read().unwrap();
|
||||
let mut triggered = Vec::new();
|
||||
|
||||
for length in 0..=key.len() {
|
||||
if watchers.contains_key(&key[..length]) {
|
||||
triggered.push(&key[..length]);
|
||||
}
|
||||
}
|
||||
|
||||
drop(watchers);
|
||||
|
||||
if !triggered.is_empty() {
|
||||
let mut watchers = self.watchers.write().unwrap();
|
||||
for prefix in triggered {
|
||||
if let Some(txs) = watchers.remove(prefix) {
|
||||
for tx in txs {
|
||||
let _ = tx.send(());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(self.db.0.put_cf(self.cf(), key, value)?)
|
||||
}
|
||||
|
||||
fn remove(&self, key: &[u8]) -> Result<()> {
|
||||
Ok(self.db.0.delete_cf(self.cf(), key)?)
|
||||
}
|
||||
|
||||
fn iter<'a>(&'a self) -> Box<dyn Iterator<Item = (Box<[u8]>, Box<[u8]>)> + Send + Sync + 'a> {
|
||||
Box::new(
|
||||
self.db
|
||||
.0
|
||||
.iterator_cf(self.cf(), rocksdb::IteratorMode::Start),
|
||||
)
|
||||
}
|
||||
|
||||
fn iter_from<'a>(
|
||||
&'a self,
|
||||
from: &[u8],
|
||||
backwards: bool,
|
||||
) -> Box<dyn Iterator<Item = (Box<[u8]>, Box<[u8]>)> + 'a> {
|
||||
Box::new(self.db.0.iterator_cf(
|
||||
self.cf(),
|
||||
rocksdb::IteratorMode::From(
|
||||
from,
|
||||
if backwards {
|
||||
Direction::Reverse
|
||||
} else {
|
||||
Direction::Forward
|
||||
},
|
||||
),
|
||||
))
|
||||
}
|
||||
|
||||
fn increment(&self, key: &[u8]) -> Result<Vec<u8>> {
|
||||
// TODO: atomic?
|
||||
let old = self.get(key)?;
|
||||
let new = utils::increment(old.as_deref()).unwrap();
|
||||
self.insert(key, &new)?;
|
||||
Ok(new)
|
||||
}
|
||||
|
||||
fn scan_prefix<'a>(
|
||||
&'a self,
|
||||
prefix: Vec<u8>,
|
||||
) -> Box<dyn Iterator<Item = (Box<[u8]>, Box<[u8]>)> + 'a> {
|
||||
Box::new(
|
||||
self.db
|
||||
.0
|
||||
.iterator_cf(
|
||||
self.cf(),
|
||||
rocksdb::IteratorMode::From(&prefix, Direction::Forward),
|
||||
)
|
||||
.take_while(move |(k, _)| k.starts_with(&prefix)),
|
||||
)
|
||||
}
|
||||
|
||||
fn watch_prefix<'a>(&'a self, prefix: &[u8]) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>> {
|
||||
let (tx, rx) = tokio::sync::oneshot::channel();
|
||||
|
||||
self.watchers
|
||||
.write()
|
||||
.unwrap()
|
||||
.entry(prefix.to_vec())
|
||||
.or_default()
|
||||
.push(tx);
|
||||
|
||||
Box::pin(async move {
|
||||
// Tx is never destroyed
|
||||
rx.await.unwrap();
|
||||
})
|
||||
}
|
||||
}
|
|
@ -6,12 +6,12 @@ use ruma::{
|
|||
RoomId, UserId,
|
||||
};
|
||||
use serde::{de::DeserializeOwned, Serialize};
|
||||
use sled::IVec;
|
||||
use std::{collections::HashMap, convert::TryFrom};
|
||||
use std::{collections::HashMap, convert::TryFrom, sync::Arc};
|
||||
|
||||
use super::abstraction::Tree;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct AccountData {
|
||||
pub(super) roomuserdataid_accountdata: sled::Tree, // RoomUserDataId = Room + User + Count + Type
|
||||
pub(super) roomuserdataid_accountdata: Arc<dyn Tree>, // RoomUserDataId = Room + User + Count + Type
|
||||
}
|
||||
|
||||
impl AccountData {
|
||||
|
@ -34,9 +34,8 @@ impl AccountData {
|
|||
prefix.push(0xff);
|
||||
|
||||
// Remove old entry
|
||||
if let Some(previous) = self.find_event(room_id, user_id, &event_type) {
|
||||
let (old_key, _) = previous?;
|
||||
self.roomuserdataid_accountdata.remove(old_key)?;
|
||||
if let Some((old_key, _)) = self.find_event(room_id, user_id, &event_type)? {
|
||||
self.roomuserdataid_accountdata.remove(&old_key)?;
|
||||
}
|
||||
|
||||
let mut key = prefix;
|
||||
|
@ -52,8 +51,10 @@ impl AccountData {
|
|||
));
|
||||
}
|
||||
|
||||
self.roomuserdataid_accountdata
|
||||
.insert(key, &*json.to_string())?;
|
||||
self.roomuserdataid_accountdata.insert(
|
||||
&key,
|
||||
&serde_json::to_vec(&json).expect("to_vec always works on json values"),
|
||||
)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
@ -65,9 +66,8 @@ impl AccountData {
|
|||
user_id: &UserId,
|
||||
kind: EventType,
|
||||
) -> Result<Option<T>> {
|
||||
self.find_event(room_id, user_id, &kind)
|
||||
.map(|r| {
|
||||
let (_, v) = r?;
|
||||
self.find_event(room_id, user_id, &kind)?
|
||||
.map(|(_, v)| {
|
||||
serde_json::from_slice(&v).map_err(|_| Error::bad_database("could not deserialize"))
|
||||
})
|
||||
.transpose()
|
||||
|
@ -98,8 +98,7 @@ impl AccountData {
|
|||
|
||||
for r in self
|
||||
.roomuserdataid_accountdata
|
||||
.range(&*first_possible..)
|
||||
.filter_map(|r| r.ok())
|
||||
.iter_from(&first_possible, false)
|
||||
.take_while(move |(k, _)| k.starts_with(&prefix))
|
||||
.map(|(k, v)| {
|
||||
Ok::<_, Error>((
|
||||
|
@ -128,7 +127,7 @@ impl AccountData {
|
|||
room_id: Option<&RoomId>,
|
||||
user_id: &UserId,
|
||||
kind: &EventType,
|
||||
) -> Option<Result<(IVec, IVec)>> {
|
||||
) -> Result<Option<(Box<[u8]>, Box<[u8]>)>> {
|
||||
let mut prefix = room_id
|
||||
.map(|r| r.to_string())
|
||||
.unwrap_or_default()
|
||||
|
@ -137,23 +136,21 @@ impl AccountData {
|
|||
prefix.push(0xff);
|
||||
prefix.extend_from_slice(&user_id.as_bytes());
|
||||
prefix.push(0xff);
|
||||
|
||||
let mut last_possible_key = prefix.clone();
|
||||
last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes());
|
||||
|
||||
let kind = kind.clone();
|
||||
|
||||
self.roomuserdataid_accountdata
|
||||
.scan_prefix(prefix)
|
||||
.rev()
|
||||
.find(move |r| {
|
||||
r.as_ref()
|
||||
.map(|(k, _)| {
|
||||
k.rsplit(|&b| b == 0xff)
|
||||
.next()
|
||||
.map(|current_event_type| {
|
||||
current_event_type == kind.as_ref().as_bytes()
|
||||
})
|
||||
.unwrap_or(false)
|
||||
})
|
||||
Ok(self
|
||||
.roomuserdataid_accountdata
|
||||
.iter_from(&last_possible_key, true)
|
||||
.take_while(move |(k, _)| k.starts_with(&prefix))
|
||||
.find(move |(k, _)| {
|
||||
k.rsplit(|&b| b == 0xff)
|
||||
.next()
|
||||
.map(|current_event_type| current_event_type == kind.as_ref().as_bytes())
|
||||
.unwrap_or(false)
|
||||
})
|
||||
.map(|r| Ok(r?))
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,6 +1,9 @@
|
|||
use std::convert::{TryFrom, TryInto};
|
||||
use std::{
|
||||
convert::{TryFrom, TryInto},
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
use crate::pdu::PduBuilder;
|
||||
use crate::{pdu::PduBuilder, Database};
|
||||
use log::warn;
|
||||
use rocket::futures::{channel::mpsc, stream::StreamExt};
|
||||
use ruma::{
|
||||
|
@ -22,7 +25,7 @@ pub struct Admin {
|
|||
impl Admin {
|
||||
pub fn start_handler(
|
||||
&self,
|
||||
db: super::Database,
|
||||
db: Arc<Database>,
|
||||
mut receiver: mpsc::UnboundedReceiver<AdminCommand>,
|
||||
) {
|
||||
tokio::spawn(async move {
|
||||
|
@ -73,14 +76,17 @@ impl Admin {
|
|||
db.appservice.register_appservice(yaml).unwrap(); // TODO handle error
|
||||
}
|
||||
AdminCommand::ListAppservices => {
|
||||
let appservices = db.appservice.iter_ids().collect::<Vec<_>>();
|
||||
let count = appservices.len();
|
||||
let output = format!(
|
||||
"Appservices ({}): {}",
|
||||
count,
|
||||
appservices.into_iter().filter_map(|r| r.ok()).collect::<Vec<_>>().join(", ")
|
||||
);
|
||||
send_message(message::MessageEventContent::text_plain(output));
|
||||
if let Ok(appservices) = db.appservice.iter_ids().map(|ids| ids.collect::<Vec<_>>()) {
|
||||
let count = appservices.len();
|
||||
let output = format!(
|
||||
"Appservices ({}): {}",
|
||||
count,
|
||||
appservices.into_iter().filter_map(|r| r.ok()).collect::<Vec<_>>().join(", ")
|
||||
);
|
||||
send_message(message::MessageEventContent::text_plain(output));
|
||||
} else {
|
||||
send_message(message::MessageEventContent::text_plain("Failed to get appservices."));
|
||||
}
|
||||
}
|
||||
AdminCommand::SendMessage(message) => {
|
||||
send_message(message);
|
||||
|
@ -93,6 +99,6 @@ impl Admin {
|
|||
}
|
||||
|
||||
pub fn send(&self, command: AdminCommand) {
|
||||
self.sender.unbounded_send(command).unwrap()
|
||||
self.sender.unbounded_send(command).unwrap();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -4,18 +4,21 @@ use std::{
|
|||
sync::{Arc, RwLock},
|
||||
};
|
||||
|
||||
#[derive(Clone)]
|
||||
use super::abstraction::Tree;
|
||||
|
||||
pub struct Appservice {
|
||||
pub(super) cached_registrations: Arc<RwLock<HashMap<String, serde_yaml::Value>>>,
|
||||
pub(super) id_appserviceregistrations: sled::Tree,
|
||||
pub(super) id_appserviceregistrations: Arc<dyn Tree>,
|
||||
}
|
||||
|
||||
impl Appservice {
|
||||
pub fn register_appservice(&self, yaml: serde_yaml::Value) -> Result<()> {
|
||||
// TODO: Rumaify
|
||||
let id = yaml.get("id").unwrap().as_str().unwrap();
|
||||
self.id_appserviceregistrations
|
||||
.insert(id, serde_yaml::to_string(&yaml).unwrap().as_bytes())?;
|
||||
self.id_appserviceregistrations.insert(
|
||||
id.as_bytes(),
|
||||
serde_yaml::to_string(&yaml).unwrap().as_bytes(),
|
||||
)?;
|
||||
self.cached_registrations
|
||||
.write()
|
||||
.unwrap()
|
||||
|
@ -33,7 +36,7 @@ impl Appservice {
|
|||
|| {
|
||||
Ok(self
|
||||
.id_appserviceregistrations
|
||||
.get(id)?
|
||||
.get(id.as_bytes())?
|
||||
.map(|bytes| {
|
||||
Ok::<_, Error>(serde_yaml::from_slice(&bytes).map_err(|_| {
|
||||
Error::bad_database(
|
||||
|
@ -47,21 +50,25 @@ impl Appservice {
|
|||
)
|
||||
}
|
||||
|
||||
pub fn iter_ids(&self) -> impl Iterator<Item = Result<String>> {
|
||||
self.id_appserviceregistrations.iter().keys().map(|id| {
|
||||
Ok(utils::string_from_bytes(&id?).map_err(|_| {
|
||||
pub fn iter_ids<'a>(
|
||||
&'a self,
|
||||
) -> Result<impl Iterator<Item = Result<String>> + Send + Sync + 'a> {
|
||||
Ok(self.id_appserviceregistrations.iter().map(|(id, _)| {
|
||||
Ok(utils::string_from_bytes(&id).map_err(|_| {
|
||||
Error::bad_database("Invalid id bytes in id_appserviceregistrations.")
|
||||
})?)
|
||||
})
|
||||
}))
|
||||
}
|
||||
|
||||
pub fn iter_all(&self) -> impl Iterator<Item = Result<(String, serde_yaml::Value)>> + '_ {
|
||||
self.iter_ids().filter_map(|id| id.ok()).map(move |id| {
|
||||
pub fn iter_all(
|
||||
&self,
|
||||
) -> Result<impl Iterator<Item = Result<(String, serde_yaml::Value)>> + '_ + Send + Sync> {
|
||||
Ok(self.iter_ids()?.filter_map(|id| id.ok()).map(move |id| {
|
||||
Ok((
|
||||
id.clone(),
|
||||
self.get_registration(&id)?
|
||||
.expect("iter_ids only returns appservices that exist"),
|
||||
))
|
||||
})
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -13,22 +13,23 @@ use std::{
|
|||
use tokio::sync::Semaphore;
|
||||
use trust_dns_resolver::TokioAsyncResolver;
|
||||
|
||||
pub const COUNTER: &str = "c";
|
||||
use super::abstraction::Tree;
|
||||
|
||||
pub const COUNTER: &[u8] = b"c";
|
||||
|
||||
type WellKnownMap = HashMap<Box<ServerName>, (String, String)>;
|
||||
type TlsNameMap = HashMap<String, webpki::DNSName>;
|
||||
type RateLimitState = (Instant, u32); // Time if last failed try, number of failed tries
|
||||
#[derive(Clone)]
|
||||
pub struct Globals {
|
||||
pub actual_destination_cache: Arc<RwLock<WellKnownMap>>, // actual_destination, host
|
||||
pub tls_name_override: Arc<RwLock<TlsNameMap>>,
|
||||
pub(super) globals: sled::Tree,
|
||||
pub(super) globals: Arc<dyn Tree>,
|
||||
config: Config,
|
||||
keypair: Arc<ruma::signatures::Ed25519KeyPair>,
|
||||
reqwest_client: reqwest::Client,
|
||||
dns_resolver: TokioAsyncResolver,
|
||||
jwt_decoding_key: Option<jsonwebtoken::DecodingKey<'static>>,
|
||||
pub(super) server_signingkeys: sled::Tree,
|
||||
pub(super) server_signingkeys: Arc<dyn Tree>,
|
||||
pub bad_event_ratelimiter: Arc<RwLock<BTreeMap<EventId, RateLimitState>>>,
|
||||
pub bad_signature_ratelimiter: Arc<RwLock<BTreeMap<Vec<String>, RateLimitState>>>,
|
||||
pub servername_ratelimiter: Arc<RwLock<BTreeMap<Box<ServerName>, Arc<Semaphore>>>>,
|
||||
|
@ -69,15 +70,20 @@ impl ServerCertVerifier for MatrixServerVerifier {
|
|||
|
||||
impl Globals {
|
||||
pub fn load(
|
||||
globals: sled::Tree,
|
||||
server_signingkeys: sled::Tree,
|
||||
globals: Arc<dyn Tree>,
|
||||
server_signingkeys: Arc<dyn Tree>,
|
||||
config: Config,
|
||||
) -> Result<Self> {
|
||||
let bytes = &*globals
|
||||
.update_and_fetch("keypair", utils::generate_keypair)?
|
||||
.expect("utils::generate_keypair always returns Some");
|
||||
let keypair_bytes = globals.get(b"keypair")?.map_or_else(
|
||||
|| {
|
||||
let keypair = utils::generate_keypair();
|
||||
globals.insert(b"keypair", &keypair)?;
|
||||
Ok::<_, Error>(keypair)
|
||||
},
|
||||
|s| Ok(s.to_vec()),
|
||||
)?;
|
||||
|
||||
let mut parts = bytes.splitn(2, |&b| b == 0xff);
|
||||
let mut parts = keypair_bytes.splitn(2, |&b| b == 0xff);
|
||||
|
||||
let keypair = utils::string_from_bytes(
|
||||
// 1. version
|
||||
|
@ -102,7 +108,7 @@ impl Globals {
|
|||
Ok(k) => k,
|
||||
Err(e) => {
|
||||
error!("Keypair invalid. Deleting...");
|
||||
globals.remove("keypair")?;
|
||||
globals.remove(b"keypair")?;
|
||||
return Err(e);
|
||||
}
|
||||
};
|
||||
|
@ -159,13 +165,8 @@ impl Globals {
|
|||
}
|
||||
|
||||
pub fn next_count(&self) -> Result<u64> {
|
||||
Ok(utils::u64_from_bytes(
|
||||
&self
|
||||
.globals
|
||||
.update_and_fetch(COUNTER, utils::increment)?
|
||||
.expect("utils::increment will always put in a value"),
|
||||
)
|
||||
.map_err(|_| Error::bad_database("Count has invalid bytes."))?)
|
||||
Ok(utils::u64_from_bytes(&self.globals.increment(COUNTER)?)
|
||||
.map_err(|_| Error::bad_database("Count has invalid bytes."))?)
|
||||
}
|
||||
|
||||
pub fn current_count(&self) -> Result<u64> {
|
||||
|
@ -211,21 +212,30 @@ impl Globals {
|
|||
/// Remove the outdated keys and insert the new ones.
|
||||
///
|
||||
/// This doesn't actually check that the keys provided are newer than the old set.
|
||||
pub fn add_signing_key(&self, origin: &ServerName, new_keys: &ServerSigningKeys) -> Result<()> {
|
||||
self.server_signingkeys
|
||||
.update_and_fetch(origin.as_bytes(), |signingkeys| {
|
||||
let mut keys = signingkeys
|
||||
.and_then(|keys| serde_json::from_slice(keys).ok())
|
||||
.unwrap_or_else(|| {
|
||||
// Just insert "now", it doesn't matter
|
||||
ServerSigningKeys::new(origin.to_owned(), MilliSecondsSinceUnixEpoch::now())
|
||||
});
|
||||
keys.verify_keys
|
||||
.extend(new_keys.verify_keys.clone().into_iter());
|
||||
keys.old_verify_keys
|
||||
.extend(new_keys.old_verify_keys.clone().into_iter());
|
||||
Some(serde_json::to_vec(&keys).expect("serversigningkeys can be serialized"))
|
||||
})?;
|
||||
pub fn add_signing_key(&self, origin: &ServerName, new_keys: ServerSigningKeys) -> Result<()> {
|
||||
// Not atomic, but this is not critical
|
||||
let signingkeys = self.server_signingkeys.get(origin.as_bytes())?;
|
||||
|
||||
let mut keys = signingkeys
|
||||
.and_then(|keys| serde_json::from_slice(&keys).ok())
|
||||
.unwrap_or_else(|| {
|
||||
// Just insert "now", it doesn't matter
|
||||
ServerSigningKeys::new(origin.to_owned(), MilliSecondsSinceUnixEpoch::now())
|
||||
});
|
||||
|
||||
let ServerSigningKeys {
|
||||
verify_keys,
|
||||
old_verify_keys,
|
||||
..
|
||||
} = new_keys;
|
||||
|
||||
keys.verify_keys.extend(verify_keys.into_iter());
|
||||
keys.old_verify_keys.extend(old_verify_keys.into_iter());
|
||||
|
||||
self.server_signingkeys.insert(
|
||||
origin.as_bytes(),
|
||||
&serde_json::to_vec(&keys).expect("serversigningkeys can be serialized"),
|
||||
)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
@ -254,14 +264,15 @@ impl Globals {
|
|||
}
|
||||
|
||||
pub fn database_version(&self) -> Result<u64> {
|
||||
self.globals.get("version")?.map_or(Ok(0), |version| {
|
||||
self.globals.get(b"version")?.map_or(Ok(0), |version| {
|
||||
utils::u64_from_bytes(&version)
|
||||
.map_err(|_| Error::bad_database("Database version id is invalid."))
|
||||
})
|
||||
}
|
||||
|
||||
pub fn bump_database_version(&self, new_version: u64) -> Result<()> {
|
||||
self.globals.insert("version", &new_version.to_be_bytes())?;
|
||||
self.globals
|
||||
.insert(b"version", &new_version.to_be_bytes())?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
|
|
@ -6,13 +6,14 @@ use ruma::{
|
|||
},
|
||||
RoomId, UserId,
|
||||
};
|
||||
use std::{collections::BTreeMap, convert::TryFrom};
|
||||
use std::{collections::BTreeMap, convert::TryFrom, sync::Arc};
|
||||
|
||||
use super::abstraction::Tree;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct KeyBackups {
|
||||
pub(super) backupid_algorithm: sled::Tree, // BackupId = UserId + Version(Count)
|
||||
pub(super) backupid_etag: sled::Tree, // BackupId = UserId + Version(Count)
|
||||
pub(super) backupkeyid_backup: sled::Tree, // BackupKeyId = UserId + Version + RoomId + SessionId
|
||||
pub(super) backupid_algorithm: Arc<dyn Tree>, // BackupId = UserId + Version(Count)
|
||||
pub(super) backupid_etag: Arc<dyn Tree>, // BackupId = UserId + Version(Count)
|
||||
pub(super) backupkeyid_backup: Arc<dyn Tree>, // BackupKeyId = UserId + Version + RoomId + SessionId
|
||||
}
|
||||
|
||||
impl KeyBackups {
|
||||
|
@ -30,8 +31,7 @@ impl KeyBackups {
|
|||
|
||||
self.backupid_algorithm.insert(
|
||||
&key,
|
||||
&*serde_json::to_string(backup_metadata)
|
||||
.expect("BackupAlgorithm::to_string always works"),
|
||||
&serde_json::to_vec(backup_metadata).expect("BackupAlgorithm::to_vec always works"),
|
||||
)?;
|
||||
self.backupid_etag
|
||||
.insert(&key, &globals.next_count()?.to_be_bytes())?;
|
||||
|
@ -48,13 +48,8 @@ impl KeyBackups {
|
|||
|
||||
key.push(0xff);
|
||||
|
||||
for outdated_key in self
|
||||
.backupkeyid_backup
|
||||
.scan_prefix(&key)
|
||||
.keys()
|
||||
.filter_map(|r| r.ok())
|
||||
{
|
||||
self.backupkeyid_backup.remove(outdated_key)?;
|
||||
for (outdated_key, _) in self.backupkeyid_backup.scan_prefix(key) {
|
||||
self.backupkeyid_backup.remove(&outdated_key)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
|
@ -80,8 +75,9 @@ impl KeyBackups {
|
|||
|
||||
self.backupid_algorithm.insert(
|
||||
&key,
|
||||
&*serde_json::to_string(backup_metadata)
|
||||
.expect("BackupAlgorithm::to_string always works"),
|
||||
&serde_json::to_string(backup_metadata)
|
||||
.expect("BackupAlgorithm::to_string always works")
|
||||
.as_bytes(),
|
||||
)?;
|
||||
self.backupid_etag
|
||||
.insert(&key, &globals.next_count()?.to_be_bytes())?;
|
||||
|
@ -91,11 +87,14 @@ impl KeyBackups {
|
|||
pub fn get_latest_backup(&self, user_id: &UserId) -> Result<Option<(String, BackupAlgorithm)>> {
|
||||
let mut prefix = user_id.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
let mut last_possible_key = prefix.clone();
|
||||
last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes());
|
||||
|
||||
self.backupid_algorithm
|
||||
.scan_prefix(&prefix)
|
||||
.last()
|
||||
.map_or(Ok(None), |r| {
|
||||
let (key, value) = r?;
|
||||
.iter_from(&last_possible_key, true)
|
||||
.take_while(move |(k, _)| k.starts_with(&prefix))
|
||||
.next()
|
||||
.map_or(Ok(None), |(key, value)| {
|
||||
let version = utils::string_from_bytes(
|
||||
key.rsplit(|&b| b == 0xff)
|
||||
.next()
|
||||
|
@ -117,10 +116,13 @@ impl KeyBackups {
|
|||
key.push(0xff);
|
||||
key.extend_from_slice(version.as_bytes());
|
||||
|
||||
self.backupid_algorithm.get(key)?.map_or(Ok(None), |bytes| {
|
||||
Ok(serde_json::from_slice(&bytes)
|
||||
.map_err(|_| Error::bad_database("Algorithm in backupid_algorithm is invalid."))?)
|
||||
})
|
||||
self.backupid_algorithm
|
||||
.get(&key)?
|
||||
.map_or(Ok(None), |bytes| {
|
||||
Ok(serde_json::from_slice(&bytes).map_err(|_| {
|
||||
Error::bad_database("Algorithm in backupid_algorithm is invalid.")
|
||||
})?)
|
||||
})
|
||||
}
|
||||
|
||||
pub fn add_key(
|
||||
|
@ -153,7 +155,7 @@ impl KeyBackups {
|
|||
|
||||
self.backupkeyid_backup.insert(
|
||||
&key,
|
||||
&*serde_json::to_string(&key_data).expect("KeyBackupData::to_string always works"),
|
||||
&serde_json::to_vec(&key_data).expect("KeyBackupData::to_vec always works"),
|
||||
)?;
|
||||
|
||||
Ok(())
|
||||
|
@ -164,7 +166,7 @@ impl KeyBackups {
|
|||
prefix.push(0xff);
|
||||
prefix.extend_from_slice(version.as_bytes());
|
||||
|
||||
Ok(self.backupkeyid_backup.scan_prefix(&prefix).count())
|
||||
Ok(self.backupkeyid_backup.scan_prefix(prefix).count())
|
||||
}
|
||||
|
||||
pub fn get_etag(&self, user_id: &UserId, version: &str) -> Result<String> {
|
||||
|
@ -194,33 +196,37 @@ impl KeyBackups {
|
|||
|
||||
let mut rooms = BTreeMap::<RoomId, RoomKeyBackup>::new();
|
||||
|
||||
for result in self.backupkeyid_backup.scan_prefix(&prefix).map(|r| {
|
||||
let (key, value) = r?;
|
||||
let mut parts = key.rsplit(|&b| b == 0xff);
|
||||
for result in self
|
||||
.backupkeyid_backup
|
||||
.scan_prefix(prefix)
|
||||
.map(|(key, value)| {
|
||||
let mut parts = key.rsplit(|&b| b == 0xff);
|
||||
|
||||
let session_id = utils::string_from_bytes(
|
||||
&parts
|
||||
.next()
|
||||
.ok_or_else(|| Error::bad_database("backupkeyid_backup key is invalid."))?,
|
||||
)
|
||||
.map_err(|_| Error::bad_database("backupkeyid_backup session_id is invalid."))?;
|
||||
let session_id =
|
||||
utils::string_from_bytes(&parts.next().ok_or_else(|| {
|
||||
Error::bad_database("backupkeyid_backup key is invalid.")
|
||||
})?)
|
||||
.map_err(|_| {
|
||||
Error::bad_database("backupkeyid_backup session_id is invalid.")
|
||||
})?;
|
||||
|
||||
let room_id = RoomId::try_from(
|
||||
utils::string_from_bytes(
|
||||
&parts
|
||||
.next()
|
||||
.ok_or_else(|| Error::bad_database("backupkeyid_backup key is invalid."))?,
|
||||
let room_id = RoomId::try_from(
|
||||
utils::string_from_bytes(&parts.next().ok_or_else(|| {
|
||||
Error::bad_database("backupkeyid_backup key is invalid.")
|
||||
})?)
|
||||
.map_err(|_| Error::bad_database("backupkeyid_backup room_id is invalid."))?,
|
||||
)
|
||||
.map_err(|_| Error::bad_database("backupkeyid_backup room_id is invalid."))?,
|
||||
)
|
||||
.map_err(|_| Error::bad_database("backupkeyid_backup room_id is invalid room id."))?;
|
||||
.map_err(|_| {
|
||||
Error::bad_database("backupkeyid_backup room_id is invalid room id.")
|
||||
})?;
|
||||
|
||||
let key_data = serde_json::from_slice(&value).map_err(|_| {
|
||||
Error::bad_database("KeyBackupData in backupkeyid_backup is invalid.")
|
||||
})?;
|
||||
let key_data = serde_json::from_slice(&value).map_err(|_| {
|
||||
Error::bad_database("KeyBackupData in backupkeyid_backup is invalid.")
|
||||
})?;
|
||||
|
||||
Ok::<_, Error>((room_id, session_id, key_data))
|
||||
}) {
|
||||
Ok::<_, Error>((room_id, session_id, key_data))
|
||||
})
|
||||
{
|
||||
let (room_id, session_id, key_data) = result?;
|
||||
rooms
|
||||
.entry(room_id)
|
||||
|
@ -239,7 +245,7 @@ impl KeyBackups {
|
|||
user_id: &UserId,
|
||||
version: &str,
|
||||
room_id: &RoomId,
|
||||
) -> BTreeMap<String, KeyBackupData> {
|
||||
) -> Result<BTreeMap<String, KeyBackupData>> {
|
||||
let mut prefix = user_id.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
prefix.extend_from_slice(version.as_bytes());
|
||||
|
@ -247,10 +253,10 @@ impl KeyBackups {
|
|||
prefix.extend_from_slice(room_id.as_bytes());
|
||||
prefix.push(0xff);
|
||||
|
||||
self.backupkeyid_backup
|
||||
.scan_prefix(&prefix)
|
||||
.map(|r| {
|
||||
let (key, value) = r?;
|
||||
Ok(self
|
||||
.backupkeyid_backup
|
||||
.scan_prefix(prefix)
|
||||
.map(|(key, value)| {
|
||||
let mut parts = key.rsplit(|&b| b == 0xff);
|
||||
|
||||
let session_id =
|
||||
|
@ -268,7 +274,7 @@ impl KeyBackups {
|
|||
Ok::<_, Error>((session_id, key_data))
|
||||
})
|
||||
.filter_map(|r| r.ok())
|
||||
.collect()
|
||||
.collect())
|
||||
}
|
||||
|
||||
pub fn get_session(
|
||||
|
@ -302,13 +308,8 @@ impl KeyBackups {
|
|||
key.extend_from_slice(&version.as_bytes());
|
||||
key.push(0xff);
|
||||
|
||||
for outdated_key in self
|
||||
.backupkeyid_backup
|
||||
.scan_prefix(&key)
|
||||
.keys()
|
||||
.filter_map(|r| r.ok())
|
||||
{
|
||||
self.backupkeyid_backup.remove(outdated_key)?;
|
||||
for (outdated_key, _) in self.backupkeyid_backup.scan_prefix(key) {
|
||||
self.backupkeyid_backup.remove(&outdated_key)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
|
@ -327,13 +328,8 @@ impl KeyBackups {
|
|||
key.extend_from_slice(&room_id.as_bytes());
|
||||
key.push(0xff);
|
||||
|
||||
for outdated_key in self
|
||||
.backupkeyid_backup
|
||||
.scan_prefix(&key)
|
||||
.keys()
|
||||
.filter_map(|r| r.ok())
|
||||
{
|
||||
self.backupkeyid_backup.remove(outdated_key)?;
|
||||
for (outdated_key, _) in self.backupkeyid_backup.scan_prefix(key) {
|
||||
self.backupkeyid_backup.remove(&outdated_key)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
|
@ -354,13 +350,8 @@ impl KeyBackups {
|
|||
key.push(0xff);
|
||||
key.extend_from_slice(&session_id.as_bytes());
|
||||
|
||||
for outdated_key in self
|
||||
.backupkeyid_backup
|
||||
.scan_prefix(&key)
|
||||
.keys()
|
||||
.filter_map(|r| r.ok())
|
||||
{
|
||||
self.backupkeyid_backup.remove(outdated_key)?;
|
||||
for (outdated_key, _) in self.backupkeyid_backup.scan_prefix(key) {
|
||||
self.backupkeyid_backup.remove(&outdated_key)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
|
|
|
@ -1,7 +1,9 @@
|
|||
use image::{imageops::FilterType, GenericImageView};
|
||||
|
||||
use crate::{utils, Error, Result};
|
||||
use std::mem;
|
||||
use std::{mem, sync::Arc};
|
||||
|
||||
use super::abstraction::Tree;
|
||||
|
||||
pub struct FileMeta {
|
||||
pub content_disposition: Option<String>,
|
||||
|
@ -9,9 +11,8 @@ pub struct FileMeta {
|
|||
pub file: Vec<u8>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Media {
|
||||
pub(super) mediaid_file: sled::Tree, // MediaId = MXC + WidthHeight + ContentDisposition + ContentType
|
||||
pub(super) mediaid_file: Arc<dyn Tree>, // MediaId = MXC + WidthHeight + ContentDisposition + ContentType
|
||||
}
|
||||
|
||||
impl Media {
|
||||
|
@ -42,7 +43,7 @@ impl Media {
|
|||
.unwrap_or_default(),
|
||||
);
|
||||
|
||||
self.mediaid_file.insert(key, file)?;
|
||||
self.mediaid_file.insert(&key, file)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
@ -76,7 +77,7 @@ impl Media {
|
|||
.unwrap_or_default(),
|
||||
);
|
||||
|
||||
self.mediaid_file.insert(key, file)?;
|
||||
self.mediaid_file.insert(&key, file)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
@ -89,8 +90,7 @@ impl Media {
|
|||
prefix.extend_from_slice(&0_u32.to_be_bytes()); // Height = 0 if it's not a thumbnail
|
||||
prefix.push(0xff);
|
||||
|
||||
if let Some(r) = self.mediaid_file.scan_prefix(&prefix).next() {
|
||||
let (key, file) = r?;
|
||||
if let Some((key, file)) = self.mediaid_file.scan_prefix(prefix).next() {
|
||||
let mut parts = key.rsplit(|&b| b == 0xff);
|
||||
|
||||
let content_type = parts
|
||||
|
@ -169,9 +169,8 @@ impl Media {
|
|||
original_prefix.extend_from_slice(&0_u32.to_be_bytes()); // Height = 0 if it's not a thumbnail
|
||||
original_prefix.push(0xff);
|
||||
|
||||
if let Some(r) = self.mediaid_file.scan_prefix(&thumbnail_prefix).next() {
|
||||
if let Some((key, file)) = self.mediaid_file.scan_prefix(thumbnail_prefix).next() {
|
||||
// Using saved thumbnail
|
||||
let (key, file) = r?;
|
||||
let mut parts = key.rsplit(|&b| b == 0xff);
|
||||
|
||||
let content_type = parts
|
||||
|
@ -202,10 +201,8 @@ impl Media {
|
|||
content_type,
|
||||
file: file.to_vec(),
|
||||
}))
|
||||
} else if let Some(r) = self.mediaid_file.scan_prefix(&original_prefix).next() {
|
||||
} else if let Some((key, file)) = self.mediaid_file.scan_prefix(original_prefix).next() {
|
||||
// Generate a thumbnail
|
||||
|
||||
let (key, file) = r?;
|
||||
let mut parts = key.rsplit(|&b| b == 0xff);
|
||||
|
||||
let content_type = parts
|
||||
|
@ -302,7 +299,7 @@ impl Media {
|
|||
widthheight,
|
||||
);
|
||||
|
||||
self.mediaid_file.insert(thumbnail_key, &*thumbnail_bytes)?;
|
||||
self.mediaid_file.insert(&thumbnail_key, &thumbnail_bytes)?;
|
||||
|
||||
Ok(Some(FileMeta {
|
||||
content_disposition,
|
||||
|
|
|
@ -14,23 +14,17 @@ use ruma::{
|
|||
push::{Action, PushConditionRoomCtx, PushFormat, Ruleset, Tweak},
|
||||
uint, UInt, UserId,
|
||||
};
|
||||
use sled::IVec;
|
||||
|
||||
use std::{convert::TryFrom, fmt::Debug, mem};
|
||||
use std::{convert::TryFrom, fmt::Debug, mem, sync::Arc};
|
||||
|
||||
use super::abstraction::Tree;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PushData {
|
||||
/// UserId + pushkey -> Pusher
|
||||
pub(super) senderkey_pusher: sled::Tree,
|
||||
pub(super) senderkey_pusher: Arc<dyn Tree>,
|
||||
}
|
||||
|
||||
impl PushData {
|
||||
pub fn new(db: &sled::Db) -> Result<Self> {
|
||||
Ok(Self {
|
||||
senderkey_pusher: db.open_tree("senderkey_pusher")?,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn set_pusher(&self, sender: &UserId, pusher: set_pusher::Pusher) -> Result<()> {
|
||||
let mut key = sender.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
|
@ -40,14 +34,14 @@ impl PushData {
|
|||
if pusher.kind.is_none() {
|
||||
return self
|
||||
.senderkey_pusher
|
||||
.remove(key)
|
||||
.remove(&key)
|
||||
.map(|_| ())
|
||||
.map_err(Into::into);
|
||||
}
|
||||
|
||||
self.senderkey_pusher.insert(
|
||||
key,
|
||||
&*serde_json::to_string(&pusher).expect("Pusher is valid JSON string"),
|
||||
&key,
|
||||
&serde_json::to_vec(&pusher).expect("Pusher is valid JSON value"),
|
||||
)?;
|
||||
|
||||
Ok(())
|
||||
|
@ -69,23 +63,21 @@ impl PushData {
|
|||
|
||||
self.senderkey_pusher
|
||||
.scan_prefix(prefix)
|
||||
.values()
|
||||
.map(|push| {
|
||||
let push = push.map_err(|_| Error::bad_database("Invalid push bytes in db."))?;
|
||||
.map(|(_, push)| {
|
||||
Ok(serde_json::from_slice(&*push)
|
||||
.map_err(|_| Error::bad_database("Invalid Pusher in db."))?)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn get_pusher_senderkeys(&self, sender: &UserId) -> impl Iterator<Item = Result<IVec>> {
|
||||
pub fn get_pusher_senderkeys<'a>(
|
||||
&'a self,
|
||||
sender: &UserId,
|
||||
) -> impl Iterator<Item = Box<[u8]>> + 'a {
|
||||
let mut prefix = sender.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
|
||||
self.senderkey_pusher
|
||||
.scan_prefix(prefix)
|
||||
.keys()
|
||||
.map(|r| Ok(r?))
|
||||
self.senderkey_pusher.scan_prefix(prefix).map(|(k, _)| k)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -1,4 +1,4 @@
|
|||
use crate::{utils, Error, Result};
|
||||
use crate::{database::abstraction::Tree, utils, Error, Result};
|
||||
use ruma::{
|
||||
events::{
|
||||
presence::{PresenceEvent, PresenceEventContent},
|
||||
|
@ -13,17 +13,17 @@ use std::{
|
|||
collections::{HashMap, HashSet},
|
||||
convert::{TryFrom, TryInto},
|
||||
mem,
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct RoomEdus {
|
||||
pub(in super::super) readreceiptid_readreceipt: sled::Tree, // ReadReceiptId = RoomId + Count + UserId
|
||||
pub(in super::super) roomuserid_privateread: sled::Tree, // RoomUserId = Room + User, PrivateRead = Count
|
||||
pub(in super::super) roomuserid_lastprivatereadupdate: sled::Tree, // LastPrivateReadUpdate = Count
|
||||
pub(in super::super) typingid_userid: sled::Tree, // TypingId = RoomId + TimeoutTime + Count
|
||||
pub(in super::super) roomid_lasttypingupdate: sled::Tree, // LastRoomTypingUpdate = Count
|
||||
pub(in super::super) presenceid_presence: sled::Tree, // PresenceId = RoomId + Count + UserId
|
||||
pub(in super::super) userid_lastpresenceupdate: sled::Tree, // LastPresenceUpdate = Count
|
||||
pub(in super::super) readreceiptid_readreceipt: Arc<dyn Tree>, // ReadReceiptId = RoomId + Count + UserId
|
||||
pub(in super::super) roomuserid_privateread: Arc<dyn Tree>, // RoomUserId = Room + User, PrivateRead = Count
|
||||
pub(in super::super) roomuserid_lastprivatereadupdate: Arc<dyn Tree>, // LastPrivateReadUpdate = Count
|
||||
pub(in super::super) typingid_userid: Arc<dyn Tree>, // TypingId = RoomId + TimeoutTime + Count
|
||||
pub(in super::super) roomid_lasttypingupdate: Arc<dyn Tree>, // LastRoomTypingUpdate = Count
|
||||
pub(in super::super) presenceid_presence: Arc<dyn Tree>, // PresenceId = RoomId + Count + UserId
|
||||
pub(in super::super) userid_lastpresenceupdate: Arc<dyn Tree>, // LastPresenceUpdate = Count
|
||||
}
|
||||
|
||||
impl RoomEdus {
|
||||
|
@ -38,15 +38,15 @@ impl RoomEdus {
|
|||
let mut prefix = room_id.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
|
||||
let mut last_possible_key = prefix.clone();
|
||||
last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes());
|
||||
|
||||
// Remove old entry
|
||||
if let Some(old) = self
|
||||
if let Some((old, _)) = self
|
||||
.readreceiptid_readreceipt
|
||||
.scan_prefix(&prefix)
|
||||
.keys()
|
||||
.rev()
|
||||
.filter_map(|r| r.ok())
|
||||
.take_while(|key| key.starts_with(&prefix))
|
||||
.find(|key| {
|
||||
.iter_from(&last_possible_key, true)
|
||||
.take_while(|(key, _)| key.starts_with(&prefix))
|
||||
.find(|(key, _)| {
|
||||
key.rsplit(|&b| b == 0xff)
|
||||
.next()
|
||||
.expect("rsplit always returns an element")
|
||||
|
@ -54,7 +54,7 @@ impl RoomEdus {
|
|||
})
|
||||
{
|
||||
// This is the old room_latest
|
||||
self.readreceiptid_readreceipt.remove(old)?;
|
||||
self.readreceiptid_readreceipt.remove(&old)?;
|
||||
}
|
||||
|
||||
let mut room_latest_id = prefix;
|
||||
|
@ -63,8 +63,8 @@ impl RoomEdus {
|
|||
room_latest_id.extend_from_slice(&user_id.as_bytes());
|
||||
|
||||
self.readreceiptid_readreceipt.insert(
|
||||
room_latest_id,
|
||||
&*serde_json::to_string(&event).expect("EduEvent::to_string always works"),
|
||||
&room_latest_id,
|
||||
&serde_json::to_vec(&event).expect("EduEvent::to_string always works"),
|
||||
)?;
|
||||
|
||||
Ok(())
|
||||
|
@ -72,13 +72,12 @@ impl RoomEdus {
|
|||
|
||||
/// Returns an iterator over the most recent read_receipts in a room that happened after the event with id `since`.
|
||||
#[tracing::instrument(skip(self))]
|
||||
pub fn readreceipts_since(
|
||||
&self,
|
||||
pub fn readreceipts_since<'a>(
|
||||
&'a self,
|
||||
room_id: &RoomId,
|
||||
since: u64,
|
||||
) -> Result<
|
||||
impl Iterator<Item = Result<(UserId, u64, Raw<ruma::events::AnySyncEphemeralRoomEvent>)>>,
|
||||
> {
|
||||
) -> impl Iterator<Item = Result<(UserId, u64, Raw<ruma::events::AnySyncEphemeralRoomEvent>)>> + 'a
|
||||
{
|
||||
let mut prefix = room_id.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
let prefix2 = prefix.clone();
|
||||
|
@ -86,10 +85,8 @@ impl RoomEdus {
|
|||
let mut first_possible_edu = prefix.clone();
|
||||
first_possible_edu.extend_from_slice(&(since + 1).to_be_bytes()); // +1 so we don't send the event at since
|
||||
|
||||
Ok(self
|
||||
.readreceiptid_readreceipt
|
||||
.range(&*first_possible_edu..)
|
||||
.filter_map(|r| r.ok())
|
||||
self.readreceiptid_readreceipt
|
||||
.iter_from(&first_possible_edu, false)
|
||||
.take_while(move |(k, _)| k.starts_with(&prefix2))
|
||||
.map(move |(k, v)| {
|
||||
let count =
|
||||
|
@ -115,7 +112,7 @@ impl RoomEdus {
|
|||
serde_json::value::to_raw_value(&json).expect("json is valid raw value"),
|
||||
),
|
||||
))
|
||||
}))
|
||||
})
|
||||
}
|
||||
|
||||
/// Sets a private read marker at `count`.
|
||||
|
@ -146,11 +143,13 @@ impl RoomEdus {
|
|||
key.push(0xff);
|
||||
key.extend_from_slice(&user_id.as_bytes());
|
||||
|
||||
self.roomuserid_privateread.get(key)?.map_or(Ok(None), |v| {
|
||||
Ok(Some(utils::u64_from_bytes(&v).map_err(|_| {
|
||||
Error::bad_database("Invalid private read marker bytes")
|
||||
})?))
|
||||
})
|
||||
self.roomuserid_privateread
|
||||
.get(&key)?
|
||||
.map_or(Ok(None), |v| {
|
||||
Ok(Some(utils::u64_from_bytes(&v).map_err(|_| {
|
||||
Error::bad_database("Invalid private read marker bytes")
|
||||
})?))
|
||||
})
|
||||
}
|
||||
|
||||
/// Returns the count of the last typing update in this room.
|
||||
|
@ -215,11 +214,10 @@ impl RoomEdus {
|
|||
// Maybe there are multiple ones from calling roomtyping_add multiple times
|
||||
for outdated_edu in self
|
||||
.typingid_userid
|
||||
.scan_prefix(&prefix)
|
||||
.filter_map(|r| r.ok())
|
||||
.filter(|(_, v)| v == user_id.as_bytes())
|
||||
.scan_prefix(prefix)
|
||||
.filter(|(_, v)| &**v == user_id.as_bytes())
|
||||
{
|
||||
self.typingid_userid.remove(outdated_edu.0)?;
|
||||
self.typingid_userid.remove(&outdated_edu.0)?;
|
||||
found_outdated = true;
|
||||
}
|
||||
|
||||
|
@ -247,10 +245,8 @@ impl RoomEdus {
|
|||
// Find all outdated edus before inserting a new one
|
||||
for outdated_edu in self
|
||||
.typingid_userid
|
||||
.scan_prefix(&prefix)
|
||||
.keys()
|
||||
.map(|key| {
|
||||
let key = key?;
|
||||
.scan_prefix(prefix)
|
||||
.map(|(key, _)| {
|
||||
Ok::<_, Error>((
|
||||
key.clone(),
|
||||
utils::u64_from_bytes(
|
||||
|
@ -265,7 +261,7 @@ impl RoomEdus {
|
|||
.take_while(|&(_, timestamp)| timestamp < current_timestamp)
|
||||
{
|
||||
// This is an outdated edu (time > timestamp)
|
||||
self.typingid_userid.remove(outdated_edu.0)?;
|
||||
self.typingid_userid.remove(&outdated_edu.0)?;
|
||||
found_outdated = true;
|
||||
}
|
||||
|
||||
|
@ -309,10 +305,9 @@ impl RoomEdus {
|
|||
for user_id in self
|
||||
.typingid_userid
|
||||
.scan_prefix(prefix)
|
||||
.values()
|
||||
.map(|user_id| {
|
||||
.map(|(_, user_id)| {
|
||||
Ok::<_, Error>(
|
||||
UserId::try_from(utils::string_from_bytes(&user_id?).map_err(|_| {
|
||||
UserId::try_from(utils::string_from_bytes(&user_id).map_err(|_| {
|
||||
Error::bad_database("User ID in typingid_userid is invalid unicode.")
|
||||
})?)
|
||||
.map_err(|_| Error::bad_database("User ID in typingid_userid is invalid."))?,
|
||||
|
@ -351,12 +346,12 @@ impl RoomEdus {
|
|||
presence_id.extend_from_slice(&presence.sender.as_bytes());
|
||||
|
||||
self.presenceid_presence.insert(
|
||||
presence_id,
|
||||
&*serde_json::to_string(&presence).expect("PresenceEvent can be serialized"),
|
||||
&presence_id,
|
||||
&serde_json::to_vec(&presence).expect("PresenceEvent can be serialized"),
|
||||
)?;
|
||||
|
||||
self.userid_lastpresenceupdate.insert(
|
||||
&user_id.as_bytes(),
|
||||
user_id.as_bytes(),
|
||||
&utils::millis_since_unix_epoch().to_be_bytes(),
|
||||
)?;
|
||||
|
||||
|
@ -403,7 +398,7 @@ impl RoomEdus {
|
|||
presence_id.extend_from_slice(&user_id.as_bytes());
|
||||
|
||||
self.presenceid_presence
|
||||
.get(presence_id)?
|
||||
.get(&presence_id)?
|
||||
.map(|value| {
|
||||
let mut presence = serde_json::from_slice::<PresenceEvent>(&value)
|
||||
.map_err(|_| Error::bad_database("Invalid presence event in db."))?;
|
||||
|
@ -438,7 +433,6 @@ impl RoomEdus {
|
|||
for (user_id_bytes, last_timestamp) in self
|
||||
.userid_lastpresenceupdate
|
||||
.iter()
|
||||
.filter_map(|r| r.ok())
|
||||
.filter_map(|(k, bytes)| {
|
||||
Some((
|
||||
k,
|
||||
|
@ -468,8 +462,8 @@ impl RoomEdus {
|
|||
presence_id.extend_from_slice(&user_id_bytes);
|
||||
|
||||
self.presenceid_presence.insert(
|
||||
presence_id,
|
||||
&*serde_json::to_string(&PresenceEvent {
|
||||
&presence_id,
|
||||
&serde_json::to_vec(&PresenceEvent {
|
||||
content: PresenceEventContent {
|
||||
avatar_url: None,
|
||||
currently_active: None,
|
||||
|
@ -515,8 +509,7 @@ impl RoomEdus {
|
|||
|
||||
for (key, value) in self
|
||||
.presenceid_presence
|
||||
.range(&*first_possible_edu..)
|
||||
.filter_map(|r| r.ok())
|
||||
.iter_from(&*first_possible_edu, false)
|
||||
.take_while(|(key, _)| key.starts_with(&prefix))
|
||||
{
|
||||
let user_id = UserId::try_from(
|
||||
|
|
|
@ -12,7 +12,10 @@ use crate::{
|
|||
use federation::transactions::send_transaction_message;
|
||||
use log::{error, warn};
|
||||
use ring::digest;
|
||||
use rocket::futures::stream::{FuturesUnordered, StreamExt};
|
||||
use rocket::futures::{
|
||||
channel::mpsc,
|
||||
stream::{FuturesUnordered, StreamExt},
|
||||
};
|
||||
use ruma::{
|
||||
api::{
|
||||
appservice,
|
||||
|
@ -27,9 +30,10 @@ use ruma::{
|
|||
receipt::ReceiptType,
|
||||
MilliSecondsSinceUnixEpoch, ServerName, UInt, UserId,
|
||||
};
|
||||
use sled::IVec;
|
||||
use tokio::{select, sync::Semaphore};
|
||||
|
||||
use super::abstraction::Tree;
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
|
||||
pub enum OutgoingKind {
|
||||
Appservice(Box<ServerName>),
|
||||
|
@ -70,13 +74,13 @@ pub enum SendingEventType {
|
|||
Edu(Vec<u8>),
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Sending {
|
||||
/// The state for a given state hash.
|
||||
pub(super) servername_educount: sled::Tree, // EduCount: Count of last EDU sync
|
||||
pub(super) servernamepduids: sled::Tree, // ServernamePduId = (+ / $)SenderKey / ServerName / UserId + PduId
|
||||
pub(super) servercurrentevents: sled::Tree, // ServerCurrentEvents = (+ / $)ServerName / UserId + PduId / (*)EduEvent
|
||||
pub(super) servername_educount: Arc<dyn Tree>, // EduCount: Count of last EDU sync
|
||||
pub(super) servernamepduids: Arc<dyn Tree>, // ServernamePduId = (+ / $)SenderKey / ServerName / UserId + PduId
|
||||
pub(super) servercurrentevents: Arc<dyn Tree>, // ServerCurrentEvents = (+ / $)ServerName / UserId + PduId / (*)EduEvent
|
||||
pub(super) maximum_requests: Arc<Semaphore>,
|
||||
pub sender: mpsc::UnboundedSender<Vec<u8>>,
|
||||
}
|
||||
|
||||
enum TransactionStatus {
|
||||
|
@ -86,28 +90,25 @@ enum TransactionStatus {
|
|||
}
|
||||
|
||||
impl Sending {
|
||||
pub fn start_handler(&self, db: &Database) {
|
||||
let servernamepduids = self.servernamepduids.clone();
|
||||
let servercurrentevents = self.servercurrentevents.clone();
|
||||
|
||||
pub fn start_handler(&self, db: Arc<Database>, mut receiver: mpsc::UnboundedReceiver<Vec<u8>>) {
|
||||
let db = db.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let mut futures = FuturesUnordered::new();
|
||||
|
||||
// Retry requests we could not finish yet
|
||||
let mut subscriber = servernamepduids.watch_prefix(b"");
|
||||
let mut current_transaction_status = HashMap::<Vec<u8>, TransactionStatus>::new();
|
||||
|
||||
// Retry requests we could not finish yet
|
||||
let mut initial_transactions = HashMap::<OutgoingKind, Vec<SendingEventType>>::new();
|
||||
for (key, outgoing_kind, event) in servercurrentevents
|
||||
.iter()
|
||||
.filter_map(|r| r.ok())
|
||||
.filter_map(|(key, _)| {
|
||||
Self::parse_servercurrentevent(&key)
|
||||
.ok()
|
||||
.map(|(k, e)| (key, k, e))
|
||||
})
|
||||
for (key, outgoing_kind, event) in
|
||||
db.sending
|
||||
.servercurrentevents
|
||||
.iter()
|
||||
.filter_map(|(key, _)| {
|
||||
Self::parse_servercurrentevent(&key)
|
||||
.ok()
|
||||
.map(|(k, e)| (key, k, e))
|
||||
})
|
||||
{
|
||||
let entry = initial_transactions
|
||||
.entry(outgoing_kind.clone())
|
||||
|
@ -118,7 +119,7 @@ impl Sending {
|
|||
"Dropping some current events: {:?} {:?} {:?}",
|
||||
key, outgoing_kind, event
|
||||
);
|
||||
servercurrentevents.remove(key).unwrap();
|
||||
db.sending.servercurrentevents.remove(&key).unwrap();
|
||||
continue;
|
||||
}
|
||||
|
||||
|
@ -137,20 +138,16 @@ impl Sending {
|
|||
match response {
|
||||
Ok(outgoing_kind) => {
|
||||
let prefix = outgoing_kind.get_prefix();
|
||||
for key in servercurrentevents
|
||||
.scan_prefix(&prefix)
|
||||
.keys()
|
||||
.filter_map(|r| r.ok())
|
||||
for (key, _) in db.sending.servercurrentevents
|
||||
.scan_prefix(prefix.clone())
|
||||
{
|
||||
servercurrentevents.remove(key).unwrap();
|
||||
db.sending.servercurrentevents.remove(&key).unwrap();
|
||||
}
|
||||
|
||||
// Find events that have been added since starting the last request
|
||||
let new_events = servernamepduids
|
||||
.scan_prefix(&prefix)
|
||||
.keys()
|
||||
.filter_map(|r| r.ok())
|
||||
.map(|k| {
|
||||
let new_events = db.sending.servernamepduids
|
||||
.scan_prefix(prefix.clone())
|
||||
.map(|(k, _)| {
|
||||
SendingEventType::Pdu(k[prefix.len()..].to_vec())
|
||||
})
|
||||
.take(30)
|
||||
|
@ -166,8 +163,8 @@ impl Sending {
|
|||
SendingEventType::Pdu(b) |
|
||||
SendingEventType::Edu(b) => {
|
||||
current_key.extend_from_slice(&b);
|
||||
servercurrentevents.insert(¤t_key, &[]).unwrap();
|
||||
servernamepduids.remove(¤t_key).unwrap();
|
||||
db.sending.servercurrentevents.insert(¤t_key, &[]).unwrap();
|
||||
db.sending.servernamepduids.remove(¤t_key).unwrap();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -195,18 +192,15 @@ impl Sending {
|
|||
}
|
||||
};
|
||||
},
|
||||
Some(event) = &mut subscriber => {
|
||||
// New sled version:
|
||||
//for (_tree, key, value_opt) in &event {
|
||||
// if value_opt.is_none() {
|
||||
// continue;
|
||||
// }
|
||||
|
||||
if let sled::Event::Insert { key, .. } = event {
|
||||
if let Ok((outgoing_kind, event)) = Self::parse_servercurrentevent(&key) {
|
||||
if let Some(events) = Self::select_events(&outgoing_kind, vec![(event, key)], &mut current_transaction_status, &servercurrentevents, &servernamepduids, &db) {
|
||||
futures.push(Self::handle_events(outgoing_kind, events, &db));
|
||||
}
|
||||
Some(key) = receiver.next() => {
|
||||
if let Ok((outgoing_kind, event)) = Self::parse_servercurrentevent(&key) {
|
||||
if let Ok(Some(events)) = Self::select_events(
|
||||
&outgoing_kind,
|
||||
vec![(event, key)],
|
||||
&mut current_transaction_status,
|
||||
&db
|
||||
) {
|
||||
futures.push(Self::handle_events(outgoing_kind, events, &db));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -217,12 +211,10 @@ impl Sending {
|
|||
|
||||
fn select_events(
|
||||
outgoing_kind: &OutgoingKind,
|
||||
new_events: Vec<(SendingEventType, IVec)>, // Events we want to send: event and full key
|
||||
new_events: Vec<(SendingEventType, Vec<u8>)>, // Events we want to send: event and full key
|
||||
current_transaction_status: &mut HashMap<Vec<u8>, TransactionStatus>,
|
||||
servercurrentevents: &sled::Tree,
|
||||
servernamepduids: &sled::Tree,
|
||||
db: &Database,
|
||||
) -> Option<Vec<SendingEventType>> {
|
||||
) -> Result<Option<Vec<SendingEventType>>> {
|
||||
let mut retry = false;
|
||||
let mut allow = true;
|
||||
|
||||
|
@ -252,29 +244,25 @@ impl Sending {
|
|||
.or_insert(TransactionStatus::Running);
|
||||
|
||||
if !allow {
|
||||
return None;
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let mut events = Vec::new();
|
||||
|
||||
if retry {
|
||||
// We retry the previous transaction
|
||||
for key in servercurrentevents
|
||||
.scan_prefix(&prefix)
|
||||
.keys()
|
||||
.filter_map(|r| r.ok())
|
||||
{
|
||||
for (key, _) in db.sending.servercurrentevents.scan_prefix(prefix) {
|
||||
if let Ok((_, e)) = Self::parse_servercurrentevent(&key) {
|
||||
events.push(e);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (e, full_key) in new_events {
|
||||
servercurrentevents.insert(&full_key, &[]).unwrap();
|
||||
db.sending.servercurrentevents.insert(&full_key, &[])?;
|
||||
|
||||
// If it was a PDU we have to unqueue it
|
||||
// TODO: don't try to unqueue EDUs
|
||||
servernamepduids.remove(&full_key).unwrap();
|
||||
db.sending.servernamepduids.remove(&full_key)?;
|
||||
|
||||
events.push(e);
|
||||
}
|
||||
|
@ -284,13 +272,12 @@ impl Sending {
|
|||
events.extend_from_slice(&select_edus);
|
||||
db.sending
|
||||
.servername_educount
|
||||
.insert(server_name.as_bytes(), &last_count.to_be_bytes())
|
||||
.unwrap();
|
||||
.insert(server_name.as_bytes(), &last_count.to_be_bytes())?;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Some(events)
|
||||
Ok(Some(events))
|
||||
}
|
||||
|
||||
pub fn select_edus(db: &Database, server: &ServerName) -> Result<(Vec<SendingEventType>, u64)> {
|
||||
|
@ -307,7 +294,7 @@ impl Sending {
|
|||
let mut max_edu_count = since;
|
||||
'outer: for room_id in db.rooms.server_rooms(server) {
|
||||
let room_id = room_id?;
|
||||
for r in db.rooms.edus.readreceipts_since(&room_id, since)? {
|
||||
for r in db.rooms.edus.readreceipts_since(&room_id, since) {
|
||||
let (user_id, count, read_receipt) = r?;
|
||||
|
||||
if count > max_edu_count {
|
||||
|
@ -372,12 +359,13 @@ impl Sending {
|
|||
}
|
||||
|
||||
#[tracing::instrument(skip(self))]
|
||||
pub fn send_push_pdu(&self, pdu_id: &[u8], senderkey: IVec) -> Result<()> {
|
||||
pub fn send_push_pdu(&self, pdu_id: &[u8], senderkey: Box<[u8]>) -> Result<()> {
|
||||
let mut key = b"$".to_vec();
|
||||
key.extend_from_slice(&senderkey);
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(pdu_id);
|
||||
self.servernamepduids.insert(key, b"")?;
|
||||
self.servernamepduids.insert(&key, b"")?;
|
||||
self.sender.unbounded_send(key).unwrap();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
@ -387,7 +375,8 @@ impl Sending {
|
|||
let mut key = server.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(pdu_id);
|
||||
self.servernamepduids.insert(key, b"")?;
|
||||
self.servernamepduids.insert(&key, b"")?;
|
||||
self.sender.unbounded_send(key).unwrap();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
@ -398,7 +387,8 @@ impl Sending {
|
|||
key.extend_from_slice(appservice_id.as_bytes());
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(pdu_id);
|
||||
self.servernamepduids.insert(key, b"")?;
|
||||
self.servernamepduids.insert(&key, b"")?;
|
||||
self.sender.unbounded_send(key).unwrap();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
@ -641,7 +631,7 @@ impl Sending {
|
|||
}
|
||||
}
|
||||
|
||||
fn parse_servercurrentevent(key: &IVec) -> Result<(OutgoingKind, SendingEventType)> {
|
||||
fn parse_servercurrentevent(key: &[u8]) -> Result<(OutgoingKind, SendingEventType)> {
|
||||
// Appservices start with a plus
|
||||
Ok::<_, Error>(if key.starts_with(b"+") {
|
||||
let mut parts = key[1..].splitn(2, |&b| b == 0xff);
|
||||
|
|
|
@ -1,10 +1,12 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use crate::Result;
|
||||
use ruma::{DeviceId, UserId};
|
||||
use sled::IVec;
|
||||
|
||||
#[derive(Clone)]
|
||||
use super::abstraction::Tree;
|
||||
|
||||
pub struct TransactionIds {
|
||||
pub(super) userdevicetxnid_response: sled::Tree, // Response can be empty (/sendToDevice) or the event id (/send)
|
||||
pub(super) userdevicetxnid_response: Arc<dyn Tree>, // Response can be empty (/sendToDevice) or the event id (/send)
|
||||
}
|
||||
|
||||
impl TransactionIds {
|
||||
|
@ -21,7 +23,7 @@ impl TransactionIds {
|
|||
key.push(0xff);
|
||||
key.extend_from_slice(txn_id.as_bytes());
|
||||
|
||||
self.userdevicetxnid_response.insert(key, data)?;
|
||||
self.userdevicetxnid_response.insert(&key, data)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
@ -31,7 +33,7 @@ impl TransactionIds {
|
|||
user_id: &UserId,
|
||||
device_id: Option<&DeviceId>,
|
||||
txn_id: &str,
|
||||
) -> Result<Option<IVec>> {
|
||||
) -> Result<Option<Vec<u8>>> {
|
||||
let mut key = user_id.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(device_id.map(|d| d.as_bytes()).unwrap_or_default());
|
||||
|
@ -39,6 +41,6 @@ impl TransactionIds {
|
|||
key.extend_from_slice(txn_id.as_bytes());
|
||||
|
||||
// If there's no entry, this is a new transaction
|
||||
Ok(self.userdevicetxnid_response.get(key)?)
|
||||
Ok(self.userdevicetxnid_response.get(&key)?)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use crate::{client_server::SESSION_ID_LENGTH, utils, Error, Result};
|
||||
use ruma::{
|
||||
api::client::{
|
||||
|
@ -8,10 +10,11 @@ use ruma::{
|
|||
DeviceId, UserId,
|
||||
};
|
||||
|
||||
#[derive(Clone)]
|
||||
use super::abstraction::Tree;
|
||||
|
||||
pub struct Uiaa {
|
||||
pub(super) userdevicesessionid_uiaainfo: sled::Tree, // User-interactive authentication
|
||||
pub(super) userdevicesessionid_uiaarequest: sled::Tree, // UiaaRequest = canonical json value
|
||||
pub(super) userdevicesessionid_uiaainfo: Arc<dyn Tree>, // User-interactive authentication
|
||||
pub(super) userdevicesessionid_uiaarequest: Arc<dyn Tree>, // UiaaRequest = canonical json value
|
||||
}
|
||||
|
||||
impl Uiaa {
|
||||
|
@ -185,7 +188,7 @@ impl Uiaa {
|
|||
|
||||
self.userdevicesessionid_uiaarequest.insert(
|
||||
&userdevicesessionid,
|
||||
&*serde_json::to_string(request).expect("json value to string always works"),
|
||||
&serde_json::to_vec(request).expect("json value to vec always works"),
|
||||
)?;
|
||||
|
||||
Ok(())
|
||||
|
@ -233,7 +236,7 @@ impl Uiaa {
|
|||
if let Some(uiaainfo) = uiaainfo {
|
||||
self.userdevicesessionid_uiaainfo.insert(
|
||||
&userdevicesessionid,
|
||||
&*serde_json::to_string(&uiaainfo).expect("UiaaInfo::to_string always works"),
|
||||
&serde_json::to_vec(&uiaainfo).expect("UiaaInfo::to_vec always works"),
|
||||
)?;
|
||||
} else {
|
||||
self.userdevicesessionid_uiaainfo
|
||||
|
|
|
@ -7,40 +7,41 @@ use ruma::{
|
|||
serde::Raw,
|
||||
DeviceId, DeviceKeyAlgorithm, DeviceKeyId, MilliSecondsSinceUnixEpoch, UInt, UserId,
|
||||
};
|
||||
use std::{collections::BTreeMap, convert::TryFrom, mem};
|
||||
use std::{collections::BTreeMap, convert::TryFrom, mem, sync::Arc};
|
||||
|
||||
use super::abstraction::Tree;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Users {
|
||||
pub(super) userid_password: sled::Tree,
|
||||
pub(super) userid_displayname: sled::Tree,
|
||||
pub(super) userid_avatarurl: sled::Tree,
|
||||
pub(super) userdeviceid_token: sled::Tree,
|
||||
pub(super) userdeviceid_metadata: sled::Tree, // This is also used to check if a device exists
|
||||
pub(super) userid_devicelistversion: sled::Tree, // DevicelistVersion = u64
|
||||
pub(super) token_userdeviceid: sled::Tree,
|
||||
pub(super) userid_password: Arc<dyn Tree>,
|
||||
pub(super) userid_displayname: Arc<dyn Tree>,
|
||||
pub(super) userid_avatarurl: Arc<dyn Tree>,
|
||||
pub(super) userdeviceid_token: Arc<dyn Tree>,
|
||||
pub(super) userdeviceid_metadata: Arc<dyn Tree>, // This is also used to check if a device exists
|
||||
pub(super) userid_devicelistversion: Arc<dyn Tree>, // DevicelistVersion = u64
|
||||
pub(super) token_userdeviceid: Arc<dyn Tree>,
|
||||
|
||||
pub(super) onetimekeyid_onetimekeys: sled::Tree, // OneTimeKeyId = UserId + DeviceKeyId
|
||||
pub(super) userid_lastonetimekeyupdate: sled::Tree, // LastOneTimeKeyUpdate = Count
|
||||
pub(super) keychangeid_userid: sled::Tree, // KeyChangeId = UserId/RoomId + Count
|
||||
pub(super) keyid_key: sled::Tree, // KeyId = UserId + KeyId (depends on key type)
|
||||
pub(super) userid_masterkeyid: sled::Tree,
|
||||
pub(super) userid_selfsigningkeyid: sled::Tree,
|
||||
pub(super) userid_usersigningkeyid: sled::Tree,
|
||||
pub(super) onetimekeyid_onetimekeys: Arc<dyn Tree>, // OneTimeKeyId = UserId + DeviceKeyId
|
||||
pub(super) userid_lastonetimekeyupdate: Arc<dyn Tree>, // LastOneTimeKeyUpdate = Count
|
||||
pub(super) keychangeid_userid: Arc<dyn Tree>, // KeyChangeId = UserId/RoomId + Count
|
||||
pub(super) keyid_key: Arc<dyn Tree>, // KeyId = UserId + KeyId (depends on key type)
|
||||
pub(super) userid_masterkeyid: Arc<dyn Tree>,
|
||||
pub(super) userid_selfsigningkeyid: Arc<dyn Tree>,
|
||||
pub(super) userid_usersigningkeyid: Arc<dyn Tree>,
|
||||
|
||||
pub(super) todeviceid_events: sled::Tree, // ToDeviceId = UserId + DeviceId + Count
|
||||
pub(super) todeviceid_events: Arc<dyn Tree>, // ToDeviceId = UserId + DeviceId + Count
|
||||
}
|
||||
|
||||
impl Users {
|
||||
/// Check if a user has an account on this homeserver.
|
||||
pub fn exists(&self, user_id: &UserId) -> Result<bool> {
|
||||
Ok(self.userid_password.contains_key(user_id.to_string())?)
|
||||
Ok(self.userid_password.get(user_id.as_bytes())?.is_some())
|
||||
}
|
||||
|
||||
/// Check if account is deactivated
|
||||
pub fn is_deactivated(&self, user_id: &UserId) -> Result<bool> {
|
||||
Ok(self
|
||||
.userid_password
|
||||
.get(user_id.to_string())?
|
||||
.get(user_id.as_bytes())?
|
||||
.ok_or(Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
"User does not exist.",
|
||||
|
@ -55,14 +56,14 @@ impl Users {
|
|||
}
|
||||
|
||||
/// Returns the number of users registered on this server.
|
||||
pub fn count(&self) -> usize {
|
||||
self.userid_password.iter().count()
|
||||
pub fn count(&self) -> Result<usize> {
|
||||
Ok(self.userid_password.iter().count())
|
||||
}
|
||||
|
||||
/// Find out which user an access token belongs to.
|
||||
pub fn find_from_token(&self, token: &str) -> Result<Option<(UserId, String)>> {
|
||||
self.token_userdeviceid
|
||||
.get(token)?
|
||||
.get(token.as_bytes())?
|
||||
.map_or(Ok(None), |bytes| {
|
||||
let mut parts = bytes.split(|&b| b == 0xff);
|
||||
let user_bytes = parts.next().ok_or_else(|| {
|
||||
|
@ -87,10 +88,10 @@ impl Users {
|
|||
}
|
||||
|
||||
/// Returns an iterator over all users on this homeserver.
|
||||
pub fn iter(&self) -> impl Iterator<Item = Result<UserId>> {
|
||||
self.userid_password.iter().keys().map(|bytes| {
|
||||
pub fn iter<'a>(&'a self) -> impl Iterator<Item = Result<UserId>> + 'a {
|
||||
self.userid_password.iter().map(|(bytes, _)| {
|
||||
Ok(
|
||||
UserId::try_from(utils::string_from_bytes(&bytes?).map_err(|_| {
|
||||
UserId::try_from(utils::string_from_bytes(&bytes).map_err(|_| {
|
||||
Error::bad_database("User ID in userid_password is invalid unicode.")
|
||||
})?)
|
||||
.map_err(|_| Error::bad_database("User ID in userid_password is invalid."))?,
|
||||
|
@ -101,7 +102,7 @@ impl Users {
|
|||
/// Returns the password hash for the given user.
|
||||
pub fn password_hash(&self, user_id: &UserId) -> Result<Option<String>> {
|
||||
self.userid_password
|
||||
.get(user_id.to_string())?
|
||||
.get(user_id.as_bytes())?
|
||||
.map_or(Ok(None), |bytes| {
|
||||
Ok(Some(utils::string_from_bytes(&bytes).map_err(|_| {
|
||||
Error::bad_database("Password hash in db is not valid string.")
|
||||
|
@ -113,7 +114,8 @@ impl Users {
|
|||
pub fn set_password(&self, user_id: &UserId, password: Option<&str>) -> Result<()> {
|
||||
if let Some(password) = password {
|
||||
if let Ok(hash) = utils::calculate_hash(&password) {
|
||||
self.userid_password.insert(user_id.to_string(), &*hash)?;
|
||||
self.userid_password
|
||||
.insert(user_id.as_bytes(), hash.as_bytes())?;
|
||||
Ok(())
|
||||
} else {
|
||||
Err(Error::BadRequest(
|
||||
|
@ -122,7 +124,7 @@ impl Users {
|
|||
))
|
||||
}
|
||||
} else {
|
||||
self.userid_password.insert(user_id.to_string(), "")?;
|
||||
self.userid_password.insert(user_id.as_bytes(), b"")?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
@ -130,7 +132,7 @@ impl Users {
|
|||
/// Returns the displayname of a user on this homeserver.
|
||||
pub fn displayname(&self, user_id: &UserId) -> Result<Option<String>> {
|
||||
self.userid_displayname
|
||||
.get(user_id.to_string())?
|
||||
.get(user_id.as_bytes())?
|
||||
.map_or(Ok(None), |bytes| {
|
||||
Ok(Some(utils::string_from_bytes(&bytes).map_err(|_| {
|
||||
Error::bad_database("Displayname in db is invalid.")
|
||||
|
@ -142,9 +144,9 @@ impl Users {
|
|||
pub fn set_displayname(&self, user_id: &UserId, displayname: Option<String>) -> Result<()> {
|
||||
if let Some(displayname) = displayname {
|
||||
self.userid_displayname
|
||||
.insert(user_id.to_string(), &*displayname)?;
|
||||
.insert(user_id.as_bytes(), displayname.as_bytes())?;
|
||||
} else {
|
||||
self.userid_displayname.remove(user_id.to_string())?;
|
||||
self.userid_displayname.remove(user_id.as_bytes())?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
|
@ -153,7 +155,7 @@ impl Users {
|
|||
/// Get a the avatar_url of a user.
|
||||
pub fn avatar_url(&self, user_id: &UserId) -> Result<Option<MxcUri>> {
|
||||
self.userid_avatarurl
|
||||
.get(user_id.to_string())?
|
||||
.get(user_id.as_bytes())?
|
||||
.map(|bytes| {
|
||||
let s = utils::string_from_bytes(&bytes)
|
||||
.map_err(|_| Error::bad_database("Avatar URL in db is invalid."))?;
|
||||
|
@ -166,9 +168,9 @@ impl Users {
|
|||
pub fn set_avatar_url(&self, user_id: &UserId, avatar_url: Option<MxcUri>) -> Result<()> {
|
||||
if let Some(avatar_url) = avatar_url {
|
||||
self.userid_avatarurl
|
||||
.insert(user_id.to_string(), avatar_url.to_string().as_str())?;
|
||||
.insert(user_id.as_bytes(), avatar_url.to_string().as_bytes())?;
|
||||
} else {
|
||||
self.userid_avatarurl.remove(user_id.to_string())?;
|
||||
self.userid_avatarurl.remove(user_id.as_bytes())?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
|
@ -190,19 +192,17 @@ impl Users {
|
|||
userdeviceid.extend_from_slice(device_id.as_bytes());
|
||||
|
||||
self.userid_devicelistversion
|
||||
.update_and_fetch(&user_id.as_bytes(), utils::increment)?
|
||||
.expect("utils::increment will always put in a value");
|
||||
.increment(user_id.as_bytes())?;
|
||||
|
||||
self.userdeviceid_metadata.insert(
|
||||
userdeviceid,
|
||||
serde_json::to_string(&Device {
|
||||
&userdeviceid,
|
||||
&serde_json::to_vec(&Device {
|
||||
device_id: device_id.into(),
|
||||
display_name: initial_device_display_name,
|
||||
last_seen_ip: None, // TODO
|
||||
last_seen_ts: Some(MilliSecondsSinceUnixEpoch::now()),
|
||||
})
|
||||
.expect("Device::to_string never fails.")
|
||||
.as_bytes(),
|
||||
.expect("Device::to_string never fails."),
|
||||
)?;
|
||||
|
||||
self.set_token(user_id, &device_id, token)?;
|
||||
|
@ -217,7 +217,8 @@ impl Users {
|
|||
userdeviceid.extend_from_slice(device_id.as_bytes());
|
||||
|
||||
// Remove tokens
|
||||
if let Some(old_token) = self.userdeviceid_token.remove(&userdeviceid)? {
|
||||
if let Some(old_token) = self.userdeviceid_token.get(&userdeviceid)? {
|
||||
self.userdeviceid_token.remove(&userdeviceid)?;
|
||||
self.token_userdeviceid.remove(&old_token)?;
|
||||
}
|
||||
|
||||
|
@ -225,15 +226,14 @@ impl Users {
|
|||
let mut prefix = userdeviceid.clone();
|
||||
prefix.push(0xff);
|
||||
|
||||
for key in self.todeviceid_events.scan_prefix(&prefix).keys() {
|
||||
self.todeviceid_events.remove(key?)?;
|
||||
for (key, _) in self.todeviceid_events.scan_prefix(prefix) {
|
||||
self.todeviceid_events.remove(&key)?;
|
||||
}
|
||||
|
||||
// TODO: Remove onetimekeys
|
||||
|
||||
self.userid_devicelistversion
|
||||
.update_and_fetch(&user_id.as_bytes(), utils::increment)?
|
||||
.expect("utils::increment will always put in a value");
|
||||
.increment(user_id.as_bytes())?;
|
||||
|
||||
self.userdeviceid_metadata.remove(&userdeviceid)?;
|
||||
|
||||
|
@ -241,16 +241,18 @@ impl Users {
|
|||
}
|
||||
|
||||
/// Returns an iterator over all device ids of this user.
|
||||
pub fn all_device_ids(&self, user_id: &UserId) -> impl Iterator<Item = Result<Box<DeviceId>>> {
|
||||
pub fn all_device_ids<'a>(
|
||||
&'a self,
|
||||
user_id: &UserId,
|
||||
) -> impl Iterator<Item = Result<Box<DeviceId>>> + 'a {
|
||||
let mut prefix = user_id.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
// All devices have metadata
|
||||
self.userdeviceid_metadata
|
||||
.scan_prefix(prefix)
|
||||
.keys()
|
||||
.map(|bytes| {
|
||||
.map(|(bytes, _)| {
|
||||
Ok(utils::string_from_bytes(
|
||||
&*bytes?
|
||||
&bytes
|
||||
.rsplit(|&b| b == 0xff)
|
||||
.next()
|
||||
.ok_or_else(|| Error::bad_database("UserDevice ID in db is invalid."))?,
|
||||
|
@ -271,13 +273,15 @@ impl Users {
|
|||
|
||||
// Remove old token
|
||||
if let Some(old_token) = self.userdeviceid_token.get(&userdeviceid)? {
|
||||
self.token_userdeviceid.remove(old_token)?;
|
||||
self.token_userdeviceid.remove(&old_token)?;
|
||||
// It will be removed from userdeviceid_token by the insert later
|
||||
}
|
||||
|
||||
// Assign token to user device combination
|
||||
self.userdeviceid_token.insert(&userdeviceid, &*token)?;
|
||||
self.token_userdeviceid.insert(token, userdeviceid)?;
|
||||
self.userdeviceid_token
|
||||
.insert(&userdeviceid, token.as_bytes())?;
|
||||
self.token_userdeviceid
|
||||
.insert(token.as_bytes(), &userdeviceid)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
@ -309,8 +313,7 @@ impl Users {
|
|||
|
||||
self.onetimekeyid_onetimekeys.insert(
|
||||
&key,
|
||||
&*serde_json::to_string(&one_time_key_value)
|
||||
.expect("OneTimeKey::to_string always works"),
|
||||
&serde_json::to_vec(&one_time_key_value).expect("OneTimeKey::to_vec always works"),
|
||||
)?;
|
||||
|
||||
self.userid_lastonetimekeyupdate
|
||||
|
@ -350,10 +353,9 @@ impl Users {
|
|||
.insert(&user_id.as_bytes(), &globals.next_count()?.to_be_bytes())?;
|
||||
|
||||
self.onetimekeyid_onetimekeys
|
||||
.scan_prefix(&prefix)
|
||||
.scan_prefix(prefix)
|
||||
.next()
|
||||
.map(|r| {
|
||||
let (key, value) = r?;
|
||||
.map(|(key, value)| {
|
||||
self.onetimekeyid_onetimekeys.remove(&key)?;
|
||||
|
||||
Ok((
|
||||
|
@ -383,21 +385,20 @@ impl Users {
|
|||
|
||||
let mut counts = BTreeMap::new();
|
||||
|
||||
for algorithm in self
|
||||
.onetimekeyid_onetimekeys
|
||||
.scan_prefix(&userdeviceid)
|
||||
.keys()
|
||||
.map(|bytes| {
|
||||
Ok::<_, Error>(
|
||||
serde_json::from_slice::<DeviceKeyId>(
|
||||
&*bytes?.rsplit(|&b| b == 0xff).next().ok_or_else(|| {
|
||||
Error::bad_database("OneTimeKey ID in db is invalid.")
|
||||
})?,
|
||||
for algorithm in
|
||||
self.onetimekeyid_onetimekeys
|
||||
.scan_prefix(userdeviceid)
|
||||
.map(|(bytes, _)| {
|
||||
Ok::<_, Error>(
|
||||
serde_json::from_slice::<DeviceKeyId>(
|
||||
&*bytes.rsplit(|&b| b == 0xff).next().ok_or_else(|| {
|
||||
Error::bad_database("OneTimeKey ID in db is invalid.")
|
||||
})?,
|
||||
)
|
||||
.map_err(|_| Error::bad_database("DeviceKeyId in db is invalid."))?
|
||||
.algorithm(),
|
||||
)
|
||||
.map_err(|_| Error::bad_database("DeviceKeyId in db is invalid."))?
|
||||
.algorithm(),
|
||||
)
|
||||
})
|
||||
})
|
||||
{
|
||||
*counts.entry(algorithm?).or_default() += UInt::from(1_u32);
|
||||
}
|
||||
|
@ -419,7 +420,7 @@ impl Users {
|
|||
|
||||
self.keyid_key.insert(
|
||||
&userdeviceid,
|
||||
&*serde_json::to_string(&device_keys).expect("DeviceKeys::to_string always works"),
|
||||
&serde_json::to_vec(&device_keys).expect("DeviceKeys::to_vec always works"),
|
||||
)?;
|
||||
|
||||
self.mark_device_key_update(user_id, rooms, globals)?;
|
||||
|
@ -460,11 +461,11 @@ impl Users {
|
|||
|
||||
self.keyid_key.insert(
|
||||
&master_key_key,
|
||||
&*serde_json::to_string(&master_key).expect("CrossSigningKey::to_string always works"),
|
||||
&serde_json::to_vec(&master_key).expect("CrossSigningKey::to_vec always works"),
|
||||
)?;
|
||||
|
||||
self.userid_masterkeyid
|
||||
.insert(&*user_id.to_string(), master_key_key)?;
|
||||
.insert(user_id.as_bytes(), &master_key_key)?;
|
||||
|
||||
// Self-signing key
|
||||
if let Some(self_signing_key) = self_signing_key {
|
||||
|
@ -486,12 +487,12 @@ impl Users {
|
|||
|
||||
self.keyid_key.insert(
|
||||
&self_signing_key_key,
|
||||
&*serde_json::to_string(&self_signing_key)
|
||||
.expect("CrossSigningKey::to_string always works"),
|
||||
&serde_json::to_vec(&self_signing_key)
|
||||
.expect("CrossSigningKey::to_vec always works"),
|
||||
)?;
|
||||
|
||||
self.userid_selfsigningkeyid
|
||||
.insert(&*user_id.to_string(), self_signing_key_key)?;
|
||||
.insert(user_id.as_bytes(), &self_signing_key_key)?;
|
||||
}
|
||||
|
||||
// User-signing key
|
||||
|
@ -514,12 +515,12 @@ impl Users {
|
|||
|
||||
self.keyid_key.insert(
|
||||
&user_signing_key_key,
|
||||
&*serde_json::to_string(&user_signing_key)
|
||||
.expect("CrossSigningKey::to_string always works"),
|
||||
&serde_json::to_vec(&user_signing_key)
|
||||
.expect("CrossSigningKey::to_vec always works"),
|
||||
)?;
|
||||
|
||||
self.userid_usersigningkeyid
|
||||
.insert(&*user_id.to_string(), user_signing_key_key)?;
|
||||
.insert(user_id.as_bytes(), &user_signing_key_key)?;
|
||||
}
|
||||
|
||||
self.mark_device_key_update(user_id, rooms, globals)?;
|
||||
|
@ -561,8 +562,7 @@ impl Users {
|
|||
|
||||
self.keyid_key.insert(
|
||||
&key,
|
||||
&*serde_json::to_string(&cross_signing_key)
|
||||
.expect("CrossSigningKey::to_string always works"),
|
||||
&serde_json::to_vec(&cross_signing_key).expect("CrossSigningKey::to_vec always works"),
|
||||
)?;
|
||||
|
||||
// TODO: Should we notify about this change?
|
||||
|
@ -572,24 +572,20 @@ impl Users {
|
|||
}
|
||||
|
||||
#[tracing::instrument(skip(self))]
|
||||
pub fn keys_changed(
|
||||
&self,
|
||||
pub fn keys_changed<'a>(
|
||||
&'a self,
|
||||
user_or_room_id: &str,
|
||||
from: u64,
|
||||
to: Option<u64>,
|
||||
) -> impl Iterator<Item = Result<UserId>> {
|
||||
) -> impl Iterator<Item = Result<UserId>> + 'a {
|
||||
let mut prefix = user_or_room_id.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
|
||||
let mut start = prefix.clone();
|
||||
start.extend_from_slice(&(from + 1).to_be_bytes());
|
||||
|
||||
let mut end = prefix.clone();
|
||||
end.extend_from_slice(&to.unwrap_or(u64::MAX).to_be_bytes());
|
||||
|
||||
self.keychangeid_userid
|
||||
.range(start..end)
|
||||
.filter_map(|r| r.ok())
|
||||
.iter_from(&start, false)
|
||||
.take_while(move |(k, _)| k.starts_with(&prefix))
|
||||
.map(|(_, bytes)| {
|
||||
Ok(
|
||||
|
@ -625,13 +621,13 @@ impl Users {
|
|||
key.push(0xff);
|
||||
key.extend_from_slice(&count);
|
||||
|
||||
self.keychangeid_userid.insert(key, &*user_id.to_string())?;
|
||||
self.keychangeid_userid.insert(&key, user_id.as_bytes())?;
|
||||
}
|
||||
|
||||
let mut key = user_id.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(&count);
|
||||
self.keychangeid_userid.insert(key, &*user_id.to_string())?;
|
||||
self.keychangeid_userid.insert(&key, user_id.as_bytes())?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
@ -645,7 +641,7 @@ impl Users {
|
|||
key.push(0xff);
|
||||
key.extend_from_slice(device_id.as_bytes());
|
||||
|
||||
self.keyid_key.get(key)?.map_or(Ok(None), |bytes| {
|
||||
self.keyid_key.get(&key)?.map_or(Ok(None), |bytes| {
|
||||
Ok(Some(serde_json::from_slice(&bytes).map_err(|_| {
|
||||
Error::bad_database("DeviceKeys in db are invalid.")
|
||||
})?))
|
||||
|
@ -658,9 +654,9 @@ impl Users {
|
|||
allowed_signatures: F,
|
||||
) -> Result<Option<CrossSigningKey>> {
|
||||
self.userid_masterkeyid
|
||||
.get(user_id.to_string())?
|
||||
.get(user_id.as_bytes())?
|
||||
.map_or(Ok(None), |key| {
|
||||
self.keyid_key.get(key)?.map_or(Ok(None), |bytes| {
|
||||
self.keyid_key.get(&key)?.map_or(Ok(None), |bytes| {
|
||||
let mut cross_signing_key = serde_json::from_slice::<CrossSigningKey>(&bytes)
|
||||
.map_err(|_| {
|
||||
Error::bad_database("CrossSigningKey in db is invalid.")
|
||||
|
@ -685,9 +681,9 @@ impl Users {
|
|||
allowed_signatures: F,
|
||||
) -> Result<Option<CrossSigningKey>> {
|
||||
self.userid_selfsigningkeyid
|
||||
.get(user_id.to_string())?
|
||||
.get(user_id.as_bytes())?
|
||||
.map_or(Ok(None), |key| {
|
||||
self.keyid_key.get(key)?.map_or(Ok(None), |bytes| {
|
||||
self.keyid_key.get(&key)?.map_or(Ok(None), |bytes| {
|
||||
let mut cross_signing_key = serde_json::from_slice::<CrossSigningKey>(&bytes)
|
||||
.map_err(|_| {
|
||||
Error::bad_database("CrossSigningKey in db is invalid.")
|
||||
|
@ -708,9 +704,9 @@ impl Users {
|
|||
|
||||
pub fn get_user_signing_key(&self, user_id: &UserId) -> Result<Option<CrossSigningKey>> {
|
||||
self.userid_usersigningkeyid
|
||||
.get(user_id.to_string())?
|
||||
.get(user_id.as_bytes())?
|
||||
.map_or(Ok(None), |key| {
|
||||
self.keyid_key.get(key)?.map_or(Ok(None), |bytes| {
|
||||
self.keyid_key.get(&key)?.map_or(Ok(None), |bytes| {
|
||||
Ok(Some(serde_json::from_slice(&bytes).map_err(|_| {
|
||||
Error::bad_database("CrossSigningKey in db is invalid.")
|
||||
})?))
|
||||
|
@ -740,7 +736,7 @@ impl Users {
|
|||
|
||||
self.todeviceid_events.insert(
|
||||
&key,
|
||||
&*serde_json::to_string(&json).expect("Map::to_string always works"),
|
||||
&serde_json::to_vec(&json).expect("Map::to_vec always works"),
|
||||
)?;
|
||||
|
||||
Ok(())
|
||||
|
@ -759,9 +755,9 @@ impl Users {
|
|||
prefix.extend_from_slice(device_id.as_bytes());
|
||||
prefix.push(0xff);
|
||||
|
||||
for value in self.todeviceid_events.scan_prefix(&prefix).values() {
|
||||
for (_, value) in self.todeviceid_events.scan_prefix(prefix) {
|
||||
events.push(
|
||||
serde_json::from_slice(&*value?)
|
||||
serde_json::from_slice(&value)
|
||||
.map_err(|_| Error::bad_database("Event in todeviceid_events is invalid."))?,
|
||||
);
|
||||
}
|
||||
|
@ -786,10 +782,9 @@ impl Users {
|
|||
|
||||
for (key, _) in self
|
||||
.todeviceid_events
|
||||
.range(&*prefix..=&*last)
|
||||
.keys()
|
||||
.map(|key| {
|
||||
let key = key?;
|
||||
.iter_from(&last, true)
|
||||
.take_while(move |(k, _)| k.starts_with(&prefix))
|
||||
.map(|(key, _)| {
|
||||
Ok::<_, Error>((
|
||||
key.clone(),
|
||||
utils::u64_from_bytes(&key[key.len() - mem::size_of::<u64>()..key.len()])
|
||||
|
@ -799,7 +794,7 @@ impl Users {
|
|||
.filter_map(|r| r.ok())
|
||||
.take_while(|&(_, count)| count <= until)
|
||||
{
|
||||
self.todeviceid_events.remove(key)?;
|
||||
self.todeviceid_events.remove(&key)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
|
@ -819,14 +814,11 @@ impl Users {
|
|||
assert!(self.userdeviceid_metadata.get(&userdeviceid)?.is_some());
|
||||
|
||||
self.userid_devicelistversion
|
||||
.update_and_fetch(&user_id.as_bytes(), utils::increment)?
|
||||
.expect("utils::increment will always put in a value");
|
||||
.increment(user_id.as_bytes())?;
|
||||
|
||||
self.userdeviceid_metadata.insert(
|
||||
userdeviceid,
|
||||
serde_json::to_string(device)
|
||||
.expect("Device::to_string always works")
|
||||
.as_bytes(),
|
||||
&userdeviceid,
|
||||
&serde_json::to_vec(device).expect("Device::to_string always works"),
|
||||
)?;
|
||||
|
||||
Ok(())
|
||||
|
@ -861,15 +853,17 @@ impl Users {
|
|||
})
|
||||
}
|
||||
|
||||
pub fn all_devices_metadata(&self, user_id: &UserId) -> impl Iterator<Item = Result<Device>> {
|
||||
pub fn all_devices_metadata<'a>(
|
||||
&'a self,
|
||||
user_id: &UserId,
|
||||
) -> impl Iterator<Item = Result<Device>> + 'a {
|
||||
let mut key = user_id.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
|
||||
self.userdeviceid_metadata
|
||||
.scan_prefix(key)
|
||||
.values()
|
||||
.map(|bytes| {
|
||||
Ok(serde_json::from_slice::<Device>(&bytes?).map_err(|_| {
|
||||
.map(|(_, bytes)| {
|
||||
Ok(serde_json::from_slice::<Device>(&bytes).map_err(|_| {
|
||||
Error::bad_database("Device in userdeviceid_metadata is invalid.")
|
||||
})?)
|
||||
})
|
||||
|
@ -885,7 +879,7 @@ impl Users {
|
|||
// Set the password to "" to indicate a deactivated account. Hashes will never result in an
|
||||
// empty string, so the user will not be able to log in again. Systems like changing the
|
||||
// password without logging in should check if the account is deactivated.
|
||||
self.userid_password.insert(user_id.to_string(), "")?;
|
||||
self.userid_password.insert(user_id.as_bytes(), &[])?;
|
||||
|
||||
// TODO: Unhook 3PID
|
||||
Ok(())
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue