improvement: faster incoming transaction handling
This commit is contained in:
parent
bf7e019a68
commit
46d8a46e1f
12 changed files with 365 additions and 280 deletions
|
@ -110,6 +110,7 @@ pub struct Rooms {
|
|||
impl Rooms {
|
||||
/// Builds a StateMap by iterating over all keys that start
|
||||
/// with state_hash, this gives the full state for the given state_hash.
|
||||
#[tracing::instrument(skip(self))]
|
||||
pub fn state_full_ids(&self, shortstatehash: u64) -> Result<BTreeSet<EventId>> {
|
||||
let full_state = self
|
||||
.load_shortstatehash_info(shortstatehash)?
|
||||
|
@ -122,6 +123,7 @@ impl Rooms {
|
|||
.collect()
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self))]
|
||||
pub fn state_full(
|
||||
&self,
|
||||
shortstatehash: u64,
|
||||
|
@ -220,6 +222,7 @@ impl Rooms {
|
|||
}
|
||||
|
||||
/// This fetches auth events from the current state.
|
||||
#[tracing::instrument(skip(self))]
|
||||
pub fn get_auth_events(
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
|
@ -261,6 +264,7 @@ impl Rooms {
|
|||
}
|
||||
|
||||
/// Checks if a room exists.
|
||||
#[tracing::instrument(skip(self))]
|
||||
pub fn exists(&self, room_id: &RoomId) -> Result<bool> {
|
||||
let prefix = match self.get_shortroomid(room_id)? {
|
||||
Some(b) => b.to_be_bytes().to_vec(),
|
||||
|
@ -277,6 +281,7 @@ impl Rooms {
|
|||
}
|
||||
|
||||
/// Checks if a room exists.
|
||||
#[tracing::instrument(skip(self))]
|
||||
pub fn first_pdu_in_room(&self, room_id: &RoomId) -> Result<Option<Arc<PduEvent>>> {
|
||||
let prefix = self
|
||||
.get_shortroomid(room_id)?
|
||||
|
@ -300,6 +305,7 @@ impl Rooms {
|
|||
/// Force the creation of a new StateHash and insert it into the db.
|
||||
///
|
||||
/// Whatever `state` is supplied to `force_state` becomes the new current room state snapshot.
|
||||
#[tracing::instrument(skip(self, new_state, db))]
|
||||
pub fn force_state(
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
|
@ -412,6 +418,7 @@ impl Rooms {
|
|||
}
|
||||
|
||||
/// Returns a stack with info on shortstatehash, full state, added diff and removed diff for the selected shortstatehash and each parent layer.
|
||||
#[tracing::instrument(skip(self))]
|
||||
pub fn load_shortstatehash_info(
|
||||
&self,
|
||||
shortstatehash: u64,
|
||||
|
@ -480,6 +487,7 @@ impl Rooms {
|
|||
}
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self, globals))]
|
||||
pub fn compress_state_event(
|
||||
&self,
|
||||
shortstatekey: u64,
|
||||
|
@ -495,6 +503,7 @@ impl Rooms {
|
|||
Ok(v.try_into().expect("we checked the size above"))
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self, compressed_event))]
|
||||
pub fn parse_compressed_state_event(
|
||||
&self,
|
||||
compressed_event: CompressedStateEvent,
|
||||
|
@ -518,6 +527,13 @@ impl Rooms {
|
|||
/// * `statediffremoved` - Removed from base. Each vec is shortstatekey+shorteventid
|
||||
/// * `diff_to_sibling` - Approximately how much the diff grows each time for this layer
|
||||
/// * `parent_states` - A stack with info on shortstatehash, full state, added diff and removed diff for each parent layer
|
||||
#[tracing::instrument(skip(
|
||||
self,
|
||||
statediffnew,
|
||||
statediffremoved,
|
||||
diff_to_sibling,
|
||||
parent_states
|
||||
))]
|
||||
pub fn save_state_from_diff(
|
||||
&self,
|
||||
shortstatehash: u64,
|
||||
|
@ -642,6 +658,7 @@ impl Rooms {
|
|||
}
|
||||
|
||||
/// Returns (shortstatehash, already_existed)
|
||||
#[tracing::instrument(skip(self, globals))]
|
||||
fn get_or_create_shortstatehash(
|
||||
&self,
|
||||
state_hash: &StateHashId,
|
||||
|
@ -662,6 +679,7 @@ impl Rooms {
|
|||
})
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self, globals))]
|
||||
pub fn get_or_create_shorteventid(
|
||||
&self,
|
||||
event_id: &EventId,
|
||||
|
@ -692,6 +710,7 @@ impl Rooms {
|
|||
Ok(short)
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self))]
|
||||
pub fn get_shortroomid(&self, room_id: &RoomId) -> Result<Option<u64>> {
|
||||
self.roomid_shortroomid
|
||||
.get(&room_id.as_bytes())?
|
||||
|
@ -702,6 +721,7 @@ impl Rooms {
|
|||
.transpose()
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self))]
|
||||
pub fn get_shortstatekey(
|
||||
&self,
|
||||
event_type: &EventType,
|
||||
|
@ -739,6 +759,7 @@ impl Rooms {
|
|||
Ok(short)
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self, globals))]
|
||||
pub fn get_or_create_shortroomid(
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
|
@ -756,6 +777,7 @@ impl Rooms {
|
|||
})
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self, globals))]
|
||||
pub fn get_or_create_shortstatekey(
|
||||
&self,
|
||||
event_type: &EventType,
|
||||
|
@ -794,6 +816,7 @@ impl Rooms {
|
|||
Ok(short)
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self))]
|
||||
pub fn get_eventid_from_short(&self, shorteventid: u64) -> Result<EventId> {
|
||||
if let Some(id) = self
|
||||
.shorteventid_cache
|
||||
|
@ -876,12 +899,14 @@ impl Rooms {
|
|||
}
|
||||
|
||||
/// Returns the `count` of this pdu's id.
|
||||
#[tracing::instrument(skip(self))]
|
||||
pub fn get_pdu_count(&self, event_id: &EventId) -> Result<Option<u64>> {
|
||||
self.eventid_pduid
|
||||
.get(event_id.as_bytes())?
|
||||
.map_or(Ok(None), |pdu_id| self.pdu_count(&pdu_id).map(Some))
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self))]
|
||||
pub fn latest_pdu_count(&self, room_id: &RoomId) -> Result<u64> {
|
||||
let prefix = self
|
||||
.get_shortroomid(room_id)?
|
||||
|
@ -902,6 +927,7 @@ impl Rooms {
|
|||
}
|
||||
|
||||
/// Returns the json of a pdu.
|
||||
#[tracing::instrument(skip(self))]
|
||||
pub fn get_pdu_json(&self, event_id: &EventId) -> Result<Option<CanonicalJsonObject>> {
|
||||
self.eventid_pduid
|
||||
.get(event_id.as_bytes())?
|
||||
|
@ -920,6 +946,7 @@ impl Rooms {
|
|||
}
|
||||
|
||||
/// Returns the json of a pdu.
|
||||
#[tracing::instrument(skip(self))]
|
||||
pub fn get_outlier_pdu_json(&self, event_id: &EventId) -> Result<Option<CanonicalJsonObject>> {
|
||||
self.eventid_outlierpdu
|
||||
.get(event_id.as_bytes())?
|
||||
|
@ -930,6 +957,7 @@ impl Rooms {
|
|||
}
|
||||
|
||||
/// Returns the json of a pdu.
|
||||
#[tracing::instrument(skip(self))]
|
||||
pub fn get_non_outlier_pdu_json(
|
||||
&self,
|
||||
event_id: &EventId,
|
||||
|
@ -951,6 +979,7 @@ impl Rooms {
|
|||
}
|
||||
|
||||
/// Returns the pdu's id.
|
||||
#[tracing::instrument(skip(self))]
|
||||
pub fn get_pdu_id(&self, event_id: &EventId) -> Result<Option<Vec<u8>>> {
|
||||
self.eventid_pduid
|
||||
.get(event_id.as_bytes())?
|
||||
|
@ -960,6 +989,7 @@ impl Rooms {
|
|||
/// Returns the pdu.
|
||||
///
|
||||
/// Checks the `eventid_outlierpdu` Tree if not found in the timeline.
|
||||
#[tracing::instrument(skip(self))]
|
||||
pub fn get_non_outlier_pdu(&self, event_id: &EventId) -> Result<Option<PduEvent>> {
|
||||
self.eventid_pduid
|
||||
.get(event_id.as_bytes())?
|
||||
|
@ -980,6 +1010,7 @@ impl Rooms {
|
|||
/// Returns the pdu.
|
||||
///
|
||||
/// Checks the `eventid_outlierpdu` Tree if not found in the timeline.
|
||||
#[tracing::instrument(skip(self))]
|
||||
pub fn get_pdu(&self, event_id: &EventId) -> Result<Option<Arc<PduEvent>>> {
|
||||
if let Some(p) = self.pdu_cache.lock().unwrap().get_mut(&event_id) {
|
||||
return Ok(Some(Arc::clone(p)));
|
||||
|
@ -1019,6 +1050,7 @@ impl Rooms {
|
|||
/// Returns the pdu.
|
||||
///
|
||||
/// This does __NOT__ check the outliers `Tree`.
|
||||
#[tracing::instrument(skip(self))]
|
||||
pub fn get_pdu_from_id(&self, pdu_id: &[u8]) -> Result<Option<PduEvent>> {
|
||||
self.pduid_pdu.get(pdu_id)?.map_or(Ok(None), |pdu| {
|
||||
Ok(Some(
|
||||
|
@ -1029,6 +1061,7 @@ impl Rooms {
|
|||
}
|
||||
|
||||
/// Returns the pdu as a `BTreeMap<String, CanonicalJsonValue>`.
|
||||
#[tracing::instrument(skip(self))]
|
||||
pub fn get_pdu_json_from_id(&self, pdu_id: &[u8]) -> Result<Option<CanonicalJsonObject>> {
|
||||
self.pduid_pdu.get(pdu_id)?.map_or(Ok(None), |pdu| {
|
||||
Ok(Some(
|
||||
|
@ -1039,6 +1072,7 @@ impl Rooms {
|
|||
}
|
||||
|
||||
/// Removes a pdu and creates a new one with the same id.
|
||||
#[tracing::instrument(skip(self))]
|
||||
fn replace_pdu(&self, pdu_id: &[u8], pdu: &PduEvent) -> Result<()> {
|
||||
if self.pduid_pdu.get(&pdu_id)?.is_some() {
|
||||
self.pduid_pdu.insert(
|
||||
|
@ -2298,6 +2332,7 @@ impl Rooms {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self))]
|
||||
pub fn update_joined_count(&self, room_id: &RoomId) -> Result<()> {
|
||||
let mut joinedcount = 0_u64;
|
||||
let mut joined_servers = HashSet::new();
|
||||
|
@ -2347,6 +2382,7 @@ impl Rooms {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self, db))]
|
||||
pub async fn leave_room(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
|
@ -2419,6 +2455,7 @@ impl Rooms {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self, db))]
|
||||
async fn remote_leave_room(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
|
@ -2650,6 +2687,7 @@ impl Rooms {
|
|||
})
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self))]
|
||||
pub fn search_pdus<'a>(
|
||||
&'a self,
|
||||
room_id: &RoomId,
|
||||
|
@ -2809,6 +2847,7 @@ impl Rooms {
|
|||
})
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self))]
|
||||
pub fn room_joined_count(&self, room_id: &RoomId) -> Result<Option<u64>> {
|
||||
Ok(self
|
||||
.roomid_joinedcount
|
||||
|
|
|
@ -4,11 +4,14 @@ use crate::{client_server::SESSION_ID_LENGTH, utils, Error, Result};
|
|||
use ruma::{
|
||||
api::client::{
|
||||
error::ErrorKind,
|
||||
r0::uiaa::{IncomingAuthData, UiaaInfo},
|
||||
r0::uiaa::{
|
||||
IncomingAuthData, IncomingPassword, IncomingUserIdentifier::MatrixId, UiaaInfo,
|
||||
},
|
||||
},
|
||||
signatures::CanonicalJsonValue,
|
||||
DeviceId, UserId,
|
||||
};
|
||||
use tracing::error;
|
||||
|
||||
use super::abstraction::Tree;
|
||||
|
||||
|
@ -49,126 +52,91 @@ impl Uiaa {
|
|||
users: &super::users::Users,
|
||||
globals: &super::globals::Globals,
|
||||
) -> Result<(bool, UiaaInfo)> {
|
||||
if let IncomingAuthData::DirectRequest {
|
||||
kind,
|
||||
session,
|
||||
auth_parameters,
|
||||
} = &auth
|
||||
{
|
||||
let mut uiaainfo = session
|
||||
.as_ref()
|
||||
.map(|session| self.get_uiaa_session(&user_id, &device_id, session))
|
||||
.unwrap_or_else(|| Ok(uiaainfo.clone()))?;
|
||||
let mut uiaainfo = auth
|
||||
.session()
|
||||
.map(|session| self.get_uiaa_session(&user_id, &device_id, session))
|
||||
.unwrap_or_else(|| Ok(uiaainfo.clone()))?;
|
||||
|
||||
if uiaainfo.session.is_none() {
|
||||
uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH));
|
||||
}
|
||||
if uiaainfo.session.is_none() {
|
||||
uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH));
|
||||
}
|
||||
|
||||
match auth {
|
||||
// Find out what the user completed
|
||||
match &**kind {
|
||||
"m.login.password" => {
|
||||
let identifier = auth_parameters.get("identifier").ok_or(Error::BadRequest(
|
||||
ErrorKind::MissingParam,
|
||||
"m.login.password needs identifier.",
|
||||
))?;
|
||||
|
||||
let identifier_type = identifier.get("type").ok_or(Error::BadRequest(
|
||||
ErrorKind::MissingParam,
|
||||
"Identifier needs a type.",
|
||||
))?;
|
||||
|
||||
if identifier_type != "m.id.user" {
|
||||
IncomingAuthData::Password(IncomingPassword {
|
||||
identifier,
|
||||
password,
|
||||
..
|
||||
}) => {
|
||||
let username = match identifier {
|
||||
MatrixId(username) => username,
|
||||
_ => {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::Unrecognized,
|
||||
"Identifier type not recognized.",
|
||||
));
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
let username = identifier
|
||||
.get("user")
|
||||
.ok_or(Error::BadRequest(
|
||||
ErrorKind::MissingParam,
|
||||
"Identifier needs user field.",
|
||||
))?
|
||||
.as_str()
|
||||
.ok_or(Error::BadRequest(
|
||||
ErrorKind::BadJson,
|
||||
"User is not a string.",
|
||||
))?;
|
||||
|
||||
let user_id = UserId::parse_with_server_name(username, globals.server_name())
|
||||
let user_id =
|
||||
UserId::parse_with_server_name(username.clone(), globals.server_name())
|
||||
.map_err(|_| {
|
||||
Error::BadRequest(ErrorKind::InvalidParam, "User ID is invalid.")
|
||||
})?;
|
||||
Error::BadRequest(ErrorKind::InvalidParam, "User ID is invalid.")
|
||||
})?;
|
||||
|
||||
let password = auth_parameters
|
||||
.get("password")
|
||||
.ok_or(Error::BadRequest(
|
||||
ErrorKind::MissingParam,
|
||||
"Password is missing.",
|
||||
))?
|
||||
.as_str()
|
||||
.ok_or(Error::BadRequest(
|
||||
ErrorKind::BadJson,
|
||||
"Password is not a string.",
|
||||
))?;
|
||||
// Check if password is correct
|
||||
if let Some(hash) = users.password_hash(&user_id)? {
|
||||
let hash_matches =
|
||||
argon2::verify_encoded(&hash, password.as_bytes()).unwrap_or(false);
|
||||
|
||||
// Check if password is correct
|
||||
if let Some(hash) = users.password_hash(&user_id)? {
|
||||
let hash_matches =
|
||||
argon2::verify_encoded(&hash, password.as_bytes()).unwrap_or(false);
|
||||
|
||||
if !hash_matches {
|
||||
uiaainfo.auth_error = Some(ruma::api::client::error::ErrorBody {
|
||||
kind: ErrorKind::Forbidden,
|
||||
message: "Invalid username or password.".to_owned(),
|
||||
});
|
||||
return Ok((false, uiaainfo));
|
||||
}
|
||||
}
|
||||
|
||||
// Password was correct! Let's add it to `completed`
|
||||
uiaainfo.completed.push("m.login.password".to_owned());
|
||||
}
|
||||
"m.login.dummy" => {
|
||||
uiaainfo.completed.push("m.login.dummy".to_owned());
|
||||
}
|
||||
k => panic!("type not supported: {}", k),
|
||||
}
|
||||
|
||||
// Check if a flow now succeeds
|
||||
let mut completed = false;
|
||||
'flows: for flow in &mut uiaainfo.flows {
|
||||
for stage in &flow.stages {
|
||||
if !uiaainfo.completed.contains(stage) {
|
||||
continue 'flows;
|
||||
if !hash_matches {
|
||||
uiaainfo.auth_error = Some(ruma::api::client::error::ErrorBody {
|
||||
kind: ErrorKind::Forbidden,
|
||||
message: "Invalid username or password.".to_owned(),
|
||||
});
|
||||
return Ok((false, uiaainfo));
|
||||
}
|
||||
}
|
||||
// We didn't break, so this flow succeeded!
|
||||
completed = true;
|
||||
}
|
||||
|
||||
if !completed {
|
||||
self.update_uiaa_session(
|
||||
user_id,
|
||||
device_id,
|
||||
uiaainfo.session.as_ref().expect("session is always set"),
|
||||
Some(&uiaainfo),
|
||||
)?;
|
||||
return Ok((false, uiaainfo));
|
||||
// Password was correct! Let's add it to `completed`
|
||||
uiaainfo.completed.push("m.login.password".to_owned());
|
||||
}
|
||||
IncomingAuthData::Dummy(_) => {
|
||||
uiaainfo.completed.push("m.login.dummy".to_owned());
|
||||
}
|
||||
k => error!("type not supported: {:?}", k),
|
||||
}
|
||||
|
||||
// UIAA was successful! Remove this session and return true
|
||||
// Check if a flow now succeeds
|
||||
let mut completed = false;
|
||||
'flows: for flow in &mut uiaainfo.flows {
|
||||
for stage in &flow.stages {
|
||||
if !uiaainfo.completed.contains(stage) {
|
||||
continue 'flows;
|
||||
}
|
||||
}
|
||||
// We didn't break, so this flow succeeded!
|
||||
completed = true;
|
||||
}
|
||||
|
||||
if !completed {
|
||||
self.update_uiaa_session(
|
||||
user_id,
|
||||
device_id,
|
||||
uiaainfo.session.as_ref().expect("session is always set"),
|
||||
None,
|
||||
Some(&uiaainfo),
|
||||
)?;
|
||||
Ok((true, uiaainfo))
|
||||
} else {
|
||||
panic!("FallbackAcknowledgement is not supported yet");
|
||||
return Ok((false, uiaainfo));
|
||||
}
|
||||
|
||||
// UIAA was successful! Remove this session and return true
|
||||
self.update_uiaa_session(
|
||||
user_id,
|
||||
device_id,
|
||||
uiaainfo.session.as_ref().expect("session is always set"),
|
||||
None,
|
||||
)?;
|
||||
Ok((true, uiaainfo))
|
||||
}
|
||||
|
||||
fn set_uiaa_request(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue