revert: remove dependency on async_recursion

This commit is contained in:
Matthias Ahouansou 2024-03-05 19:58:39 +00:00
parent becaad677f
commit c58af8485d
No known key found for this signature in database
3 changed files with 159 additions and 169 deletions

12
Cargo.lock generated
View file

@ -80,17 +80,6 @@ version = "1.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5f093eed78becd229346bf859eec0aa4dd7ddde0757287b2b4107a1f09c80002" checksum = "5f093eed78becd229346bf859eec0aa4dd7ddde0757287b2b4107a1f09c80002"
[[package]]
name = "async-recursion"
version = "1.0.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5fd55a5ba1179988837d24ab4c7cc8ed6efdeff578ede0416b4225a5fca35bd0"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.48",
]
[[package]] [[package]]
name = "async-trait" name = "async-trait"
version = "0.1.77" version = "0.1.77"
@ -385,7 +374,6 @@ checksum = "3d7b894f5411737b7867f4827955924d7c254fc9f4d91a6aad6b097804b1018b"
name = "conduit" name = "conduit"
version = "0.7.0-alpha" version = "0.7.0-alpha"
dependencies = [ dependencies = [
"async-recursion",
"async-trait", "async-trait",
"axum", "axum",
"axum-server", "axum-server",

View file

@ -115,7 +115,6 @@ lazy_static = "1.4.0"
async-trait = "0.1.68" async-trait = "0.1.68"
sd-notify = { version = "0.4.1", optional = true } sd-notify = { version = "0.4.1", optional = true }
async-recursion = "1.0.5"
[target.'cfg(unix)'.dependencies] [target.'cfg(unix)'.dependencies]
nix = { version = "0.26.2", features = ["resource"] } nix = { version = "0.26.2", features = ["resource"] }

View file

@ -8,7 +8,6 @@ use std::{
time::{Duration, Instant, SystemTime}, time::{Duration, Instant, SystemTime},
}; };
use async_recursion::async_recursion;
use futures_util::{stream::FuturesUnordered, Future, StreamExt}; use futures_util::{stream::FuturesUnordered, Future, StreamExt};
use ruma::{ use ruma::{
api::{ api::{
@ -1044,8 +1043,7 @@ impl Service {
/// d. TODO: Ask other servers over federation? /// d. TODO: Ask other servers over federation?
#[allow(clippy::type_complexity)] #[allow(clippy::type_complexity)]
#[tracing::instrument(skip_all)] #[tracing::instrument(skip_all)]
#[async_recursion] pub(crate) fn fetch_and_handle_outliers<'a>(
pub(crate) async fn fetch_and_handle_outliers<'a>(
&'a self, &'a self,
origin: &'a ServerName, origin: &'a ServerName,
events: &'a [Arc<EventId>], events: &'a [Arc<EventId>],
@ -1053,175 +1051,180 @@ impl Service {
room_id: &'a RoomId, room_id: &'a RoomId,
room_version_id: &'a RoomVersionId, room_version_id: &'a RoomVersionId,
pub_key_map: &'a RwLock<BTreeMap<String, BTreeMap<String, Base64>>>, pub_key_map: &'a RwLock<BTreeMap<String, BTreeMap<String, Base64>>>,
) -> Vec<(Arc<PduEvent>, Option<BTreeMap<String, CanonicalJsonValue>>)> { ) -> AsyncRecursiveType<'a, Vec<(Arc<PduEvent>, Option<BTreeMap<String, CanonicalJsonValue>>)>>
let back_off = |id| async move { {
match services() Box::pin(async move {
.globals let back_off = |id| async move {
.bad_event_ratelimiter
.write()
.await
.entry(id)
{
hash_map::Entry::Vacant(e) => {
e.insert((Instant::now(), 1));
}
hash_map::Entry::Occupied(mut e) => *e.get_mut() = (Instant::now(), e.get().1 + 1),
}
};
let mut pdus = vec![];
for id in events {
// a. Look in the main timeline (pduid_pdu tree)
// b. Look at outlier pdu tree
// (get_pdu_json checks both)
if let Ok(Some(local_pdu)) = services().rooms.timeline.get_pdu(id) {
trace!("Found {} in db", id);
pdus.push((local_pdu, None));
continue;
}
// c. Ask origin server over federation
// We also handle its auth chain here so we don't get a stack overflow in
// handle_outlier_pdu.
let mut todo_auth_events = vec![Arc::clone(id)];
let mut events_in_reverse_order = Vec::new();
let mut events_all = HashSet::new();
let mut i = 0;
while let Some(next_id) = todo_auth_events.pop() {
if let Some((time, tries)) = services()
.globals
.bad_event_ratelimiter
.read()
.await
.get(&*next_id)
{
// Exponential backoff
let mut min_elapsed_duration =
Duration::from_secs(5 * 60) * (*tries) * (*tries);
if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) {
min_elapsed_duration = Duration::from_secs(60 * 60 * 24);
}
if time.elapsed() < min_elapsed_duration {
info!("Backing off from {}", next_id);
continue;
}
}
if events_all.contains(&next_id) {
continue;
}
i += 1;
if i % 100 == 0 {
tokio::task::yield_now().await;
}
if let Ok(Some(_)) = services().rooms.timeline.get_pdu(&next_id) {
trace!("Found {} in db", next_id);
continue;
}
info!("Fetching {} over federation.", next_id);
match services() match services()
.sending
.send_federation_request(
origin,
get_event::v1::Request {
event_id: (*next_id).to_owned(),
},
)
.await
{
Ok(res) => {
info!("Got {} over federation", next_id);
let (calculated_event_id, value) =
match pdu::gen_event_id_canonical_json(&res.pdu, room_version_id) {
Ok(t) => t,
Err(_) => {
back_off((*next_id).to_owned()).await;
continue;
}
};
if calculated_event_id != *next_id {
warn!("Server didn't return event id we requested: requested: {}, we got {}. Event: {:?}",
next_id, calculated_event_id, &res.pdu);
}
if let Some(auth_events) =
value.get("auth_events").and_then(|c| c.as_array())
{
for auth_event in auth_events {
if let Ok(auth_event) =
serde_json::from_value(auth_event.clone().into())
{
let a: Arc<EventId> = auth_event;
todo_auth_events.push(a);
} else {
warn!("Auth event id is not valid");
}
}
} else {
warn!("Auth event list invalid");
}
events_in_reverse_order.push((next_id.clone(), value));
events_all.insert(next_id);
}
Err(_) => {
warn!("Failed to fetch event: {}", next_id);
back_off((*next_id).to_owned()).await;
}
}
}
for (next_id, value) in events_in_reverse_order.iter().rev() {
if let Some((time, tries)) = services()
.globals .globals
.bad_event_ratelimiter .bad_event_ratelimiter
.read() .write()
.await .await
.get(&**next_id) .entry(id)
{ {
// Exponential backoff hash_map::Entry::Vacant(e) => {
let mut min_elapsed_duration = e.insert((Instant::now(), 1));
Duration::from_secs(5 * 60) * (*tries) * (*tries); }
if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) { hash_map::Entry::Occupied(mut e) => {
min_elapsed_duration = Duration::from_secs(60 * 60 * 24); *e.get_mut() = (Instant::now(), e.get().1 + 1)
}
}
};
let mut pdus = vec![];
for id in events {
// a. Look in the main timeline (pduid_pdu tree)
// b. Look at outlier pdu tree
// (get_pdu_json checks both)
if let Ok(Some(local_pdu)) = services().rooms.timeline.get_pdu(id) {
trace!("Found {} in db", id);
pdus.push((local_pdu, None));
continue;
}
// c. Ask origin server over federation
// We also handle its auth chain here so we don't get a stack overflow in
// handle_outlier_pdu.
let mut todo_auth_events = vec![Arc::clone(id)];
let mut events_in_reverse_order = Vec::new();
let mut events_all = HashSet::new();
let mut i = 0;
while let Some(next_id) = todo_auth_events.pop() {
if let Some((time, tries)) = services()
.globals
.bad_event_ratelimiter
.read()
.await
.get(&*next_id)
{
// Exponential backoff
let mut min_elapsed_duration =
Duration::from_secs(5 * 60) * (*tries) * (*tries);
if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) {
min_elapsed_duration = Duration::from_secs(60 * 60 * 24);
}
if time.elapsed() < min_elapsed_duration {
info!("Backing off from {}", next_id);
continue;
}
} }
if time.elapsed() < min_elapsed_duration { if events_all.contains(&next_id) {
info!("Backing off from {}", next_id);
continue; continue;
} }
i += 1;
if i % 100 == 0 {
tokio::task::yield_now().await;
}
if let Ok(Some(_)) = services().rooms.timeline.get_pdu(&next_id) {
trace!("Found {} in db", next_id);
continue;
}
info!("Fetching {} over federation.", next_id);
match services()
.sending
.send_federation_request(
origin,
get_event::v1::Request {
event_id: (*next_id).to_owned(),
},
)
.await
{
Ok(res) => {
info!("Got {} over federation", next_id);
let (calculated_event_id, value) =
match pdu::gen_event_id_canonical_json(&res.pdu, room_version_id) {
Ok(t) => t,
Err(_) => {
back_off((*next_id).to_owned()).await;
continue;
}
};
if calculated_event_id != *next_id {
warn!("Server didn't return event id we requested: requested: {}, we got {}. Event: {:?}",
next_id, calculated_event_id, &res.pdu);
}
if let Some(auth_events) =
value.get("auth_events").and_then(|c| c.as_array())
{
for auth_event in auth_events {
if let Ok(auth_event) =
serde_json::from_value(auth_event.clone().into())
{
let a: Arc<EventId> = auth_event;
todo_auth_events.push(a);
} else {
warn!("Auth event id is not valid");
}
}
} else {
warn!("Auth event list invalid");
}
events_in_reverse_order.push((next_id.clone(), value));
events_all.insert(next_id);
}
Err(_) => {
warn!("Failed to fetch event: {}", next_id);
back_off((*next_id).to_owned()).await;
}
}
} }
match self for (next_id, value) in events_in_reverse_order.iter().rev() {
.handle_outlier_pdu( if let Some((time, tries)) = services()
origin, .globals
create_event, .bad_event_ratelimiter
next_id, .read()
room_id, .await
value.clone(), .get(&**next_id)
true, {
pub_key_map, // Exponential backoff
) let mut min_elapsed_duration =
.await Duration::from_secs(5 * 60) * (*tries) * (*tries);
{ if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) {
Ok((pdu, json)) => { min_elapsed_duration = Duration::from_secs(60 * 60 * 24);
if next_id == id { }
pdus.push((pdu, Some(json)));
if time.elapsed() < min_elapsed_duration {
info!("Backing off from {}", next_id);
continue;
} }
} }
Err(e) => {
warn!("Authentication of event {} failed: {:?}", next_id, e); match self
back_off((**next_id).to_owned()).await; .handle_outlier_pdu(
origin,
create_event,
next_id,
room_id,
value.clone(),
true,
pub_key_map,
)
.await
{
Ok((pdu, json)) => {
if next_id == id {
pdus.push((pdu, Some(json)));
}
}
Err(e) => {
warn!("Authentication of event {} failed: {:?}", next_id, e);
back_off((**next_id).to_owned()).await;
}
} }
} }
} }
} pdus
pdus })
} }
async fn fetch_unknown_prev_events( async fn fetch_unknown_prev_events(