improvement: better e2ee over fed, faster incoming event handling

This commit is contained in:
Timo Kösters 2021-08-24 19:10:31 +02:00
parent 72dd95f500
commit 81e056417c
No known key found for this signature in database
GPG key ID: 24DA7517711A2BA4
9 changed files with 407 additions and 256 deletions

View file

@ -23,13 +23,13 @@ use ruma::{
uint, EventId, RoomAliasId, RoomId, RoomVersionId, ServerName, UserId,
};
use std::{
collections::{BTreeMap, BTreeSet, HashMap, HashSet},
collections::{BTreeMap, HashMap, HashSet},
convert::{TryFrom, TryInto},
mem::size_of,
sync::{Arc, Mutex},
};
use tokio::sync::MutexGuard;
use tracing::{debug, error, warn};
use tracing::{error, warn};
use super::{abstraction::Tree, admin::AdminCommand, pusher};
@ -73,8 +73,8 @@ pub struct Rooms {
pub(super) shorteventid_shortstatehash: Arc<dyn Tree>,
/// StateKey = EventType + StateKey, ShortStateKey = Count
pub(super) statekey_shortstatekey: Arc<dyn Tree>,
pub(super) shortstatekey_statekey: Arc<dyn Tree>,
pub(super) shortroomid_roomid: Arc<dyn Tree>,
pub(super) roomid_shortroomid: Arc<dyn Tree>,
pub(super) shorteventid_eventid: Arc<dyn Tree>,
@ -95,6 +95,7 @@ pub struct Rooms {
pub(super) shorteventid_cache: Mutex<LruCache<u64, EventId>>,
pub(super) eventidshort_cache: Mutex<LruCache<EventId, u64>>,
pub(super) statekeyshort_cache: Mutex<LruCache<(EventType, String), u64>>,
pub(super) shortstatekey_cache: Mutex<LruCache<u64, (EventType, String)>>,
pub(super) stateinfo_cache: Mutex<
LruCache<
u64,
@ -112,7 +113,7 @@ impl Rooms {
/// Builds a StateMap by iterating over all keys that start
/// with state_hash, this gives the full state for the given state_hash.
#[tracing::instrument(skip(self))]
pub fn state_full_ids(&self, shortstatehash: u64) -> Result<BTreeSet<EventId>> {
pub fn state_full_ids(&self, shortstatehash: u64) -> Result<BTreeMap<u64, EventId>> {
let full_state = self
.load_shortstatehash_info(shortstatehash)?
.pop()
@ -138,7 +139,7 @@ impl Rooms {
.into_iter()
.map(|compressed| self.parse_compressed_state_event(compressed))
.filter_map(|r| r.ok())
.map(|eventid| self.get_pdu(&eventid))
.map(|(_, eventid)| self.get_pdu(&eventid))
.filter_map(|r| r.ok().flatten())
.map(|pdu| {
Ok::<_, Error>((
@ -176,7 +177,11 @@ impl Rooms {
Ok(full_state
.into_iter()
.find(|bytes| bytes.starts_with(&shortstatekey.to_be_bytes()))
.and_then(|compressed| self.parse_compressed_state_event(compressed).ok()))
.and_then(|compressed| {
self.parse_compressed_state_event(compressed)
.ok()
.map(|(_, id)| id)
}))
}
/// Returns a single PDU from `room_id` with key (`event_type`, `state_key`).
@ -232,6 +237,13 @@ impl Rooms {
state_key: Option<&str>,
content: &serde_json::Value,
) -> Result<StateMap<Arc<PduEvent>>> {
let shortstatehash =
if let Some(current_shortstatehash) = self.current_shortstatehash(room_id)? {
current_shortstatehash
} else {
return Ok(HashMap::new());
};
let auth_events = state_res::auth_types_for_event(
kind,
sender,
@ -239,19 +251,30 @@ impl Rooms {
content.clone(),
);
let mut events = StateMap::new();
for (event_type, state_key) in auth_events {
if let Some(pdu) = self.room_state_get(room_id, &event_type, &state_key)? {
events.insert((event_type, state_key), pdu);
} else {
// This is okay because when creating a new room some events were not created yet
debug!(
"{:?}: Could not find {} {:?} in state",
content, event_type, state_key
);
}
}
Ok(events)
let mut sauthevents = auth_events
.into_iter()
.filter_map(|(event_type, state_key)| {
self.get_shortstatekey(&event_type, &state_key)
.ok()
.flatten()
.map(|s| (s, (event_type, state_key)))
})
.collect::<HashMap<_, _>>();
let full_state = self
.load_shortstatehash_info(shortstatehash)?
.pop()
.expect("there is always one layer")
.1;
Ok(full_state
.into_iter()
.filter_map(|compressed| self.parse_compressed_state_event(compressed).ok())
.filter_map(|(shortstatekey, event_id)| {
sauthevents.remove(&shortstatekey).map(|k| (k, event_id))
})
.filter_map(|(k, event_id)| self.get_pdu(&event_id).ok().flatten().map(|pdu| (k, pdu)))
.collect())
}
/// Generate a new StateHash.
@ -306,32 +329,19 @@ impl Rooms {
/// Force the creation of a new StateHash and insert it into the db.
///
/// Whatever `state` is supplied to `force_state` becomes the new current room state snapshot.
#[tracing::instrument(skip(self, new_state, db))]
#[tracing::instrument(skip(self, new_state_ids_compressed, db))]
pub fn force_state(
&self,
room_id: &RoomId,
new_state: HashMap<(EventType, String), EventId>,
new_state_ids_compressed: HashSet<CompressedStateEvent>,
db: &Database,
) -> Result<()> {
let previous_shortstatehash = self.current_shortstatehash(&room_id)?;
let new_state_ids_compressed = new_state
.iter()
.filter_map(|((event_type, state_key), event_id)| {
let shortstatekey = self
.get_or_create_shortstatekey(event_type, state_key, &db.globals)
.ok()?;
Some(
self.compress_state_event(shortstatekey, event_id, &db.globals)
.ok()?,
)
})
.collect::<HashSet<_>>();
let state_hash = self.calculate_hash(
&new_state
.values()
.map(|event_id| event_id.as_bytes())
&new_state_ids_compressed
.iter()
.map(|bytes| &bytes[..])
.collect::<Vec<_>>(),
);
@ -373,10 +383,11 @@ impl Rooms {
)?;
};
for event_id in statediffnew
.into_iter()
.filter_map(|new| self.parse_compressed_state_event(new).ok())
{
for event_id in statediffnew.into_iter().filter_map(|new| {
self.parse_compressed_state_event(new)
.ok()
.map(|(_, id)| id)
}) {
if let Some(pdu) = self.get_pdu_json(&event_id)? {
if pdu.get("type").and_then(|val| val.as_str()) == Some("m.room.member") {
if let Ok(pdu) = serde_json::from_value::<PduEvent>(
@ -504,15 +515,20 @@ impl Rooms {
Ok(v.try_into().expect("we checked the size above"))
}
/// Returns shortstatekey, event id
#[tracing::instrument(skip(self, compressed_event))]
pub fn parse_compressed_state_event(
&self,
compressed_event: CompressedStateEvent,
) -> Result<EventId> {
self.get_eventid_from_short(
utils::u64_from_bytes(&compressed_event[size_of::<u64>()..])
) -> Result<(u64, EventId)> {
Ok((
utils::u64_from_bytes(&compressed_event[0..size_of::<u64>()])
.expect("bytes have right length"),
)
self.get_eventid_from_short(
utils::u64_from_bytes(&compressed_event[size_of::<u64>()..])
.expect("bytes have right length"),
)?,
))
}
/// Creates a new shortstatehash that often is just a diff to an already existing
@ -805,6 +821,8 @@ impl Rooms {
let shortstatekey = globals.next_count()?;
self.statekey_shortstatekey
.insert(&statekey, &shortstatekey.to_be_bytes())?;
self.shortstatekey_statekey
.insert(&shortstatekey.to_be_bytes(), &statekey)?;
shortstatekey
}
};
@ -833,11 +851,10 @@ impl Rooms {
.get(&shorteventid.to_be_bytes())?
.ok_or_else(|| Error::bad_database("Shorteventid does not exist"))?;
let event_id =
EventId::try_from(utils::string_from_bytes(&bytes).map_err(|_| {
Error::bad_database("EventID in roomid_pduleaves is invalid unicode.")
})?)
.map_err(|_| Error::bad_database("EventId in roomid_pduleaves is invalid."))?;
let event_id = EventId::try_from(utils::string_from_bytes(&bytes).map_err(|_| {
Error::bad_database("EventID in shorteventid_eventid is invalid unicode.")
})?)
.map_err(|_| Error::bad_database("EventId in shorteventid_eventid is invalid."))?;
self.shorteventid_cache
.lock()
@ -847,6 +864,48 @@ impl Rooms {
Ok(event_id)
}
#[tracing::instrument(skip(self))]
pub fn get_statekey_from_short(&self, shortstatekey: u64) -> Result<(EventType, String)> {
if let Some(id) = self
.shortstatekey_cache
.lock()
.unwrap()
.get_mut(&shortstatekey)
{
return Ok(id.clone());
}
let bytes = self
.shortstatekey_statekey
.get(&shortstatekey.to_be_bytes())?
.ok_or_else(|| Error::bad_database("Shortstatekey does not exist"))?;
let mut parts = bytes.splitn(2, |&b| b == 0xff);
let eventtype_bytes = parts.next().expect("split always returns one entry");
let statekey_bytes = parts
.next()
.ok_or_else(|| Error::bad_database("Invalid statekey in shortstatekey_statekey."))?;
let event_type =
EventType::try_from(utils::string_from_bytes(&eventtype_bytes).map_err(|_| {
Error::bad_database("Event type in shortstatekey_statekey is invalid unicode.")
})?)
.map_err(|_| Error::bad_database("Event type in shortstatekey_statekey is invalid."))?;
let state_key = utils::string_from_bytes(&statekey_bytes).map_err(|_| {
Error::bad_database("Statekey in shortstatekey_statekey is invalid unicode.")
})?;
let result = (event_type, state_key);
self.shortstatekey_cache
.lock()
.unwrap()
.insert(shortstatekey, result.clone());
Ok(result)
}
/// Returns the full room state.
#[tracing::instrument(skip(self))]
pub fn room_state_full(
@ -1106,6 +1165,17 @@ impl Rooms {
.collect()
}
#[tracing::instrument(skip(self, room_id, event_ids))]
pub fn mark_as_referenced(&self, room_id: &RoomId, event_ids: &[EventId]) -> Result<()> {
for prev in event_ids {
let mut key = room_id.as_bytes().to_vec();
key.extend_from_slice(prev.as_bytes());
self.referencedevents.insert(&key, &[])?;
}
Ok(())
}
/// Replace the leaves of a room.
///
/// The provided `event_ids` become the new leaves, this allows a room to have multiple
@ -1202,12 +1272,7 @@ impl Rooms {
}
// We must keep track of all events that have been referenced.
for prev in &pdu.prev_events {
let mut key = pdu.room_id().as_bytes().to_vec();
key.extend_from_slice(prev.as_bytes());
self.referencedevents.insert(&key, &[])?;
}
self.mark_as_referenced(&pdu.room_id, &pdu.prev_events)?;
self.replace_pdu_leaves(&pdu.room_id, leaves)?;
let mutex_insert = Arc::clone(
@ -1565,35 +1630,22 @@ impl Rooms {
///
/// This adds all current state events (not including the incoming event)
/// to `stateid_pduid` and adds the incoming event to `eventid_statehash`.
#[tracing::instrument(skip(self, state, globals))]
#[tracing::instrument(skip(self, state_ids_compressed, globals))]
pub fn set_event_state(
&self,
event_id: &EventId,
room_id: &RoomId,
state: &StateMap<Arc<PduEvent>>,
state_ids_compressed: HashSet<CompressedStateEvent>,
globals: &super::globals::Globals,
) -> Result<()> {
let shorteventid = self.get_or_create_shorteventid(&event_id, globals)?;
let previous_shortstatehash = self.current_shortstatehash(&room_id)?;
let state_ids_compressed = state
.iter()
.filter_map(|((event_type, state_key), pdu)| {
let shortstatekey = self
.get_or_create_shortstatekey(event_type, state_key, globals)
.ok()?;
Some(
self.compress_state_event(shortstatekey, &pdu.event_id, globals)
.ok()?,
)
})
.collect::<HashSet<_>>();
let state_hash = self.calculate_hash(
&state
.values()
.map(|pdu| pdu.event_id.as_bytes())
&state_ids_compressed
.iter()
.map(|s| &s[..])
.collect::<Vec<_>>(),
);
@ -1857,8 +1909,8 @@ impl Rooms {
&room_version,
&Arc::new(pdu.clone()),
create_prev_event,
&auth_events,
None, // TODO: third_party_invite
|k, s| auth_events.get(&(k.clone(), s.to_owned())).map(Arc::clone),
)
.map_err(|e| {
error!("{:?}", e);

View file

@ -1,5 +1,5 @@
use std::{
collections::{BTreeMap, HashMap},
collections::{BTreeMap, HashMap, HashSet},
convert::{TryFrom, TryInto},
fmt::Debug,
sync::Arc,
@ -20,14 +20,17 @@ use ruma::{
appservice,
federation::{
self,
transactions::edu::{Edu, ReceiptContent, ReceiptData, ReceiptMap},
transactions::edu::{
DeviceListUpdateContent, Edu, ReceiptContent, ReceiptData, ReceiptMap,
},
},
OutgoingRequest,
},
device_id,
events::{push_rules, AnySyncEphemeralRoomEvent, EventType},
push,
receipt::ReceiptType,
MilliSecondsSinceUnixEpoch, ServerName, UInt, UserId,
uint, MilliSecondsSinceUnixEpoch, ServerName, UInt, UserId,
};
use tokio::{
select,
@ -317,8 +320,19 @@ impl Sending {
})?;
let mut events = Vec::new();
let mut max_edu_count = since;
let mut device_list_changes = HashSet::new();
'outer: for room_id in db.rooms.server_rooms(server) {
let room_id = room_id?;
// Look for device list updates in this room
device_list_changes.extend(
db.users
.keys_changed(&room_id.to_string(), since, None)
.filter_map(|r| r.ok())
.filter(|user_id| user_id.server_name() == db.globals.server_name()),
);
// Look for read receipts in this room
for r in db.rooms.edus.readreceipts_since(&room_id, since) {
let (user_id, count, read_receipt) = r?;
@ -378,6 +392,22 @@ impl Sending {
}
}
for user_id in device_list_changes {
// Empty prev id forces synapse to resync: https://github.com/matrix-org/synapse/blob/98aec1cc9da2bd6b8e34ffb282c85abf9b8b42ca/synapse/handlers/device.py#L767
// Because synapse resyncs, we can just insert dummy data
let edu = Edu::DeviceListUpdate(DeviceListUpdateContent {
user_id,
device_id: device_id!("dummy"),
device_display_name: "Dummy".to_owned(),
stream_id: uint!(1),
prev_id: Vec::new(),
deleted: None,
keys: None,
});
events.push(serde_json::to_vec(&edu).expect("json can be serialized"));
}
Ok((events, max_edu_count))
}

View file

@ -673,7 +673,7 @@ impl Users {
}
#[tracing::instrument(skip(self, user_id, rooms, globals))]
fn mark_device_key_update(
pub fn mark_device_key_update(
&self,
user_id: &UserId,
rooms: &super::rooms::Rooms,