refactor: use async-aware RwLocks and Mutexes where possible

This commit is contained in:
Matthias Ahouansou 2024-03-05 14:22:54 +00:00
parent 57575b7c6f
commit becaad677f
No known key found for this signature in database
21 changed files with 1171 additions and 1006 deletions

View file

@ -1,7 +1,7 @@
use std::{
collections::BTreeMap,
convert::{TryFrom, TryInto},
sync::{Arc, RwLock},
sync::Arc,
time::Instant,
};
@ -26,7 +26,7 @@ use ruma::{
EventId, OwnedRoomAliasId, OwnedRoomId, RoomAliasId, RoomId, RoomVersionId, ServerName, UserId,
};
use serde_json::value::to_raw_value;
use tokio::sync::{mpsc, Mutex, MutexGuard};
use tokio::sync::{mpsc, Mutex, RwLock};
use crate::{
api::client_server::{leave_all_rooms, AUTO_GEN_PASSWORD_LENGTH},
@ -215,27 +215,6 @@ impl Service {
.expect("@conduit:server_name is valid");
if let Ok(Some(conduit_room)) = services().admin.get_admin_room() {
let send_message = |message: RoomMessageEventContent,
mutex_lock: &MutexGuard<'_, ()>| {
services()
.rooms
.timeline
.build_and_append_pdu(
PduBuilder {
event_type: TimelineEventType::RoomMessage,
content: to_raw_value(&message)
.expect("event is valid, we just created it"),
unsigned: None,
state_key: None,
redacts: None,
},
&conduit_user,
&conduit_room,
mutex_lock,
)
.unwrap();
};
loop {
tokio::select! {
Some(event) = receiver.recv() => {
@ -248,16 +227,30 @@ impl Service {
services().globals
.roomid_mutex_state
.write()
.unwrap()
.await
.entry(conduit_room.to_owned())
.or_default(),
);
let state_lock = mutex_state.lock().await;
send_message(message_content, &state_lock);
drop(state_lock);
services()
.rooms
.timeline
.build_and_append_pdu(
PduBuilder {
event_type: TimelineEventType::RoomMessage,
content: to_raw_value(&message_content)
.expect("event is valid, we just created it"),
unsigned: None,
state_key: None,
redacts: None,
},
&conduit_user,
&conduit_room,
&state_lock,
)
.await.unwrap();
}
}
}
@ -425,11 +418,7 @@ impl Service {
Err(e) => RoomMessageEventContent::text_plain(e.to_string()),
},
AdminCommand::IncomingFederation => {
let map = services()
.globals
.roomid_federationhandletime
.read()
.unwrap();
let map = services().globals.roomid_federationhandletime.read().await;
let mut msg: String = format!("Handling {} incoming pdus:\n", map.len());
for (r, (e, i)) in map.iter() {
@ -543,7 +532,7 @@ impl Service {
}
}
AdminCommand::MemoryUsage => {
let response1 = services().memory_usage();
let response1 = services().memory_usage().await;
let response2 = services().globals.db.memory_usage();
RoomMessageEventContent::text_plain(format!(
@ -556,7 +545,7 @@ impl Service {
RoomMessageEventContent::text_plain("Done.")
}
AdminCommand::ClearServiceCaches { amount } => {
services().clear_caches(amount);
services().clear_caches(amount).await;
RoomMessageEventContent::text_plain("Done.")
}
@ -797,7 +786,7 @@ impl Service {
.fetch_required_signing_keys(&value, &pub_key_map)
.await?;
let pub_key_map = pub_key_map.read().unwrap();
let pub_key_map = pub_key_map.read().await;
match ruma::signatures::verify_json(&pub_key_map, &value) {
Ok(_) => RoomMessageEventContent::text_plain("Signature correct"),
Err(e) => RoomMessageEventContent::text_plain(format!(
@ -913,7 +902,7 @@ impl Service {
.globals
.roomid_mutex_state
.write()
.unwrap()
.await
.entry(room_id.clone())
.or_default(),
);
@ -932,164 +921,202 @@ impl Service {
content.room_version = services().globals.default_room_version();
// 1. The room create event
services().rooms.timeline.build_and_append_pdu(
PduBuilder {
event_type: TimelineEventType::RoomCreate,
content: to_raw_value(&content).expect("event is valid, we just created it"),
unsigned: None,
state_key: Some("".to_owned()),
redacts: None,
},
&conduit_user,
&room_id,
&state_lock,
)?;
services()
.rooms
.timeline
.build_and_append_pdu(
PduBuilder {
event_type: TimelineEventType::RoomCreate,
content: to_raw_value(&content).expect("event is valid, we just created it"),
unsigned: None,
state_key: Some("".to_owned()),
redacts: None,
},
&conduit_user,
&room_id,
&state_lock,
)
.await?;
// 2. Make conduit bot join
services().rooms.timeline.build_and_append_pdu(
PduBuilder {
event_type: TimelineEventType::RoomMember,
content: to_raw_value(&RoomMemberEventContent {
membership: MembershipState::Join,
displayname: None,
avatar_url: None,
is_direct: None,
third_party_invite: None,
blurhash: None,
reason: None,
join_authorized_via_users_server: None,
})
.expect("event is valid, we just created it"),
unsigned: None,
state_key: Some(conduit_user.to_string()),
redacts: None,
},
&conduit_user,
&room_id,
&state_lock,
)?;
services()
.rooms
.timeline
.build_and_append_pdu(
PduBuilder {
event_type: TimelineEventType::RoomMember,
content: to_raw_value(&RoomMemberEventContent {
membership: MembershipState::Join,
displayname: None,
avatar_url: None,
is_direct: None,
third_party_invite: None,
blurhash: None,
reason: None,
join_authorized_via_users_server: None,
})
.expect("event is valid, we just created it"),
unsigned: None,
state_key: Some(conduit_user.to_string()),
redacts: None,
},
&conduit_user,
&room_id,
&state_lock,
)
.await?;
// 3. Power levels
let mut users = BTreeMap::new();
users.insert(conduit_user.clone(), 100.into());
services().rooms.timeline.build_and_append_pdu(
PduBuilder {
event_type: TimelineEventType::RoomPowerLevels,
content: to_raw_value(&RoomPowerLevelsEventContent {
users,
..Default::default()
})
.expect("event is valid, we just created it"),
unsigned: None,
state_key: Some("".to_owned()),
redacts: None,
},
&conduit_user,
&room_id,
&state_lock,
)?;
services()
.rooms
.timeline
.build_and_append_pdu(
PduBuilder {
event_type: TimelineEventType::RoomPowerLevels,
content: to_raw_value(&RoomPowerLevelsEventContent {
users,
..Default::default()
})
.expect("event is valid, we just created it"),
unsigned: None,
state_key: Some("".to_owned()),
redacts: None,
},
&conduit_user,
&room_id,
&state_lock,
)
.await?;
// 4.1 Join Rules
services().rooms.timeline.build_and_append_pdu(
PduBuilder {
event_type: TimelineEventType::RoomJoinRules,
content: to_raw_value(&RoomJoinRulesEventContent::new(JoinRule::Invite))
.expect("event is valid, we just created it"),
unsigned: None,
state_key: Some("".to_owned()),
redacts: None,
},
&conduit_user,
&room_id,
&state_lock,
)?;
services()
.rooms
.timeline
.build_and_append_pdu(
PduBuilder {
event_type: TimelineEventType::RoomJoinRules,
content: to_raw_value(&RoomJoinRulesEventContent::new(JoinRule::Invite))
.expect("event is valid, we just created it"),
unsigned: None,
state_key: Some("".to_owned()),
redacts: None,
},
&conduit_user,
&room_id,
&state_lock,
)
.await?;
// 4.2 History Visibility
services().rooms.timeline.build_and_append_pdu(
PduBuilder {
event_type: TimelineEventType::RoomHistoryVisibility,
content: to_raw_value(&RoomHistoryVisibilityEventContent::new(
HistoryVisibility::Shared,
))
.expect("event is valid, we just created it"),
unsigned: None,
state_key: Some("".to_owned()),
redacts: None,
},
&conduit_user,
&room_id,
&state_lock,
)?;
services()
.rooms
.timeline
.build_and_append_pdu(
PduBuilder {
event_type: TimelineEventType::RoomHistoryVisibility,
content: to_raw_value(&RoomHistoryVisibilityEventContent::new(
HistoryVisibility::Shared,
))
.expect("event is valid, we just created it"),
unsigned: None,
state_key: Some("".to_owned()),
redacts: None,
},
&conduit_user,
&room_id,
&state_lock,
)
.await?;
// 4.3 Guest Access
services().rooms.timeline.build_and_append_pdu(
PduBuilder {
event_type: TimelineEventType::RoomGuestAccess,
content: to_raw_value(&RoomGuestAccessEventContent::new(GuestAccess::Forbidden))
services()
.rooms
.timeline
.build_and_append_pdu(
PduBuilder {
event_type: TimelineEventType::RoomGuestAccess,
content: to_raw_value(&RoomGuestAccessEventContent::new(
GuestAccess::Forbidden,
))
.expect("event is valid, we just created it"),
unsigned: None,
state_key: Some("".to_owned()),
redacts: None,
},
&conduit_user,
&room_id,
&state_lock,
)?;
unsigned: None,
state_key: Some("".to_owned()),
redacts: None,
},
&conduit_user,
&room_id,
&state_lock,
)
.await?;
// 5. Events implied by name and topic
let room_name = format!("{} Admin Room", services().globals.server_name());
services().rooms.timeline.build_and_append_pdu(
PduBuilder {
event_type: TimelineEventType::RoomName,
content: to_raw_value(&RoomNameEventContent::new(room_name))
.expect("event is valid, we just created it"),
unsigned: None,
state_key: Some("".to_owned()),
redacts: None,
},
&conduit_user,
&room_id,
&state_lock,
)?;
services()
.rooms
.timeline
.build_and_append_pdu(
PduBuilder {
event_type: TimelineEventType::RoomName,
content: to_raw_value(&RoomNameEventContent::new(room_name))
.expect("event is valid, we just created it"),
unsigned: None,
state_key: Some("".to_owned()),
redacts: None,
},
&conduit_user,
&room_id,
&state_lock,
)
.await?;
services().rooms.timeline.build_and_append_pdu(
PduBuilder {
event_type: TimelineEventType::RoomTopic,
content: to_raw_value(&RoomTopicEventContent {
topic: format!("Manage {}", services().globals.server_name()),
})
.expect("event is valid, we just created it"),
unsigned: None,
state_key: Some("".to_owned()),
redacts: None,
},
&conduit_user,
&room_id,
&state_lock,
)?;
services()
.rooms
.timeline
.build_and_append_pdu(
PduBuilder {
event_type: TimelineEventType::RoomTopic,
content: to_raw_value(&RoomTopicEventContent {
topic: format!("Manage {}", services().globals.server_name()),
})
.expect("event is valid, we just created it"),
unsigned: None,
state_key: Some("".to_owned()),
redacts: None,
},
&conduit_user,
&room_id,
&state_lock,
)
.await?;
// 6. Room alias
let alias: OwnedRoomAliasId = format!("#admins:{}", services().globals.server_name())
.try_into()
.expect("#admins:server_name is a valid alias name");
services().rooms.timeline.build_and_append_pdu(
PduBuilder {
event_type: TimelineEventType::RoomCanonicalAlias,
content: to_raw_value(&RoomCanonicalAliasEventContent {
alias: Some(alias.clone()),
alt_aliases: Vec::new(),
})
.expect("event is valid, we just created it"),
unsigned: None,
state_key: Some("".to_owned()),
redacts: None,
},
&conduit_user,
&room_id,
&state_lock,
)?;
services()
.rooms
.timeline
.build_and_append_pdu(
PduBuilder {
event_type: TimelineEventType::RoomCanonicalAlias,
content: to_raw_value(&RoomCanonicalAliasEventContent {
alias: Some(alias.clone()),
alt_aliases: Vec::new(),
})
.expect("event is valid, we just created it"),
unsigned: None,
state_key: Some("".to_owned()),
redacts: None,
},
&conduit_user,
&room_id,
&state_lock,
)
.await?;
services().rooms.alias.set_alias(&alias, &room_id)?;
@ -1125,7 +1152,7 @@ impl Service {
.globals
.roomid_mutex_state
.write()
.unwrap()
.await
.entry(room_id.clone())
.or_default(),
);
@ -1137,72 +1164,84 @@ impl Service {
.expect("@conduit:server_name is valid");
// Invite and join the real user
services().rooms.timeline.build_and_append_pdu(
PduBuilder {
event_type: TimelineEventType::RoomMember,
content: to_raw_value(&RoomMemberEventContent {
membership: MembershipState::Invite,
displayname: None,
avatar_url: None,
is_direct: None,
third_party_invite: None,
blurhash: None,
reason: None,
join_authorized_via_users_server: None,
})
.expect("event is valid, we just created it"),
unsigned: None,
state_key: Some(user_id.to_string()),
redacts: None,
},
&conduit_user,
&room_id,
&state_lock,
)?;
services().rooms.timeline.build_and_append_pdu(
PduBuilder {
event_type: TimelineEventType::RoomMember,
content: to_raw_value(&RoomMemberEventContent {
membership: MembershipState::Join,
displayname: Some(displayname),
avatar_url: None,
is_direct: None,
third_party_invite: None,
blurhash: None,
reason: None,
join_authorized_via_users_server: None,
})
.expect("event is valid, we just created it"),
unsigned: None,
state_key: Some(user_id.to_string()),
redacts: None,
},
user_id,
&room_id,
&state_lock,
)?;
services()
.rooms
.timeline
.build_and_append_pdu(
PduBuilder {
event_type: TimelineEventType::RoomMember,
content: to_raw_value(&RoomMemberEventContent {
membership: MembershipState::Invite,
displayname: None,
avatar_url: None,
is_direct: None,
third_party_invite: None,
blurhash: None,
reason: None,
join_authorized_via_users_server: None,
})
.expect("event is valid, we just created it"),
unsigned: None,
state_key: Some(user_id.to_string()),
redacts: None,
},
&conduit_user,
&room_id,
&state_lock,
)
.await?;
services()
.rooms
.timeline
.build_and_append_pdu(
PduBuilder {
event_type: TimelineEventType::RoomMember,
content: to_raw_value(&RoomMemberEventContent {
membership: MembershipState::Join,
displayname: Some(displayname),
avatar_url: None,
is_direct: None,
third_party_invite: None,
blurhash: None,
reason: None,
join_authorized_via_users_server: None,
})
.expect("event is valid, we just created it"),
unsigned: None,
state_key: Some(user_id.to_string()),
redacts: None,
},
user_id,
&room_id,
&state_lock,
)
.await?;
// Set power level
let mut users = BTreeMap::new();
users.insert(conduit_user.to_owned(), 100.into());
users.insert(user_id.to_owned(), 100.into());
services().rooms.timeline.build_and_append_pdu(
PduBuilder {
event_type: TimelineEventType::RoomPowerLevels,
content: to_raw_value(&RoomPowerLevelsEventContent {
users,
..Default::default()
})
.expect("event is valid, we just created it"),
unsigned: None,
state_key: Some("".to_owned()),
redacts: None,
},
&conduit_user,
&room_id,
&state_lock,
)?;
services()
.rooms
.timeline
.build_and_append_pdu(
PduBuilder {
event_type: TimelineEventType::RoomPowerLevels,
content: to_raw_value(&RoomPowerLevelsEventContent {
users,
..Default::default()
})
.expect("event is valid, we just created it"),
unsigned: None,
state_key: Some("".to_owned()),
redacts: None,
},
&conduit_user,
&room_id,
&state_lock,
)
.await?;
// Send welcome message
services().rooms.timeline.build_and_append_pdu(
@ -1220,7 +1259,7 @@ impl Service {
&conduit_user,
&room_id,
&state_lock,
)?;
).await?;
}
Ok(())
}

View file

@ -31,11 +31,11 @@ use std::{
path::PathBuf,
sync::{
atomic::{self, AtomicBool},
Arc, Mutex, RwLock,
Arc, RwLock as SyncRwLock,
},
time::{Duration, Instant},
};
use tokio::sync::{broadcast, watch::Receiver, Mutex as TokioMutex, Semaphore};
use tokio::sync::{broadcast, watch::Receiver, Mutex, RwLock, Semaphore};
use tracing::{error, info};
use trust_dns_resolver::TokioAsyncResolver;
@ -53,7 +53,7 @@ pub struct Service {
pub db: &'static dyn Data,
pub actual_destination_cache: Arc<RwLock<WellKnownMap>>, // actual_destination, host
pub tls_name_override: Arc<RwLock<TlsNameMap>>,
pub tls_name_override: Arc<SyncRwLock<TlsNameMap>>,
pub config: Config,
keypair: Arc<ruma::signatures::Ed25519KeyPair>,
dns_resolver: TokioAsyncResolver,
@ -68,8 +68,8 @@ pub struct Service {
pub servername_ratelimiter: Arc<RwLock<HashMap<OwnedServerName, Arc<Semaphore>>>>,
pub sync_receivers: RwLock<HashMap<(OwnedUserId, OwnedDeviceId), SyncHandle>>,
pub roomid_mutex_insert: RwLock<HashMap<OwnedRoomId, Arc<Mutex<()>>>>,
pub roomid_mutex_state: RwLock<HashMap<OwnedRoomId, Arc<TokioMutex<()>>>>,
pub roomid_mutex_federation: RwLock<HashMap<OwnedRoomId, Arc<TokioMutex<()>>>>, // this lock will be held longer
pub roomid_mutex_state: RwLock<HashMap<OwnedRoomId, Arc<Mutex<()>>>>,
pub roomid_mutex_federation: RwLock<HashMap<OwnedRoomId, Arc<Mutex<()>>>>, // this lock will be held longer
pub roomid_federationhandletime: RwLock<HashMap<OwnedRoomId, (OwnedEventId, Instant)>>,
pub stateres_mutex: Arc<Mutex<()>>,
pub rotate: RotationHandler,
@ -109,11 +109,11 @@ impl Default for RotationHandler {
pub struct Resolver {
inner: GaiResolver,
overrides: Arc<RwLock<TlsNameMap>>,
overrides: Arc<SyncRwLock<TlsNameMap>>,
}
impl Resolver {
pub fn new(overrides: Arc<RwLock<TlsNameMap>>) -> Self {
pub fn new(overrides: Arc<SyncRwLock<TlsNameMap>>) -> Self {
Resolver {
inner: GaiResolver::new(),
overrides,
@ -125,7 +125,7 @@ impl Resolve for Resolver {
fn resolve(&self, name: Name) -> Resolving {
self.overrides
.read()
.expect("lock should not be poisoned")
.unwrap()
.get(name.as_str())
.and_then(|(override_name, port)| {
override_name.first().map(|first_name| {
@ -159,7 +159,7 @@ impl Service {
}
};
let tls_name_override = Arc::new(RwLock::new(TlsNameMap::new()));
let tls_name_override = Arc::new(SyncRwLock::new(TlsNameMap::new()));
let jwt_decoding_key = config
.jwt_secret

View file

@ -1,9 +1,10 @@
use std::{
collections::{BTreeMap, HashMap},
sync::{Arc, Mutex},
sync::{Arc, Mutex as SyncMutex},
};
use lru_cache::LruCache;
use tokio::sync::Mutex;
use crate::{Config, Result};
@ -79,17 +80,17 @@ impl Services {
state: rooms::state::Service { db },
state_accessor: rooms::state_accessor::Service {
db,
server_visibility_cache: Mutex::new(LruCache::new(
server_visibility_cache: SyncMutex::new(LruCache::new(
(100.0 * config.conduit_cache_capacity_modifier) as usize,
)),
user_visibility_cache: Mutex::new(LruCache::new(
user_visibility_cache: SyncMutex::new(LruCache::new(
(100.0 * config.conduit_cache_capacity_modifier) as usize,
)),
},
state_cache: rooms::state_cache::Service { db },
state_compressor: rooms::state_compressor::Service {
db,
stateinfo_cache: Mutex::new(LruCache::new(
stateinfo_cache: SyncMutex::new(LruCache::new(
(100.0 * config.conduit_cache_capacity_modifier) as usize,
)),
},
@ -107,7 +108,7 @@ impl Services {
uiaa: uiaa::Service { db },
users: users::Service {
db,
connections: Mutex::new(BTreeMap::new()),
connections: SyncMutex::new(BTreeMap::new()),
},
account_data: account_data::Service { db },
admin: admin::Service::build(),
@ -118,14 +119,8 @@ impl Services {
globals: globals::Service::load(db, config)?,
})
}
fn memory_usage(&self) -> String {
let lazy_load_waiting = self
.rooms
.lazy_loading
.lazy_load_waiting
.lock()
.unwrap()
.len();
async fn memory_usage(&self) -> String {
let lazy_load_waiting = self.rooms.lazy_loading.lazy_load_waiting.lock().await.len();
let server_visibility_cache = self
.rooms
.state_accessor
@ -152,15 +147,9 @@ impl Services {
.timeline
.lasttimelinecount_cache
.lock()
.unwrap()
.len();
let roomid_spacechunk_cache = self
.rooms
.spaces
.roomid_spacechunk_cache
.lock()
.unwrap()
.await
.len();
let roomid_spacechunk_cache = self.rooms.spaces.roomid_spacechunk_cache.lock().await.len();
format!(
"\
@ -173,13 +162,13 @@ roomid_spacechunk_cache: {roomid_spacechunk_cache}\
"
)
}
fn clear_caches(&self, amount: u32) {
async fn clear_caches(&self, amount: u32) {
if amount > 0 {
self.rooms
.lazy_loading
.lazy_load_waiting
.lock()
.unwrap()
.await
.clear();
}
if amount > 1 {
@ -211,7 +200,7 @@ roomid_spacechunk_cache: {roomid_spacechunk_cache}\
.timeline
.lasttimelinecount_cache
.lock()
.unwrap()
.await
.clear();
}
if amount > 5 {
@ -219,7 +208,7 @@ roomid_spacechunk_cache: {roomid_spacechunk_cache}\
.spaces
.roomid_spacechunk_cache
.lock()
.unwrap()
.await
.clear();
}
}

View file

@ -1,25 +1,24 @@
/// An async function that can recursively call itself.
type AsyncRecursiveType<'a, T> = Pin<Box<dyn Future<Output = T> + 'a + Send>>;
use ruma::{
api::federation::discovery::{get_remote_server_keys, get_server_keys},
CanonicalJsonObject, CanonicalJsonValue, OwnedServerName, OwnedServerSigningKeyId,
RoomVersionId,
};
use std::{
collections::{hash_map, BTreeMap, HashMap, HashSet},
pin::Pin,
sync::{Arc, RwLock, RwLockWriteGuard},
sync::Arc,
time::{Duration, Instant, SystemTime},
};
use tokio::sync::Semaphore;
use async_recursion::async_recursion;
use futures_util::{stream::FuturesUnordered, Future, StreamExt};
use ruma::{
api::{
client::error::ErrorKind,
federation::{
discovery::get_remote_server_keys_batch::{self, v2::QueryCriteria},
discovery::{
get_remote_server_keys,
get_remote_server_keys_batch::{self, v2::QueryCriteria},
get_server_keys,
},
event::{get_event, get_room_state_ids},
membership::create_join_event,
},
@ -31,9 +30,11 @@ use ruma::{
int,
serde::Base64,
state_res::{self, RoomVersion, StateMap},
uint, EventId, MilliSecondsSinceUnixEpoch, RoomId, ServerName,
uint, CanonicalJsonObject, CanonicalJsonValue, EventId, MilliSecondsSinceUnixEpoch,
OwnedServerName, OwnedServerSigningKeyId, RoomId, RoomVersionId, ServerName,
};
use serde_json::value::RawValue as RawJsonValue;
use tokio::sync::{RwLock, RwLockWriteGuard, Semaphore};
use tracing::{debug, error, info, trace, warn};
use crate::{service::*, services, Error, PduEvent, Result};
@ -168,7 +169,7 @@ impl Service {
.globals
.bad_event_ratelimiter
.read()
.unwrap()
.await
.get(&*prev_id)
{
// Exponential backoff
@ -189,7 +190,7 @@ impl Service {
.globals
.bad_event_ratelimiter
.write()
.unwrap()
.await
.entry((*prev_id).to_owned())
{
hash_map::Entry::Vacant(e) => {
@ -213,7 +214,7 @@ impl Service {
.globals
.roomid_federationhandletime
.write()
.unwrap()
.await
.insert(room_id.to_owned(), ((*prev_id).to_owned(), start_time));
if let Err(e) = self
@ -233,7 +234,7 @@ impl Service {
.globals
.bad_event_ratelimiter
.write()
.unwrap()
.await
.entry((*prev_id).to_owned())
{
hash_map::Entry::Vacant(e) => {
@ -249,7 +250,7 @@ impl Service {
.globals
.roomid_federationhandletime
.write()
.unwrap()
.await
.remove(&room_id.to_owned());
debug!(
"Handling prev event {} took {}m{}s",
@ -267,7 +268,7 @@ impl Service {
.globals
.roomid_federationhandletime
.write()
.unwrap()
.await
.insert(room_id.to_owned(), (event_id.to_owned(), start_time));
let r = services()
.rooms
@ -285,7 +286,7 @@ impl Service {
.globals
.roomid_federationhandletime
.write()
.unwrap()
.await
.remove(&room_id.to_owned());
r
@ -326,11 +327,8 @@ impl Service {
let room_version =
RoomVersion::new(room_version_id).expect("room version is supported");
let mut val = match ruma::signatures::verify_event(
&pub_key_map.read().expect("RwLock is poisoned."),
&value,
room_version_id,
) {
let guard = pub_key_map.read().await;
let mut val = match ruma::signatures::verify_event(&guard, &value, room_version_id) {
Err(e) => {
// Drop
warn!("Dropping bad event {}: {}", event_id, e,);
@ -365,6 +363,8 @@ impl Service {
Ok(ruma::signatures::Verified::All) => value,
};
drop(guard);
// Now that we have checked the signature and hashes we can add the eventID and convert
// to our PduEvent type
val.insert(
@ -692,13 +692,15 @@ impl Service {
{
Ok(res) => {
debug!("Fetching state events at event.");
let collect = res
.pdu_ids
.iter()
.map(|x| Arc::from(&**x))
.collect::<Vec<_>>();
let state_vec = self
.fetch_and_handle_outliers(
origin,
&res.pdu_ids
.iter()
.map(|x| Arc::from(&**x))
.collect::<Vec<_>>(),
&collect,
create_event,
room_id,
room_version_id,
@ -805,7 +807,7 @@ impl Service {
.globals
.roomid_mutex_state
.write()
.unwrap()
.await
.entry(room_id.to_owned())
.or_default(),
);
@ -884,14 +886,18 @@ impl Service {
debug!("Starting soft fail auth check");
if soft_fail {
services().rooms.timeline.append_incoming_pdu(
&incoming_pdu,
val,
extremities.iter().map(|e| (**e).to_owned()).collect(),
state_ids_compressed,
soft_fail,
&state_lock,
)?;
services()
.rooms
.timeline
.append_incoming_pdu(
&incoming_pdu,
val,
extremities.iter().map(|e| (**e).to_owned()).collect(),
state_ids_compressed,
soft_fail,
&state_lock,
)
.await?;
// Soft fail, we keep the event as an outlier but don't add it to the timeline
warn!("Event was soft failed: {:?}", incoming_pdu);
@ -912,14 +918,18 @@ impl Service {
// We use the `state_at_event` instead of `state_after` so we accurately
// represent the state for this event.
let pdu_id = services().rooms.timeline.append_incoming_pdu(
&incoming_pdu,
val,
extremities.iter().map(|e| (**e).to_owned()).collect(),
state_ids_compressed,
soft_fail,
&state_lock,
)?;
let pdu_id = services()
.rooms
.timeline
.append_incoming_pdu(
&incoming_pdu,
val,
extremities.iter().map(|e| (**e).to_owned()).collect(),
state_ids_compressed,
soft_fail,
&state_lock,
)
.await?;
debug!("Appended incoming pdu");
@ -1034,7 +1044,8 @@ impl Service {
/// d. TODO: Ask other servers over federation?
#[allow(clippy::type_complexity)]
#[tracing::instrument(skip_all)]
pub(crate) fn fetch_and_handle_outliers<'a>(
#[async_recursion]
pub(crate) async fn fetch_and_handle_outliers<'a>(
&'a self,
origin: &'a ServerName,
events: &'a [Arc<EventId>],
@ -1042,176 +1053,175 @@ impl Service {
room_id: &'a RoomId,
room_version_id: &'a RoomVersionId,
pub_key_map: &'a RwLock<BTreeMap<String, BTreeMap<String, Base64>>>,
) -> AsyncRecursiveType<'a, Vec<(Arc<PduEvent>, Option<BTreeMap<String, CanonicalJsonValue>>)>>
{
Box::pin(async move {
let back_off = |id| match services()
) -> Vec<(Arc<PduEvent>, Option<BTreeMap<String, CanonicalJsonValue>>)> {
let back_off = |id| async move {
match services()
.globals
.bad_event_ratelimiter
.write()
.unwrap()
.await
.entry(id)
{
hash_map::Entry::Vacant(e) => {
e.insert((Instant::now(), 1));
}
hash_map::Entry::Occupied(mut e) => *e.get_mut() = (Instant::now(), e.get().1 + 1),
};
}
};
let mut pdus = vec![];
for id in events {
// a. Look in the main timeline (pduid_pdu tree)
// b. Look at outlier pdu tree
// (get_pdu_json checks both)
if let Ok(Some(local_pdu)) = services().rooms.timeline.get_pdu(id) {
trace!("Found {} in db", id);
pdus.push((local_pdu, None));
let mut pdus = vec![];
for id in events {
// a. Look in the main timeline (pduid_pdu tree)
// b. Look at outlier pdu tree
// (get_pdu_json checks both)
if let Ok(Some(local_pdu)) = services().rooms.timeline.get_pdu(id) {
trace!("Found {} in db", id);
pdus.push((local_pdu, None));
continue;
}
// c. Ask origin server over federation
// We also handle its auth chain here so we don't get a stack overflow in
// handle_outlier_pdu.
let mut todo_auth_events = vec![Arc::clone(id)];
let mut events_in_reverse_order = Vec::new();
let mut events_all = HashSet::new();
let mut i = 0;
while let Some(next_id) = todo_auth_events.pop() {
if let Some((time, tries)) = services()
.globals
.bad_event_ratelimiter
.read()
.await
.get(&*next_id)
{
// Exponential backoff
let mut min_elapsed_duration =
Duration::from_secs(5 * 60) * (*tries) * (*tries);
if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) {
min_elapsed_duration = Duration::from_secs(60 * 60 * 24);
}
if time.elapsed() < min_elapsed_duration {
info!("Backing off from {}", next_id);
continue;
}
}
if events_all.contains(&next_id) {
continue;
}
// c. Ask origin server over federation
// We also handle its auth chain here so we don't get a stack overflow in
// handle_outlier_pdu.
let mut todo_auth_events = vec![Arc::clone(id)];
let mut events_in_reverse_order = Vec::new();
let mut events_all = HashSet::new();
let mut i = 0;
while let Some(next_id) = todo_auth_events.pop() {
if let Some((time, tries)) = services()
.globals
.bad_event_ratelimiter
.read()
.unwrap()
.get(&*next_id)
{
// Exponential backoff
let mut min_elapsed_duration =
Duration::from_secs(5 * 60) * (*tries) * (*tries);
if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) {
min_elapsed_duration = Duration::from_secs(60 * 60 * 24);
}
if time.elapsed() < min_elapsed_duration {
info!("Backing off from {}", next_id);
continue;
}
}
if events_all.contains(&next_id) {
continue;
}
i += 1;
if i % 100 == 0 {
tokio::task::yield_now().await;
}
if let Ok(Some(_)) = services().rooms.timeline.get_pdu(&next_id) {
trace!("Found {} in db", next_id);
continue;
}
info!("Fetching {} over federation.", next_id);
match services()
.sending
.send_federation_request(
origin,
get_event::v1::Request {
event_id: (*next_id).to_owned(),
},
)
.await
{
Ok(res) => {
info!("Got {} over federation", next_id);
let (calculated_event_id, value) =
match pdu::gen_event_id_canonical_json(&res.pdu, room_version_id) {
Ok(t) => t,
Err(_) => {
back_off((*next_id).to_owned());
continue;
}
};
if calculated_event_id != *next_id {
warn!("Server didn't return event id we requested: requested: {}, we got {}. Event: {:?}",
next_id, calculated_event_id, &res.pdu);
}
if let Some(auth_events) =
value.get("auth_events").and_then(|c| c.as_array())
{
for auth_event in auth_events {
if let Ok(auth_event) =
serde_json::from_value(auth_event.clone().into())
{
let a: Arc<EventId> = auth_event;
todo_auth_events.push(a);
} else {
warn!("Auth event id is not valid");
}
}
} else {
warn!("Auth event list invalid");
}
events_in_reverse_order.push((next_id.clone(), value));
events_all.insert(next_id);
}
Err(_) => {
warn!("Failed to fetch event: {}", next_id);
back_off((*next_id).to_owned());
}
}
i += 1;
if i % 100 == 0 {
tokio::task::yield_now().await;
}
for (next_id, value) in events_in_reverse_order.iter().rev() {
if let Some((time, tries)) = services()
.globals
.bad_event_ratelimiter
.read()
.unwrap()
.get(&**next_id)
{
// Exponential backoff
let mut min_elapsed_duration =
Duration::from_secs(5 * 60) * (*tries) * (*tries);
if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) {
min_elapsed_duration = Duration::from_secs(60 * 60 * 24);
if let Ok(Some(_)) = services().rooms.timeline.get_pdu(&next_id) {
trace!("Found {} in db", next_id);
continue;
}
info!("Fetching {} over federation.", next_id);
match services()
.sending
.send_federation_request(
origin,
get_event::v1::Request {
event_id: (*next_id).to_owned(),
},
)
.await
{
Ok(res) => {
info!("Got {} over federation", next_id);
let (calculated_event_id, value) =
match pdu::gen_event_id_canonical_json(&res.pdu, room_version_id) {
Ok(t) => t,
Err(_) => {
back_off((*next_id).to_owned()).await;
continue;
}
};
if calculated_event_id != *next_id {
warn!("Server didn't return event id we requested: requested: {}, we got {}. Event: {:?}",
next_id, calculated_event_id, &res.pdu);
}
if time.elapsed() < min_elapsed_duration {
info!("Backing off from {}", next_id);
continue;
}
}
match self
.handle_outlier_pdu(
origin,
create_event,
next_id,
room_id,
value.clone(),
true,
pub_key_map,
)
.await
{
Ok((pdu, json)) => {
if next_id == id {
pdus.push((pdu, Some(json)));
if let Some(auth_events) =
value.get("auth_events").and_then(|c| c.as_array())
{
for auth_event in auth_events {
if let Ok(auth_event) =
serde_json::from_value(auth_event.clone().into())
{
let a: Arc<EventId> = auth_event;
todo_auth_events.push(a);
} else {
warn!("Auth event id is not valid");
}
}
} else {
warn!("Auth event list invalid");
}
Err(e) => {
warn!("Authentication of event {} failed: {:?}", next_id, e);
back_off((**next_id).to_owned());
}
events_in_reverse_order.push((next_id.clone(), value));
events_all.insert(next_id);
}
Err(_) => {
warn!("Failed to fetch event: {}", next_id);
back_off((*next_id).to_owned()).await;
}
}
}
pdus
})
for (next_id, value) in events_in_reverse_order.iter().rev() {
if let Some((time, tries)) = services()
.globals
.bad_event_ratelimiter
.read()
.await
.get(&**next_id)
{
// Exponential backoff
let mut min_elapsed_duration =
Duration::from_secs(5 * 60) * (*tries) * (*tries);
if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) {
min_elapsed_duration = Duration::from_secs(60 * 60 * 24);
}
if time.elapsed() < min_elapsed_duration {
info!("Backing off from {}", next_id);
continue;
}
}
match self
.handle_outlier_pdu(
origin,
create_event,
next_id,
room_id,
value.clone(),
true,
pub_key_map,
)
.await
{
Ok((pdu, json)) => {
if next_id == id {
pdus.push((pdu, Some(json)));
}
}
Err(e) => {
warn!("Authentication of event {} failed: {:?}", next_id, e);
back_off((**next_id).to_owned()).await;
}
}
}
}
pdus
}
async fn fetch_unknown_prev_events(
@ -1360,7 +1370,7 @@ impl Service {
pub_key_map
.write()
.map_err(|_| Error::bad_database("RwLock is poisoned."))?
.await
.insert(signature_server.clone(), keys);
}
@ -1369,7 +1379,7 @@ impl Service {
// Gets a list of servers for which we don't have the signing key yet. We go over
// the PDUs and either cache the key or add it to the list that needs to be retrieved.
fn get_server_keys_from_cache(
async fn get_server_keys_from_cache(
&self,
pdu: &RawJsonValue,
servers: &mut BTreeMap<OwnedServerName, BTreeMap<OwnedServerSigningKeyId, QueryCriteria>>,
@ -1393,7 +1403,7 @@ impl Service {
.globals
.bad_event_ratelimiter
.read()
.unwrap()
.await
.get(event_id)
{
// Exponential backoff
@ -1469,17 +1479,19 @@ impl Service {
> = BTreeMap::new();
{
let mut pkm = pub_key_map
.write()
.map_err(|_| Error::bad_database("RwLock is poisoned."))?;
let mut pkm = pub_key_map.write().await;
// Try to fetch keys, failure is okay
// Servers we couldn't find in the cache will be added to `servers`
for pdu in &event.room_state.state {
let _ = self.get_server_keys_from_cache(pdu, &mut servers, room_version, &mut pkm);
let _ = self
.get_server_keys_from_cache(pdu, &mut servers, room_version, &mut pkm)
.await;
}
for pdu in &event.room_state.auth_chain {
let _ = self.get_server_keys_from_cache(pdu, &mut servers, room_version, &mut pkm);
let _ = self
.get_server_keys_from_cache(pdu, &mut servers, room_version, &mut pkm)
.await;
}
drop(pkm);
@ -1503,9 +1515,7 @@ impl Service {
.await
{
trace!("Got signing keys: {:?}", keys);
let mut pkm = pub_key_map
.write()
.map_err(|_| Error::bad_database("RwLock is poisoned."))?;
let mut pkm = pub_key_map.write().await;
for k in keys.server_keys {
let k = match k.deserialize() {
Ok(key) => key,
@ -1564,10 +1574,7 @@ impl Service {
.into_iter()
.map(|(k, v)| (k.to_string(), v.key))
.collect();
pub_key_map
.write()
.map_err(|_| Error::bad_database("RwLock is poisoned."))?
.insert(origin.to_string(), result);
pub_key_map.write().await.insert(origin.to_string(), result);
}
}
info!("Done handling result");
@ -1632,14 +1639,14 @@ impl Service {
.globals
.servername_ratelimiter
.read()
.unwrap()
.await
.get(origin)
.map(|s| Arc::clone(s).acquire_owned());
let permit = match permit {
Some(p) => p,
None => {
let mut write = services().globals.servername_ratelimiter.write().unwrap();
let mut write = services().globals.servername_ratelimiter.write().await;
let s = Arc::clone(
write
.entry(origin.to_owned())
@ -1651,24 +1658,26 @@ impl Service {
}
.await;
let back_off = |id| match services()
.globals
.bad_signature_ratelimiter
.write()
.unwrap()
.entry(id)
{
hash_map::Entry::Vacant(e) => {
e.insert((Instant::now(), 1));
let back_off = |id| async {
match services()
.globals
.bad_signature_ratelimiter
.write()
.await
.entry(id)
{
hash_map::Entry::Vacant(e) => {
e.insert((Instant::now(), 1));
}
hash_map::Entry::Occupied(mut e) => *e.get_mut() = (Instant::now(), e.get().1 + 1),
}
hash_map::Entry::Occupied(mut e) => *e.get_mut() = (Instant::now(), e.get().1 + 1),
};
if let Some((time, tries)) = services()
.globals
.bad_signature_ratelimiter
.read()
.unwrap()
.await
.get(&signature_ids)
{
// Exponential backoff
@ -1775,7 +1784,7 @@ impl Service {
drop(permit);
back_off(signature_ids);
back_off(signature_ids).await;
warn!("Failed to find public key for server: {}", origin);
Err(Error::BadServerResponse(

View file

@ -1,11 +1,9 @@
mod data;
use std::{
collections::{HashMap, HashSet},
sync::Mutex,
};
use std::collections::{HashMap, HashSet};
pub use data::Data;
use ruma::{DeviceId, OwnedDeviceId, OwnedRoomId, OwnedUserId, RoomId, UserId};
use tokio::sync::Mutex;
use crate::Result;
@ -33,7 +31,7 @@ impl Service {
}
#[tracing::instrument(skip(self))]
pub fn lazy_load_mark_sent(
pub async fn lazy_load_mark_sent(
&self,
user_id: &UserId,
device_id: &DeviceId,
@ -41,7 +39,7 @@ impl Service {
lazy_load: HashSet<OwnedUserId>,
count: PduCount,
) {
self.lazy_load_waiting.lock().unwrap().insert(
self.lazy_load_waiting.lock().await.insert(
(
user_id.to_owned(),
device_id.to_owned(),
@ -53,14 +51,14 @@ impl Service {
}
#[tracing::instrument(skip(self))]
pub fn lazy_load_confirm_delivery(
pub async fn lazy_load_confirm_delivery(
&self,
user_id: &UserId,
device_id: &DeviceId,
room_id: &RoomId,
since: PduCount,
) -> Result<()> {
if let Some(user_ids) = self.lazy_load_waiting.lock().unwrap().remove(&(
if let Some(user_ids) = self.lazy_load_waiting.lock().await.remove(&(
user_id.to_owned(),
device_id.to_owned(),
room_id.to_owned(),

View file

@ -1,4 +1,4 @@
use std::sync::{Arc, Mutex};
use std::sync::Arc;
use lru_cache::LruCache;
use ruma::{
@ -25,6 +25,7 @@ use ruma::{
space::SpaceRoomJoinRule,
OwnedRoomId, RoomId, UserId,
};
use tokio::sync::Mutex;
use tracing::{debug, error, warn};
@ -79,7 +80,7 @@ impl Service {
if let Some(cached) = self
.roomid_spacechunk_cache
.lock()
.unwrap()
.await
.get_mut(&current_room.to_owned())
.as_ref()
{
@ -171,7 +172,7 @@ impl Service {
.transpose()?
.unwrap_or(JoinRule::Invite);
self.roomid_spacechunk_cache.lock().unwrap().insert(
self.roomid_spacechunk_cache.lock().await.insert(
current_room.clone(),
Some(CachedSpaceChunk {
chunk,
@ -265,7 +266,7 @@ impl Service {
}
}
self.roomid_spacechunk_cache.lock().unwrap().insert(
self.roomid_spacechunk_cache.lock().await.insert(
current_room.clone(),
Some(CachedSpaceChunk {
chunk,
@ -289,7 +290,7 @@ impl Service {
} else {
self.roomid_spacechunk_cache
.lock()
.unwrap()
.await
.insert(current_room.clone(), None);
}
}

View file

@ -95,7 +95,7 @@ impl Service {
.spaces
.roomid_spacechunk_cache
.lock()
.unwrap()
.await
.remove(&pdu.room_id);
}
_ => continue,

View file

@ -2,12 +2,8 @@ mod data;
use std::{
cmp::Ordering,
collections::{BTreeMap, HashMap},
};
use std::{
collections::HashSet,
sync::{Arc, Mutex, RwLock},
collections::{BTreeMap, HashMap, HashSet},
sync::Arc,
};
pub use data::Data;
@ -32,7 +28,7 @@ use ruma::{
};
use serde::Deserialize;
use serde_json::value::{to_raw_value, RawValue as RawJsonValue};
use tokio::sync::MutexGuard;
use tokio::sync::{Mutex, MutexGuard, RwLock};
use tracing::{error, info, warn};
use crate::{
@ -201,7 +197,7 @@ impl Service {
///
/// Returns pdu id
#[tracing::instrument(skip(self, pdu, pdu_json, leaves))]
pub fn append_pdu<'a>(
pub async fn append_pdu<'a>(
&self,
pdu: &PduEvent,
mut pdu_json: CanonicalJsonObject,
@ -263,11 +259,11 @@ impl Service {
.globals
.roomid_mutex_insert
.write()
.unwrap()
.await
.entry(pdu.room_id.clone())
.or_default(),
);
let insert_lock = mutex_insert.lock().unwrap();
let insert_lock = mutex_insert.lock().await;
let count1 = services().globals.next_count()?;
// Mark as read first so the sending client doesn't get a notification even if appending
@ -395,7 +391,7 @@ impl Service {
.spaces
.roomid_spacechunk_cache
.lock()
.unwrap()
.await
.remove(&pdu.room_id);
}
}
@ -806,7 +802,7 @@ impl Service {
/// Creates a new persisted data unit and adds it to a room. This function takes a
/// roomid_mutex_state, meaning that only this function is able to mutate the room state.
#[tracing::instrument(skip(self, state_lock))]
pub fn build_and_append_pdu(
pub async fn build_and_append_pdu(
&self,
pdu_builder: PduBuilder,
sender: &UserId,
@ -902,14 +898,16 @@ impl Service {
// pdu without it's state. This is okay because append_pdu can't fail.
let statehashid = services().rooms.state.append_to_state(&pdu)?;
let pdu_id = self.append_pdu(
&pdu,
pdu_json,
// Since this PDU references all pdu_leaves we can update the leaves
// of the room
vec![(*pdu.event_id).to_owned()],
state_lock,
)?;
let pdu_id = self
.append_pdu(
&pdu,
pdu_json,
// Since this PDU references all pdu_leaves we can update the leaves
// of the room
vec![(*pdu.event_id).to_owned()],
state_lock,
)
.await?;
// We set the room state after inserting the pdu, so that we never have a moment in time
// where events in the current room state do not exist
@ -947,7 +945,7 @@ impl Service {
/// Append the incoming event setting the state snapshot to the state from the
/// server that sent the event.
#[tracing::instrument(skip_all)]
pub fn append_incoming_pdu<'a>(
pub async fn append_incoming_pdu<'a>(
&self,
pdu: &PduEvent,
pdu_json: CanonicalJsonObject,
@ -977,11 +975,11 @@ impl Service {
return Ok(None);
}
let pdu_id =
services()
.rooms
.timeline
.append_pdu(pdu, pdu_json, new_room_leaves, state_lock)?;
let pdu_id = services()
.rooms
.timeline
.append_pdu(pdu, pdu_json, new_room_leaves, state_lock)
.await?;
Ok(Some(pdu_id))
}
@ -1118,7 +1116,7 @@ impl Service {
.globals
.roomid_mutex_federation
.write()
.unwrap()
.await
.entry(room_id.to_owned())
.or_default(),
);
@ -1150,11 +1148,11 @@ impl Service {
.globals
.roomid_mutex_insert
.write()
.unwrap()
.await
.entry(room_id.clone())
.or_default(),
);
let insert_lock = mutex_insert.lock().unwrap();
let insert_lock = mutex_insert.lock().await;
let count = services().globals.next_count()?;
let mut pdu_id = shortroomid.to_be_bytes().to_vec();

View file

@ -45,7 +45,7 @@ impl Service {
self.db.exists(user_id)
}
pub fn forget_sync_request_connection(
pub async fn forget_sync_request_connection(
&self,
user_id: OwnedUserId,
device_id: OwnedDeviceId,
@ -186,7 +186,7 @@ impl Service {
cached.known_rooms.clone()
}
pub fn update_sync_subscriptions(
pub async fn update_sync_subscriptions(
&self,
user_id: OwnedUserId,
device_id: OwnedDeviceId,
@ -212,7 +212,7 @@ impl Service {
cached.subscriptions = subscriptions;
}
pub fn update_sync_known_rooms(
pub async fn update_sync_known_rooms(
&self,
user_id: OwnedUserId,
device_id: OwnedDeviceId,