improvement: auth chain cache

This commit is contained in:
Timo Kösters 2021-07-18 20:43:39 +02:00
parent f5273f7eb1
commit cfaa900e83
No known key found for this signature in database
GPG key ID: 24DA7517711A2BA4
10 changed files with 201 additions and 176 deletions

View file

@ -6,6 +6,7 @@ use crate::{
use get_profile_information::v1::ProfileField;
use http::header::{HeaderValue, AUTHORIZATION, HOST};
use log::{debug, error, info, trace, warn};
use lru_cache::LruCache;
use regex::Regex;
use rocket::response::content::Json;
use ruma::{
@ -52,7 +53,7 @@ use ruma::{
ServerSigningKeyId, UserId,
};
use std::{
collections::{btree_map::Entry, BTreeMap, BTreeSet, HashSet},
collections::{hash_map::Entry, BTreeMap, HashMap, HashSet},
convert::{TryFrom, TryInto},
fmt::Debug,
future::Future,
@ -931,7 +932,7 @@ pub fn handle_incoming_pdu<'a>(
);
// Build map of auth events
let mut auth_events = BTreeMap::new();
let mut auth_events = HashMap::new();
for id in &incoming_pdu.auth_events {
let auth_event = db
.rooms
@ -1097,7 +1098,7 @@ pub fn handle_incoming_pdu<'a>(
Err(_) => return Err("Failed to fetch state events.".to_owned()),
};
let mut state = BTreeMap::new();
let mut state = HashMap::new();
for pdu in state_vec {
match state.entry((pdu.kind.clone(), pdu.state_key.clone().ok_or_else(|| "Found non-state pdu in state events.".to_owned())?)) {
Entry::Vacant(v) => {
@ -1173,7 +1174,8 @@ pub fn handle_incoming_pdu<'a>(
}
}
let mut fork_states = BTreeSet::new();
let mut extremity_statehashes = Vec::new();
for id in &extremities {
match db
.rooms
@ -1181,30 +1183,19 @@ pub fn handle_incoming_pdu<'a>(
.map_err(|_| "Failed to ask db for pdu.".to_owned())?
{
Some(leaf_pdu) => {
let pdu_shortstatehash = db
.rooms
.pdu_shortstatehash(&leaf_pdu.event_id)
.map_err(|_| "Failed to ask db for pdu state hash.".to_owned())?
.ok_or_else(|| {
error!(
"Found extremity pdu with no statehash in db: {:?}",
leaf_pdu
);
"Found pdu with no statehash in db.".to_owned()
})?;
let mut leaf_state = db
.rooms
.state_full(pdu_shortstatehash)
.map_err(|_| "Failed to ask db for room state.".to_owned())?;
if let Some(state_key) = &leaf_pdu.state_key {
// Now it's the state after
let key = (leaf_pdu.kind.clone(), state_key.clone());
leaf_state.insert(key, leaf_pdu);
}
fork_states.insert(leaf_state);
extremity_statehashes.push((
db.rooms
.pdu_shortstatehash(&leaf_pdu.event_id)
.map_err(|_| "Failed to ask db for pdu state hash.".to_owned())?
.ok_or_else(|| {
error!(
"Found extremity pdu with no statehash in db: {:?}",
leaf_pdu
);
"Found pdu with no statehash in db.".to_owned()
})?,
Some(leaf_pdu),
));
}
_ => {
error!("Missing state snapshot for {:?}", id);
@ -1218,12 +1209,36 @@ pub fn handle_incoming_pdu<'a>(
// don't just trust a set of state we got from a remote).
// We do this by adding the current state to the list of fork states
let current_statehash = db
.rooms
.current_shortstatehash(&room_id)
.map_err(|_| "Failed to load current state hash.".to_owned())?
.expect("every room has state");
let current_state = db
.rooms
.room_state_full(&room_id)
.map_err(|_| "Failed to load room state.".to_owned())?;
.state_full(current_statehash)
.map_err(|_| "Failed to load room state.")?;
fork_states.insert(current_state.clone());
extremity_statehashes.push((current_statehash.clone(), None));
let mut fork_states = Vec::new();
for (statehash, leaf_pdu) in extremity_statehashes {
let mut leaf_state = db
.rooms
.state_full(statehash)
.map_err(|_| "Failed to ask db for room state.".to_owned())?;
if let Some(leaf_pdu) = leaf_pdu {
if let Some(state_key) = &leaf_pdu.state_key {
// Now it's the state after
let key = (leaf_pdu.kind.clone(), state_key.clone());
leaf_state.insert(key, leaf_pdu);
}
}
fork_states.push(leaf_state);
}
// We also add state after incoming event to the fork states
extremities.insert(incoming_pdu.event_id.clone());
@ -1234,9 +1249,7 @@ pub fn handle_incoming_pdu<'a>(
incoming_pdu.clone(),
);
}
fork_states.insert(state_after.clone());
let fork_states = fork_states.into_iter().collect::<Vec<_>>();
fork_states.push(state_after.clone());
let mut update_state = false;
// 14. Use state resolution to find new room state
@ -1254,17 +1267,31 @@ pub fn handle_incoming_pdu<'a>(
// We do need to force an update to this room's state
update_state = true;
match state_res::StateResolution::resolve(
let fork_states = &fork_states
.into_iter()
.map(|map| {
map.into_iter()
.map(|(k, v)| (k, v.event_id.clone()))
.collect::<StateMap<_>>()
})
.collect::<Vec<_>>();
let auth_chain_t = Instant::now();
let mut auth_chain_sets = Vec::new();
for state in fork_states {
auth_chain_sets.push(
get_auth_chain(state.iter().map(|(_, id)| id.clone()).collect(), db)
.map_err(|_| "Failed to load auth chain.".to_owned())?,
);
}
dbg!(auth_chain_t.elapsed());
let state_res_t = Instant::now();
let state = match state_res::StateResolution::resolve(
&room_id,
room_version_id,
&fork_states
.into_iter()
.map(|map| {
map.into_iter()
.map(|(k, v)| (k, v.event_id.clone()))
.collect::<StateMap<_>>()
})
.collect::<Vec<_>>(),
fork_states,
auth_chain_sets,
|id| {
let res = db.rooms.get_pdu(id);
if let Err(e) = &res {
@ -1277,7 +1304,9 @@ pub fn handle_incoming_pdu<'a>(
Err(_) => {
return Err("State resolution failed, either an event could not be found or deserialization".into());
}
}
};
dbg!(state_res_t.elapsed());
state
};
// 13. Check if the event passes auth based on the "current state" of the room, if not "soft fail" it
@ -1696,6 +1725,42 @@ async fn append_incoming_pdu(
Ok(pdu_id)
}
fn get_auth_chain(starting_events: Vec<EventId>, db: &Database) -> Result<HashSet<EventId>> {
let mut auth_chain_cache = db.rooms.auth_chain_cache();
let mut auth_chain = HashSet::new();
for event in starting_events {
auth_chain.extend(get_auth_chain_recursive(&event, &mut auth_chain_cache, db)?);
}
Ok(auth_chain)
}
fn get_auth_chain_recursive(
event_id: &EventId,
auth_chain_cache: &mut std::sync::MutexGuard<'_, LruCache<EventId, HashSet<EventId>>>,
db: &Database,
) -> Result<HashSet<EventId>> {
if let Some(cached) = auth_chain_cache.get_mut(event_id) {
return Ok(cached.clone());
}
let mut auth_chain = HashSet::new();
if let Some(pdu) = db.rooms.get_pdu(&event_id)? {
for auth_event in &pdu.auth_events {
auth_chain.extend(get_auth_chain_recursive(&auth_event, auth_chain_cache, db)?);
}
} else {
warn!("Could not find pdu mentioned in auth events.");
}
auth_chain_cache.insert(event_id.clone(), auth_chain.clone());
Ok(auth_chain)
}
#[cfg_attr(
feature = "conduit_bin",
get("/_matrix/federation/v1/event/<_>", data = "<body>")
@ -1783,35 +1848,20 @@ pub fn get_event_authorization_route(
return Err(Error::bad_config("Federation is disabled."));
}
let mut auth_chain = Vec::new();
let mut auth_chain_ids = BTreeSet::<EventId>::new();
let mut todo = BTreeSet::new();
todo.insert(body.event_id.clone());
let auth_chain_ids = get_auth_chain(vec![body.event_id.clone()], &db)?;
while let Some(event_id) = todo.iter().next().cloned() {
if let Some(pdu) = db.rooms.get_pdu(&event_id)? {
todo.extend(
pdu.auth_events
.clone()
.into_iter()
.collect::<BTreeSet<_>>()
.difference(&auth_chain_ids)
.cloned(),
);
auth_chain_ids.extend(pdu.auth_events.clone().into_iter());
let pdu_json = PduEvent::convert_to_outgoing_federation_event(
db.rooms.get_pdu_json(&event_id)?.unwrap(),
);
auth_chain.push(pdu_json);
} else {
warn!("Could not find pdu mentioned in auth events.");
}
todo.remove(&event_id);
Ok(get_event_authorization::v1::Response {
auth_chain: auth_chain_ids
.into_iter()
.map(|id| {
Ok::<_, Error>(PduEvent::convert_to_outgoing_federation_event(
db.rooms.get_pdu_json(&id)?.unwrap(),
))
})
.filter_map(|r| r.ok())
.collect(),
}
Ok(get_event_authorization::v1::Response { auth_chain }.into())
.into())
}
#[cfg_attr(
@ -1846,35 +1896,21 @@ pub fn get_room_state_route(
})
.collect();
let mut auth_chain = Vec::new();
let mut auth_chain_ids = BTreeSet::<EventId>::new();
let mut todo = BTreeSet::new();
todo.insert(body.event_id.clone());
let auth_chain_ids = get_auth_chain(vec![body.event_id.clone()], &db)?;
while let Some(event_id) = todo.iter().next().cloned() {
if let Some(pdu) = db.rooms.get_pdu(&event_id)? {
todo.extend(
pdu.auth_events
.clone()
.into_iter()
.collect::<BTreeSet<_>>()
.difference(&auth_chain_ids)
.cloned(),
);
auth_chain_ids.extend(pdu.auth_events.clone().into_iter());
let pdu_json = PduEvent::convert_to_outgoing_federation_event(
db.rooms.get_pdu_json(&event_id)?.unwrap(),
);
auth_chain.push(pdu_json);
} else {
warn!("Could not find pdu mentioned in auth events.");
}
todo.remove(&event_id);
Ok(get_room_state::v1::Response {
auth_chain: auth_chain_ids
.into_iter()
.map(|id| {
Ok::<_, Error>(PduEvent::convert_to_outgoing_federation_event(
db.rooms.get_pdu_json(&id)?.unwrap(),
))
})
.filter_map(|r| r.ok())
.collect(),
pdus,
}
Ok(get_room_state::v1::Response { auth_chain, pdus }.into())
.into())
}
#[cfg_attr(
@ -1904,27 +1940,7 @@ pub fn get_room_state_ids_route(
.into_iter()
.collect();
let mut auth_chain_ids = BTreeSet::<EventId>::new();
let mut todo = BTreeSet::new();
todo.insert(body.event_id.clone());
while let Some(event_id) = todo.iter().next().cloned() {
if let Some(pdu) = db.rooms.get_pdu(&event_id)? {
todo.extend(
pdu.auth_events
.clone()
.into_iter()
.collect::<BTreeSet<_>>()
.difference(&auth_chain_ids)
.cloned(),
);
auth_chain_ids.extend(pdu.auth_events.clone().into_iter());
} else {
warn!("Could not find pdu mentioned in auth events.");
}
todo.remove(&event_id);
}
let auth_chain_ids = get_auth_chain(vec![body.event_id.clone()], &db)?;
Ok(get_room_state_ids::v1::Response {
auth_chain_ids: auth_chain_ids.into_iter().collect(),
@ -2182,8 +2198,8 @@ pub async fn create_join_event_route(
let state_ids = db.rooms.state_full_ids(shortstatehash)?;
let mut auth_chain_ids = BTreeSet::<EventId>::new();
let mut todo = state_ids.iter().cloned().collect::<BTreeSet<_>>();
let mut auth_chain_ids = HashSet::<EventId>::new();
let mut todo = state_ids.iter().cloned().collect::<HashSet<_>>();
while let Some(event_id) = todo.iter().next().cloned() {
if let Some(pdu) = db.rooms.get_pdu(&event_id)? {
@ -2191,7 +2207,7 @@ pub async fn create_join_event_route(
pdu.auth_events
.clone()
.into_iter()
.collect::<BTreeSet<_>>()
.collect::<HashSet<_>>()
.difference(&auth_chain_ids)
.cloned(),
);