Port from Rocket to axum

This commit is contained in:
Jonas Platte 2022-01-20 11:51:31 +01:00
parent 8709c3ae7b
commit 1f7b3fa4ac
No known key found for this signature in database
GPG key ID: 7D261D771D915378
52 changed files with 1064 additions and 1885 deletions

View file

@ -13,16 +13,12 @@ pub mod transaction_ids;
pub mod uiaa;
pub mod users;
use self::admin::create_admin_room;
use crate::{utils, Config, Error, Result};
use abstraction::DatabaseEngine;
use directories::ProjectDirs;
use futures_util::{stream::FuturesUnordered, StreamExt};
use lru_cache::LruCache;
use rocket::{
futures::{channel::mpsc, stream::FuturesUnordered, StreamExt},
outcome::{try_outcome, IntoOutcome},
request::{FromRequest, Request},
Shutdown, State,
};
use ruma::{DeviceId, EventId, RoomId, UserId};
use std::{
collections::{BTreeMap, HashMap, HashSet},
@ -33,11 +29,9 @@ use std::{
path::Path,
sync::{Arc, Mutex, RwLock},
};
use tokio::sync::{OwnedRwLockReadGuard, RwLock as TokioRwLock, Semaphore};
use tokio::sync::{mpsc, OwnedRwLockReadGuard, RwLock as TokioRwLock, Semaphore};
use tracing::{debug, error, info, warn};
use self::admin::create_admin_room;
pub struct Database {
_db: Arc<dyn DatabaseEngine>,
pub globals: globals::Globals,
@ -151,8 +145,8 @@ impl Database {
eprintln!("ERROR: Max request size is less than 1KB. Please increase it.");
}
let (admin_sender, admin_receiver) = mpsc::unbounded();
let (sending_sender, sending_receiver) = mpsc::unbounded();
let (admin_sender, admin_receiver) = mpsc::unbounded_channel();
let (sending_sender, sending_receiver) = mpsc::unbounded_channel();
let db = Arc::new(TokioRwLock::from(Self {
_db: builder.clone(),
@ -764,14 +758,9 @@ impl Database {
}
#[cfg(feature = "conduit_bin")]
pub async fn start_on_shutdown_tasks(db: Arc<TokioRwLock<Self>>, shutdown: Shutdown) {
tokio::spawn(async move {
shutdown.await;
info!(target: "shutdown-sync", "Received shutdown notification, notifying sync helpers...");
db.read().await.globals.rotate.fire();
});
pub async fn on_shutdown(db: Arc<TokioRwLock<Self>>) {
info!(target: "shutdown-sync", "Received shutdown notification, notifying sync helpers...");
db.read().await.globals.rotate.fire();
}
pub async fn watch(&self, user_id: &UserId, device_id: &DeviceId) {
@ -948,14 +937,23 @@ impl Deref for DatabaseGuard {
}
}
#[rocket::async_trait]
impl<'r> FromRequest<'r> for DatabaseGuard {
type Error = ();
#[cfg(feature = "conduit_bin")]
#[axum::async_trait]
impl<B> axum::extract::FromRequest<B> for DatabaseGuard
where
B: Send,
{
type Rejection = axum::extract::rejection::ExtensionRejection;
async fn from_request(req: &'r Request<'_>) -> rocket::request::Outcome<Self, ()> {
let db = try_outcome!(req.guard::<&State<Arc<TokioRwLock<Database>>>>().await);
async fn from_request(
req: &mut axum::extract::RequestParts<B>,
) -> Result<Self, Self::Rejection> {
use axum::extract::Extension;
Ok(DatabaseGuard(Arc::clone(db).read_owned().await)).or_forward(())
let Extension(db): Extension<Arc<TokioRwLock<Database>>> =
Extension::from_request(req).await?;
Ok(DatabaseGuard(db.read_owned().await))
}
}