messing with trait objects

This commit is contained in:
Timo Kösters 2022-10-05 12:45:54 +02:00 committed by Nyaaori
parent 8708cd3b63
commit face766e0f
No known key found for this signature in database
GPG key ID: E7819C3ED4D1F82E
61 changed files with 623 additions and 544 deletions

View file

@ -1,14 +1,16 @@
/// An async function that can recursively call itself.
type AsyncRecursiveType<'a, T> = Pin<Box<dyn Future<Output = T> + 'a + Send>>;
use ruma::{RoomVersionId, signatures::CanonicalJsonObject, api::federation::discovery::{get_server_keys, get_remote_server_keys}};
use tokio::sync::Semaphore;
use std::{
collections::{btree_map, hash_map, BTreeMap, HashMap, HashSet},
pin::Pin,
sync::{Arc, RwLock},
time::{Duration, Instant},
sync::{Arc, RwLock, RwLockWriteGuard},
time::{Duration, Instant, SystemTime},
};
use futures_util::{Future, stream::FuturesUnordered};
use futures_util::{Future, stream::FuturesUnordered, StreamExt};
use ruma::{
api::{
client::error::ErrorKind,
@ -22,7 +24,7 @@ use ruma::{
uint, EventId, MilliSecondsSinceUnixEpoch, RoomId, ServerName, ServerSigningKeyId,
};
use serde_json::value::{to_raw_value, RawValue as RawJsonValue};
use tracing::{error, info, trace, warn};
use tracing::{error, info, trace, warn, debug};
use crate::{service::*, services, Result, Error, PduEvent};
@ -53,7 +55,7 @@ impl Service {
/// it
/// 14. Use state resolution to find new room state
// We use some AsyncRecursiveType hacks here so we can call this async funtion recursively
#[tracing::instrument(skip(value, is_timeline_event, pub_key_map))]
#[tracing::instrument(skip(self, value, is_timeline_event, pub_key_map))]
pub(crate) async fn handle_incoming_pdu<'a>(
&self,
origin: &'a ServerName,
@ -64,10 +66,11 @@ impl Service {
pub_key_map: &'a RwLock<BTreeMap<String, BTreeMap<String, Base64>>>,
) -> Result<Option<Vec<u8>>> {
if !services().rooms.metadata.exists(room_id)? {
return Error::BadRequest(
return Err(Error::BadRequest(
ErrorKind::NotFound,
"Room is unknown to this server",
)};
));
}
services()
.rooms
@ -732,7 +735,7 @@ impl Service {
&incoming_pdu.sender,
incoming_pdu.state_key.as_deref(),
&incoming_pdu.content,
)?
)?;
let soft_fail = !state_res::event_auth::auth_check(
&room_version,
@ -821,7 +824,7 @@ impl Service {
let shortstatekey = services()
.rooms
.short
.get_or_create_shortstatekey(&incoming_pdu.kind.to_string().into(), state_key)?
.get_or_create_shortstatekey(&incoming_pdu.kind.to_string().into(), state_key)?;
state_after.insert(shortstatekey, Arc::from(&*incoming_pdu.event_id));
}
@ -1236,7 +1239,7 @@ impl Service {
let signature_ids = signature_object.keys().cloned().collect::<Vec<_>>();
let fetch_res = fetch_signing_keys(
let fetch_res = self.fetch_signing_keys(
signature_server.as_str().try_into().map_err(|_| {
Error::BadServerResponse("Invalid servername in signatures of server response pdu.")
})?,
@ -1481,4 +1484,168 @@ impl Service {
))
}
}
/// Search the DB for the signing keys of the given server, if we don't have them
/// fetch them from the server and save to our DB.
#[tracing::instrument(skip_all)]
pub async fn fetch_signing_keys(
&self,
origin: &ServerName,
signature_ids: Vec<String>,
) -> Result<BTreeMap<String, Base64>> {
let contains_all_ids =
|keys: &BTreeMap<String, Base64>| signature_ids.iter().all(|id| keys.contains_key(id));
let permit = services()
.globals
.servername_ratelimiter
.read()
.unwrap()
.get(origin)
.map(|s| Arc::clone(s).acquire_owned());
let permit = match permit {
Some(p) => p,
None => {
let mut write = services().globals.servername_ratelimiter.write().unwrap();
let s = Arc::clone(
write
.entry(origin.to_owned())
.or_insert_with(|| Arc::new(Semaphore::new(1))),
);
s.acquire_owned()
}
}
.await;
let back_off = |id| match services()
.globals
.bad_signature_ratelimiter
.write()
.unwrap()
.entry(id)
{
hash_map::Entry::Vacant(e) => {
e.insert((Instant::now(), 1));
}
hash_map::Entry::Occupied(mut e) => *e.get_mut() = (Instant::now(), e.get().1 + 1),
};
if let Some((time, tries)) = services()
.globals
.bad_signature_ratelimiter
.read()
.unwrap()
.get(&signature_ids)
{
// Exponential backoff
let mut min_elapsed_duration = Duration::from_secs(30) * (*tries) * (*tries);
if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) {
min_elapsed_duration = Duration::from_secs(60 * 60 * 24);
}
if time.elapsed() < min_elapsed_duration {
debug!("Backing off from {:?}", signature_ids);
return Err(Error::BadServerResponse("bad signature, still backing off"));
}
}
trace!("Loading signing keys for {}", origin);
let mut result: BTreeMap<_, _> = services()
.globals
.signing_keys_for(origin)?
.into_iter()
.map(|(k, v)| (k.to_string(), v.key))
.collect();
if contains_all_ids(&result) {
return Ok(result);
}
debug!("Fetching signing keys for {} over federation", origin);
if let Some(server_key) = services()
.sending
.send_federation_request(origin, get_server_keys::v2::Request::new())
.await
.ok()
.and_then(|resp| resp.server_key.deserialize().ok())
{
services().globals.add_signing_key(origin, server_key.clone())?;
result.extend(
server_key
.verify_keys
.into_iter()
.map(|(k, v)| (k.to_string(), v.key)),
);
result.extend(
server_key
.old_verify_keys
.into_iter()
.map(|(k, v)| (k.to_string(), v.key)),
);
if contains_all_ids(&result) {
return Ok(result);
}
}
for server in services().globals.trusted_servers() {
debug!("Asking {} for {}'s signing key", server, origin);
if let Some(server_keys) = services()
.sending
.send_federation_request(
server,
get_remote_server_keys::v2::Request::new(
origin,
MilliSecondsSinceUnixEpoch::from_system_time(
SystemTime::now()
.checked_add(Duration::from_secs(3600))
.expect("SystemTime to large"),
)
.expect("time is valid"),
),
)
.await
.ok()
.map(|resp| {
resp.server_keys
.into_iter()
.filter_map(|e| e.deserialize().ok())
.collect::<Vec<_>>()
})
{
trace!("Got signing keys: {:?}", server_keys);
for k in server_keys {
services().globals.add_signing_key(origin, k.clone())?;
result.extend(
k.verify_keys
.into_iter()
.map(|(k, v)| (k.to_string(), v.key)),
);
result.extend(
k.old_verify_keys
.into_iter()
.map(|(k, v)| (k.to_string(), v.key)),
);
}
if contains_all_ids(&result) {
return Ok(result);
}
}
}
drop(permit);
back_off(signature_ids);
warn!("Failed to find public key for server: {}", origin);
Err(Error::BadServerResponse(
"Failed to find public key for server",
))
}
}