diff --git a/backend/src/middleware/cors.rs b/backend/src/middleware/cors.rs index 132a2385f..5f33bc08d 100644 --- a/backend/src/middleware/cors.rs +++ b/backend/src/middleware/cors.rs @@ -1,4 +1,6 @@ use futures::FutureExt; +use http::HeaderValue; +use hyper::header::HeaderMap; use rpc_toolkit::hyper::http::Error as HttpError; use rpc_toolkit::hyper::{Body, Method, Request, Response}; use rpc_toolkit::rpc_server_helpers::{ @@ -6,24 +8,35 @@ use rpc_toolkit::rpc_server_helpers::{ }; use rpc_toolkit::Metadata; +fn get_cors_headers(req: &Request) -> HeaderMap { + let mut res = HeaderMap::new(); + if let Some(origin) = req.headers().get("Origin") { + res.insert("Access-Control-Allow-Origin", origin.clone()); + } + if let Some(method) = req.headers().get("Access-Control-Request-Method") { + res.insert("Access-Control-Allow-Methods", method.clone()); + } + if let Some(headers) = req.headers().get("Access-Control-Request-Headers") { + res.insert("Access-Control-Allow-Headers", headers.clone()); + } + res.insert( + "Access-Control-Allow-Credentials", + HeaderValue::from_static("true"), + ); + res +} + pub async fn cors( req: &mut Request, _metadata: M, ) -> Result>, HttpError> { + let headers = get_cors_headers(req); if req.method() == Method::OPTIONS { - Ok(Err(Response::builder() - .header( - "Access-Control-Allow-Origin", - if let Some(origin) = req.headers().get("origin").and_then(|s| s.to_str().ok()) { - origin - } else { - "*" - }, - ) - .header("Access-Control-Allow-Methods", "*") - .header("Access-Control-Allow-Headers", "*") - .header("Access-Control-Allow-Credentials", "true") - .body(Body::empty())?)) + Ok(Err({ + let mut res = Response::new(Body::empty()); + res.headers_mut().extend(headers.into_iter()); + res + })) } else { Ok(Ok(Box::new(|_, _| { async move { @@ -31,8 +44,7 @@ pub async fn cors( async move { let res: DynMiddlewareStage4 = Box::new(|res| { async move { - res.headers_mut() - .insert("Access-Control-Allow-Origin", "*".parse()?); + res.headers_mut().extend(headers.into_iter()); Ok::<_, HttpError>(()) } .boxed() diff --git a/backend/src/net/ssl.rs b/backend/src/net/ssl.rs index b3cf0ec81..bbce1debd 100644 --- a/backend/src/net/ssl.rs +++ b/backend/src/net/ssl.rs @@ -18,14 +18,12 @@ use rpc_toolkit::command; use tokio::sync::{Mutex, RwLock}; use tracing::instrument; +use crate::account::AccountInfo; +use crate::context::{self, RpcContext}; use crate::hostname::Hostname; use crate::net::dhcp::ips; use crate::net::keys::{Key, KeyInfo}; use crate::prelude::*; -use crate::{ - account::AccountInfo, - context::{self, RpcContext}, -}; use crate::{Error, ErrorKind, ResultExt}; static CERTIFICATE_VERSION: i32 = 2; // X509 version 3 is actually encoded as '2' in the cert because fuck you.