improvement: upgrade dependencies, fix timeline reload bug
This commit is contained in:
parent
164b1633d8
commit
45086b54b3
8 changed files with 202 additions and 148 deletions
|
@ -13,12 +13,39 @@ use crate::{Error, Result};
|
|||
use directories::ProjectDirs;
|
||||
use futures::StreamExt;
|
||||
use log::info;
|
||||
use rocket::{
|
||||
futures::{self, channel::mpsc},
|
||||
Config,
|
||||
};
|
||||
use ruma::{DeviceId, UserId};
|
||||
use std::{convert::TryFrom, fs::remove_dir_all};
|
||||
use rocket::futures::{self, channel::mpsc};
|
||||
use ruma::{DeviceId, ServerName, UserId};
|
||||
use serde::Deserialize;
|
||||
use std::{convert::TryInto, fs::remove_dir_all};
|
||||
|
||||
#[derive(Clone, Deserialize)]
|
||||
pub struct Config {
|
||||
#[serde(default = "default_server_name")]
|
||||
server_name: Box<ServerName>,
|
||||
database_path: Option<String>,
|
||||
#[serde(default = "default_cache_capacity")]
|
||||
cache_capacity: u64,
|
||||
#[serde(default = "default_max_request_size")]
|
||||
max_request_size: u32,
|
||||
#[serde(default)]
|
||||
registration_disabled: bool,
|
||||
#[serde(default)]
|
||||
encryption_disabled: bool,
|
||||
#[serde(default)]
|
||||
federation_enabled: bool,
|
||||
}
|
||||
|
||||
fn default_server_name() -> Box<ServerName> {
|
||||
"localhost".try_into().expect("")
|
||||
}
|
||||
|
||||
fn default_cache_capacity() -> u64 {
|
||||
1024 * 1024 * 1024
|
||||
}
|
||||
|
||||
fn default_max_request_size() -> u32 {
|
||||
20 * 1024 * 1024 // Default to 20 MB
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Database {
|
||||
|
@ -49,19 +76,18 @@ impl Database {
|
|||
}
|
||||
|
||||
/// Load an existing database or create a new one.
|
||||
pub fn load_or_create(config: &Config) -> Result<Self> {
|
||||
let server_name = config.get_str("server_name").unwrap_or("localhost");
|
||||
|
||||
pub fn load_or_create(config: Config) -> Result<Self> {
|
||||
let path = config
|
||||
.get_str("database_path")
|
||||
.map(|x| Ok::<_, Error>(x.to_owned()))
|
||||
.unwrap_or_else(|_| {
|
||||
.database_path
|
||||
.clone()
|
||||
.map(Ok::<_, Error>)
|
||||
.unwrap_or_else(|| {
|
||||
let path = ProjectDirs::from("xyz", "koesters", "conduit")
|
||||
.ok_or_else(|| {
|
||||
Error::bad_config("The OS didn't return a valid home directory path.")
|
||||
})?
|
||||
.data_dir()
|
||||
.join(server_name);
|
||||
.join(config.server_name.as_str());
|
||||
|
||||
Ok(path
|
||||
.to_str()
|
||||
|
@ -71,15 +97,8 @@ impl Database {
|
|||
|
||||
let db = sled::Config::default()
|
||||
.path(&path)
|
||||
.cache_capacity(
|
||||
u64::try_from(
|
||||
config
|
||||
.get_int("cache_capacity")
|
||||
.unwrap_or(1024 * 1024 * 1024),
|
||||
)
|
||||
.map_err(|_| Error::bad_config("Cache capacity needs to be a u64."))?,
|
||||
)
|
||||
.print_profile_on_drop(false)
|
||||
.cache_capacity(config.cache_capacity)
|
||||
.print_profile_on_drop(true)
|
||||
.open()?;
|
||||
|
||||
info!("Opened sled database at {}", path);
|
||||
|
|
|
@ -49,8 +49,6 @@ impl Admin {
|
|||
Some(event) = receiver.next() => {
|
||||
match event {
|
||||
AdminCommand::SendTextMessage(message) => {
|
||||
println!("{:?}", message);
|
||||
|
||||
if let Some(conduit_room) = &conduit_room {
|
||||
db.rooms.build_and_append_pdu(
|
||||
PduBuilder {
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
use crate::{utils, Error, Result};
|
||||
use crate::{database::Config, utils, Error, Result};
|
||||
use log::error;
|
||||
use ruma::ServerName;
|
||||
use std::{convert::TryInto, sync::Arc};
|
||||
use std::sync::Arc;
|
||||
|
||||
pub const COUNTER: &str = "c";
|
||||
|
||||
|
@ -10,15 +10,11 @@ pub struct Globals {
|
|||
pub(super) globals: sled::Tree,
|
||||
keypair: Arc<ruma::signatures::Ed25519KeyPair>,
|
||||
reqwest_client: reqwest::Client,
|
||||
server_name: Box<ServerName>,
|
||||
max_request_size: u32,
|
||||
registration_disabled: bool,
|
||||
encryption_disabled: bool,
|
||||
federation_enabled: bool,
|
||||
config: Config,
|
||||
}
|
||||
|
||||
impl Globals {
|
||||
pub fn load(globals: sled::Tree, config: &rocket::Config) -> Result<Self> {
|
||||
pub fn load(globals: sled::Tree, config: Config) -> Result<Self> {
|
||||
let bytes = &*globals
|
||||
.update_and_fetch("keypair", utils::generate_keypair)?
|
||||
.expect("utils::generate_keypair always returns Some");
|
||||
|
@ -57,20 +53,7 @@ impl Globals {
|
|||
globals,
|
||||
keypair: Arc::new(keypair),
|
||||
reqwest_client: reqwest::Client::new(),
|
||||
server_name: config
|
||||
.get_str("server_name")
|
||||
.unwrap_or("localhost")
|
||||
.to_string()
|
||||
.try_into()
|
||||
.map_err(|_| Error::bad_config("Invalid server_name."))?,
|
||||
max_request_size: config
|
||||
.get_int("max_request_size")
|
||||
.unwrap_or(20 * 1024 * 1024) // Default to 20 MB
|
||||
.try_into()
|
||||
.map_err(|_| Error::bad_config("Invalid max_request_size."))?,
|
||||
registration_disabled: config.get_bool("registration_disabled").unwrap_or(false),
|
||||
encryption_disabled: config.get_bool("encryption_disabled").unwrap_or(false),
|
||||
federation_enabled: config.get_bool("federation_enabled").unwrap_or(false),
|
||||
config,
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -102,22 +85,22 @@ impl Globals {
|
|||
}
|
||||
|
||||
pub fn server_name(&self) -> &ServerName {
|
||||
self.server_name.as_ref()
|
||||
self.config.server_name.as_ref()
|
||||
}
|
||||
|
||||
pub fn max_request_size(&self) -> u32 {
|
||||
self.max_request_size
|
||||
self.config.max_request_size
|
||||
}
|
||||
|
||||
pub fn registration_disabled(&self) -> bool {
|
||||
self.registration_disabled
|
||||
self.config.registration_disabled
|
||||
}
|
||||
|
||||
pub fn encryption_disabled(&self) -> bool {
|
||||
self.encryption_disabled
|
||||
self.config.encryption_disabled
|
||||
}
|
||||
|
||||
pub fn federation_enabled(&self) -> bool {
|
||||
self.federation_enabled
|
||||
self.config.federation_enabled
|
||||
}
|
||||
}
|
||||
|
|
14
src/main.rs
14
src/main.rs
|
@ -21,7 +21,7 @@ use rocket::{fairing::AdHoc, routes};
|
|||
|
||||
fn setup_rocket() -> rocket::Rocket {
|
||||
// Force log level off, so we can use our own logger
|
||||
std::env::set_var("ROCKET_LOG", "off");
|
||||
std::env::set_var("ROCKET_LOG_LEVEL", "off");
|
||||
|
||||
rocket::ignite()
|
||||
.mount(
|
||||
|
@ -123,9 +123,9 @@ fn setup_rocket() -> rocket::Rocket {
|
|||
client_server::get_pushers_route,
|
||||
client_server::set_pushers_route,
|
||||
client_server::upgrade_room_route,
|
||||
server_server::get_server_version,
|
||||
server_server::get_server_keys,
|
||||
server_server::get_server_keys_deprecated,
|
||||
server_server::get_server_version_route,
|
||||
server_server::get_server_keys_route,
|
||||
server_server::get_server_keys_deprecated_route,
|
||||
server_server::get_public_rooms_route,
|
||||
server_server::get_public_rooms_filtered_route,
|
||||
server_server::send_transaction_message_route,
|
||||
|
@ -133,8 +133,10 @@ fn setup_rocket() -> rocket::Rocket {
|
|||
server_server::get_profile_information_route,
|
||||
],
|
||||
)
|
||||
.attach(AdHoc::on_attach("Config", |mut rocket| async {
|
||||
let data = Database::load_or_create(rocket.config().await).expect("valid config");
|
||||
.attach(AdHoc::on_attach("Config", |rocket| async {
|
||||
let data =
|
||||
Database::load_or_create(rocket.figment().extract().expect("config is valid"))
|
||||
.expect("config is valid");
|
||||
|
||||
data.sending.start_handler(&data.globals, &data.rooms);
|
||||
log::set_boxed_logger(Box::new(ConduitLogger {
|
||||
|
|
|
@ -15,7 +15,8 @@ use {
|
|||
log::warn,
|
||||
rocket::{
|
||||
data::{
|
||||
Data, FromDataFuture, FromTransformedData, Transform, TransformFuture, Transformed,
|
||||
ByteUnit, Data, FromDataFuture, FromTransformedData, Transform, TransformFuture,
|
||||
Transformed,
|
||||
},
|
||||
http::Status,
|
||||
outcome::Outcome::*,
|
||||
|
@ -97,7 +98,7 @@ where
|
|||
}
|
||||
|
||||
let limit = db.globals.max_request_size();
|
||||
let mut handle = data.open().take(limit.into());
|
||||
let mut handle = data.open(ByteUnit::Byte(limit.into()));
|
||||
let mut body = Vec::new();
|
||||
handle.read_to_end(&mut body).await.unwrap();
|
||||
|
||||
|
|
|
@ -193,6 +193,8 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
let status = reqwest_response.status();
|
||||
|
||||
let body = reqwest_response
|
||||
.bytes()
|
||||
.await
|
||||
|
@ -201,17 +203,27 @@ where
|
|||
Vec::new().into()
|
||||
}) // TODO: handle timeout
|
||||
.into_iter()
|
||||
.collect();
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
if status != 200 {
|
||||
warn!(
|
||||
"Server returned bad response {} ({}): {} {:?}",
|
||||
destination,
|
||||
url,
|
||||
status,
|
||||
utils::string_from_bytes(&body)
|
||||
);
|
||||
}
|
||||
|
||||
let response = T::IncomingResponse::try_from(
|
||||
http_response
|
||||
.body(body)
|
||||
.expect("reqwest body is valid http body"),
|
||||
);
|
||||
response.map_err(|e| {
|
||||
response.map_err(|_| {
|
||||
warn!(
|
||||
"Server returned bad response {} ({}): {:?}",
|
||||
destination, url, e
|
||||
"Server returned invalid response bytes {} ({})",
|
||||
destination, url
|
||||
);
|
||||
Error::BadServerResponse("Server returned bad response.")
|
||||
})
|
||||
|
@ -221,7 +233,9 @@ where
|
|||
}
|
||||
|
||||
#[cfg_attr(feature = "conduit_bin", get("/_matrix/federation/v1/version"))]
|
||||
pub fn get_server_version(db: State<'_, Database>) -> ConduitResult<get_server_version::Response> {
|
||||
pub fn get_server_version_route(
|
||||
db: State<'_, Database>,
|
||||
) -> ConduitResult<get_server_version::Response> {
|
||||
if !db.globals.federation_enabled() {
|
||||
return Err(Error::bad_config("Federation is disabled."));
|
||||
}
|
||||
|
@ -236,7 +250,7 @@ pub fn get_server_version(db: State<'_, Database>) -> ConduitResult<get_server_v
|
|||
}
|
||||
|
||||
#[cfg_attr(feature = "conduit_bin", get("/_matrix/key/v2/server"))]
|
||||
pub fn get_server_keys(db: State<'_, Database>) -> Json<String> {
|
||||
pub fn get_server_keys_route(db: State<'_, Database>) -> Json<String> {
|
||||
if !db.globals.federation_enabled() {
|
||||
// TODO: Use proper types
|
||||
return Json("Federation is disabled.".to_owned());
|
||||
|
@ -278,8 +292,8 @@ pub fn get_server_keys(db: State<'_, Database>) -> Json<String> {
|
|||
}
|
||||
|
||||
#[cfg_attr(feature = "conduit_bin", get("/_matrix/key/v2/server/<_>"))]
|
||||
pub fn get_server_keys_deprecated(db: State<'_, Database>) -> Json<String> {
|
||||
get_server_keys(db)
|
||||
pub fn get_server_keys_deprecated_route(db: State<'_, Database>) -> Json<String> {
|
||||
get_server_keys_route(db)
|
||||
}
|
||||
|
||||
#[cfg_attr(
|
||||
|
@ -464,6 +478,9 @@ pub async fn send_transaction_message_route<'a>(
|
|||
let mut pdu_id = room_id.as_bytes().to_vec();
|
||||
pdu_id.push(0xff);
|
||||
pdu_id.extend_from_slice(&count.to_be_bytes());
|
||||
|
||||
db.rooms.append_to_state(&pdu_id, &pdu)?;
|
||||
|
||||
db.rooms.append_pdu(
|
||||
&pdu,
|
||||
&value,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue