mod data; use std::collections::HashSet; pub use data::Data; use ruma::{RoomId, events::{room::{member::MembershipState, create::RoomCreateEventContent}, AnyStrippedStateEvent, StateEventType}, UserId, EventId, serde::Raw, RoomVersionId}; use serde::Deserialize; use tracing::warn; use crate::{service::*, SERVICE, PduEvent, Error, utils::calculate_hash}; pub struct Service { db: D, } impl Service<_> { /// Set the room to the given statehash and update caches. pub fn force_state( &self, room_id: &RoomId, shortstatehash: u64, statediffnew: HashSet, statediffremoved: HashSet, ) -> Result<()> { for event_id in statediffnew.into_iter().filter_map(|new| { SERVICE.rooms.state_compressor.parse_compressed_state_event(new) .ok() .map(|(_, id)| id) }) { let pdu = match SERVICE.rooms.timeline.get_pdu_json(&event_id)? { Some(pdu) => pdu, None => continue, }; if pdu.get("type").and_then(|val| val.as_str()) != Some("m.room.member") { continue; } let pdu: PduEvent = match serde_json::from_str( &serde_json::to_string(&pdu).expect("CanonicalJsonObj can be serialized to JSON"), ) { Ok(pdu) => pdu, Err(_) => continue, }; #[derive(Deserialize)] struct ExtractMembership { membership: MembershipState, } let membership = match serde_json::from_str::(pdu.content.get()) { Ok(e) => e.membership, Err(_) => continue, }; let state_key = match pdu.state_key { Some(k) => k, None => continue, }; let user_id = match UserId::parse(state_key) { Ok(id) => id, Err(_) => continue, }; SERVICE.room.state_cache.update_membership(room_id, &user_id, membership, &pdu.sender, None, false)?; } SERVICE.room.state_cache.update_joined_count(room_id)?; self.db.set_room_state(room_id, shortstatehash); Ok(()) } /// Generates a new StateHash and associates it with the incoming event. /// /// This adds all current state events (not including the incoming event) /// to `stateid_pduid` and adds the incoming event to `eventid_statehash`. #[tracing::instrument(skip(self, state_ids_compressed))] pub fn set_event_state( &self, event_id: &EventId, room_id: &RoomId, state_ids_compressed: HashSet, ) -> Result<()> { let shorteventid = SERVICE.short.get_or_create_shorteventid(event_id)?; let previous_shortstatehash = self.db.get_room_shortstatehash(room_id)?; let state_hash = calculate_hash( &state_ids_compressed .iter() .map(|s| &s[..]) .collect::>(), ); let (shortstatehash, already_existed) = SERVICE.short.get_or_create_shortstatehash(&state_hash)?; if !already_existed { let states_parents = previous_shortstatehash .map_or_else(|| Ok(Vec::new()), |p| SERVICE.room.state_compressor.load_shortstatehash_info(p))?; let (statediffnew, statediffremoved) = if let Some(parent_stateinfo) = states_parents.last() { let statediffnew: HashSet<_> = state_ids_compressed .difference(&parent_stateinfo.1) .copied() .collect(); let statediffremoved: HashSet<_> = parent_stateinfo .1 .difference(&state_ids_compressed) .copied() .collect(); (statediffnew, statediffremoved) } else { (state_ids_compressed, HashSet::new()) }; SERVICE.room.state_compressor.save_state_from_diff( shortstatehash, statediffnew, statediffremoved, 1_000_000, // high number because no state will be based on this one states_parents, )?; } self.db.set_event_state(&shorteventid.to_be_bytes(), &shortstatehash.to_be_bytes())?; Ok(()) } /// Generates a new StateHash and associates it with the incoming event. /// /// This adds all current state events (not including the incoming event) /// to `stateid_pduid` and adds the incoming event to `eventid_statehash`. #[tracing::instrument(skip(self, new_pdu))] pub fn append_to_state( &self, new_pdu: &PduEvent, ) -> Result { let shorteventid = self.get_or_create_shorteventid(&new_pdu.event_id)?; let previous_shortstatehash = self.get_room_shortstatehash(&new_pdu.room_id)?; if let Some(p) = previous_shortstatehash { self.shorteventid_shortstatehash .insert(&shorteventid.to_be_bytes(), &p.to_be_bytes())?; } if let Some(state_key) = &new_pdu.state_key { let states_parents = previous_shortstatehash .map_or_else(|| Ok(Vec::new()), |p| self.load_shortstatehash_info(p))?; let shortstatekey = self.get_or_create_shortstatekey( &new_pdu.kind.to_string().into(), state_key, )?; let new = self.compress_state_event(shortstatekey, &new_pdu.event_id)?; let replaces = states_parents .last() .map(|info| { info.1 .iter() .find(|bytes| bytes.starts_with(&shortstatekey.to_be_bytes())) }) .unwrap_or_default(); if Some(&new) == replaces { return Ok(previous_shortstatehash.expect("must exist")); } // TODO: statehash with deterministic inputs let shortstatehash = SERVICE.globals.next_count()?; let mut statediffnew = HashSet::new(); statediffnew.insert(new); let mut statediffremoved = HashSet::new(); if let Some(replaces) = replaces { statediffremoved.insert(*replaces); } self.save_state_from_diff( shortstatehash, statediffnew, statediffremoved, 2, states_parents, )?; Ok(shortstatehash) } else { Ok(previous_shortstatehash.expect("first event in room must be a state event")) } } #[tracing::instrument(skip(self, invite_event))] pub fn calculate_invite_state( &self, invite_event: &PduEvent, ) -> Result>> { let mut state = Vec::new(); // Add recommended events if let Some(e) = self.room_state_get(&invite_event.room_id, &StateEventType::RoomCreate, "")? { state.push(e.to_stripped_state_event()); } if let Some(e) = self.room_state_get(&invite_event.room_id, &StateEventType::RoomJoinRules, "")? { state.push(e.to_stripped_state_event()); } if let Some(e) = self.room_state_get( &invite_event.room_id, &StateEventType::RoomCanonicalAlias, "", )? { state.push(e.to_stripped_state_event()); } if let Some(e) = self.room_state_get(&invite_event.room_id, &StateEventType::RoomAvatar, "")? { state.push(e.to_stripped_state_event()); } if let Some(e) = self.room_state_get(&invite_event.room_id, &StateEventType::RoomName, "")? { state.push(e.to_stripped_state_event()); } if let Some(e) = self.room_state_get( &invite_event.room_id, &StateEventType::RoomMember, invite_event.sender.as_str(), )? { state.push(e.to_stripped_state_event()); } state.push(invite_event.to_stripped_state_event()); Ok(state) } #[tracing::instrument(skip(self))] pub fn set_room_state(&self, room_id: &RoomId, shortstatehash: u64) -> Result<()> { self.roomid_shortstatehash .insert(room_id.as_bytes(), &shortstatehash.to_be_bytes())?; Ok(()) } /// Returns the room's version. #[tracing::instrument(skip(self))] pub fn get_room_version(&self, room_id: &RoomId) -> Result { let create_event = self.room_state_get(room_id, &StateEventType::RoomCreate, "")?; let create_event_content: Option = create_event .as_ref() .map(|create_event| { serde_json::from_str(create_event.content.get()).map_err(|e| { warn!("Invalid create event: {}", e); Error::bad_database("Invalid create event in db.") }) }) .transpose()?; let room_version = create_event_content .map(|create_event| create_event.room_version) .ok_or_else(|| Error::BadDatabase("Invalid room version"))?; Ok(room_version) } }