Files
start-os/core/src/middleware/cors.rs
Aiden McClelland 96ae532879 Refactor/project structure (#3085)
* refactor project structure

* environment-based default registry

* fix tests

* update build container

* use docker platform for iso build emulation

* simplify compat

* Fix docker platform spec in run-compat.sh

* handle riscv compat

* fix bug with dep error exists attr

* undo removal of sorting

* use qemu for iso stage

---------

Co-authored-by: Mariusz Kogen <k0gen@pm.me>
Co-authored-by: Matt Hill <mattnine@protonmail.com>
2025-12-22 13:39:38 -07:00

71 lines
2.2 KiB
Rust

use axum::body::Body;
use axum::extract::Request;
use axum::response::Response;
use http::{HeaderMap, HeaderValue, Method};
use rpc_toolkit::{Empty, Middleware};
#[derive(Clone)]
pub struct Cors {
headers: HeaderMap,
}
impl Cors {
pub fn new() -> Self {
let mut headers = HeaderMap::new();
headers.insert(
"Access-Control-Allow-Credentials",
HeaderValue::from_static("true"),
);
Self { headers }
}
fn get_cors_headers(&mut self, req: &Request) {
if let Some(origin) = req.headers().get("Origin") {
self.headers
.insert("Access-Control-Allow-Origin", origin.clone());
} else {
self.headers
.insert("Access-Control-Allow-Origin", HeaderValue::from_static("*"));
}
if let Some(method) = req.headers().get("Access-Control-Request-Method") {
self.headers
.insert("Access-Control-Allow-Methods", method.clone());
} else {
self.headers.insert(
"Access-Control-Allow-Methods",
HeaderValue::from_static("*"),
);
}
if let Some(headers) = req.headers().get("Access-Control-Request-Headers") {
self.headers
.insert("Access-Control-Allow-Headers", headers.clone());
} else {
self.headers.insert(
"Access-Control-Allow-Headers",
HeaderValue::from_static("*"),
);
}
}
}
impl<Context: Send + Sync + 'static> Middleware<Context> for Cors {
type Metadata = Empty;
async fn process_http_request(
&mut self,
_: &Context,
request: &mut Request,
) -> Result<(), Response> {
self.get_cors_headers(request);
if request.method() == Method::OPTIONS {
let mut response = Response::new(Body::empty());
response
.headers_mut()
.extend(std::mem::take(&mut self.headers));
return Err(response);
}
Ok(())
}
async fn process_http_response(&mut self, _: &Context, response: &mut Response) {
response
.headers_mut()
.extend(std::mem::take(&mut self.headers))
}
}