improvement: look at SRV record when sending requests
This commit is contained in:
parent
267c721616
commit
e08dfd982b
4 changed files with 253 additions and 41 deletions
|
@ -605,7 +605,9 @@ pub async fn sync_events_route(
|
|||
changed: device_list_updates.into_iter().collect(),
|
||||
left: device_list_left.into_iter().collect(),
|
||||
},
|
||||
device_one_time_keys_count: if db.users.last_one_time_keys_update(sender_id)? > since || since == 0 {
|
||||
device_one_time_keys_count: if db.users.last_one_time_keys_update(sender_id)? > since
|
||||
|| since == 0
|
||||
{
|
||||
db.users.count_one_time_keys(sender_id, device_id)?
|
||||
} else {
|
||||
BTreeMap::new()
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
use crate::{client_server, ConduitResult, Database, Error, PduEvent, Result, Ruma};
|
||||
use http::header::{HeaderValue, AUTHORIZATION};
|
||||
use http::header::{HeaderValue, AUTHORIZATION, HOST};
|
||||
use log::warn;
|
||||
use rocket::{get, post, put, response::content::Json, State};
|
||||
use ruma::{
|
||||
|
@ -23,6 +23,7 @@ use std::{
|
|||
fmt::Debug,
|
||||
time::{Duration, SystemTime},
|
||||
};
|
||||
use trust_dns_resolver::AsyncResolver;
|
||||
|
||||
pub async fn request_well_known(
|
||||
globals: &crate::database::globals::Globals,
|
||||
|
@ -54,16 +55,36 @@ pub async fn send_request<T: OutgoingRequest>(
|
|||
where
|
||||
T: Debug,
|
||||
{
|
||||
let resolver = AsyncResolver::tokio_from_system_conf()
|
||||
.await
|
||||
.map_err(|_| Error::BadConfig("Failed to set up trust dns resolver with system config."))?;
|
||||
|
||||
let mut host = None;
|
||||
|
||||
let actual_destination = "https://".to_owned()
|
||||
+ &request_well_known(globals, &destination.as_str())
|
||||
.await
|
||||
.unwrap_or_else(|| {
|
||||
let mut destination = destination.as_str().to_owned();
|
||||
if destination.find(':').is_none() {
|
||||
destination += ":8448";
|
||||
+ &if let Some(mut delegated_hostname) =
|
||||
request_well_known(globals, &destination.as_str()).await
|
||||
{
|
||||
if let Ok(Some(srv)) = resolver
|
||||
.srv_lookup(format!("_matrix._tcp.{}", delegated_hostname))
|
||||
.await
|
||||
.map(|srv| srv.iter().next().map(|result| result.target().to_string()))
|
||||
{
|
||||
host = Some(delegated_hostname);
|
||||
srv.trim_end_matches('.').to_owned()
|
||||
} else {
|
||||
if delegated_hostname.find(':').is_none() {
|
||||
delegated_hostname += ":8448";
|
||||
}
|
||||
destination
|
||||
});
|
||||
delegated_hostname
|
||||
}
|
||||
} else {
|
||||
let mut destination = destination.as_str().to_owned();
|
||||
if destination.find(':').is_none() {
|
||||
destination += ":8448";
|
||||
}
|
||||
destination
|
||||
};
|
||||
|
||||
let mut http_request = request
|
||||
.try_into_http_request(&actual_destination, Some(""))
|
||||
|
@ -129,9 +150,17 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
let reqwest_request = reqwest::Request::try_from(http_request)
|
||||
if let Some(host) = host {
|
||||
http_request
|
||||
.headers_mut()
|
||||
.insert(HOST, HeaderValue::from_str(&host).unwrap());
|
||||
}
|
||||
|
||||
let mut reqwest_request = reqwest::Request::try_from(http_request)
|
||||
.expect("all http requests are valid reqwest requests");
|
||||
|
||||
*reqwest_request.timeout_mut() = Some(Duration::from_secs(30));
|
||||
|
||||
let reqwest_response = globals.reqwest_client().execute(reqwest_request).await;
|
||||
|
||||
// Because reqwest::Response -> http::Response is complicated:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue