auth middleware

This commit is contained in:
Aiden McClelland
2021-07-29 18:37:31 -06:00
parent ce9495c6cc
commit 711766a8a4
9 changed files with 825 additions and 48 deletions

View File

@@ -44,6 +44,7 @@ fn inner_main() -> Result<(), Error> {
Some(a) => eprintln!("{}: {}", e.message, a),
None => eprintln!("{}", e.message),
}
std::process::exit(e.code);
}
);

View File

@@ -3,6 +3,7 @@ use std::time::Duration;
use embassy::context::{EitherContext, RpcContext};
use embassy::db::model::Database;
use embassy::middleware::auth::auth;
use embassy::middleware::cors::cors;
use embassy::status::{check_all, synchronize_all};
use embassy::util::daemon;
@@ -11,6 +12,7 @@ use futures::TryFutureExt;
use patch_db::json_ptr::JsonPointer;
use rpc_toolkit::hyper::StatusCode;
use rpc_toolkit::rpc_server;
use rpc_toolkit::rpc_server_helpers::DynMiddleware;
fn status_fn(_: i32) -> StatusCode {
StatusCode::OK
@@ -24,13 +26,15 @@ async fn inner_main(cfg_path: Option<&str>) -> Result<(), Error> {
.put(&<JsonPointer>::default(), &Database::init(), None)
.await?;
}
let auth = auth(rpc_ctx.clone());
let ctx = EitherContext::Rpc(rpc_ctx.clone());
let server = rpc_server!({
command: embassy::main_api,
context: ctx,
status: status_fn,
middleware: [
cors
cors,
auth,
]
});
let status_ctx = rpc_ctx.clone();

View File

@@ -1,12 +1,14 @@
use std::fs::File;
use std::io::Read;
use std::net::IpAddr;
use std::io::{BufReader, Read};
use std::net::{IpAddr, Ipv4Addr};
use std::path::{Path, PathBuf};
use std::sync::Arc;
use anyhow::anyhow;
use clap::ArgMatches;
use cookie_store::CookieStore;
use reqwest::Proxy;
use reqwest_cookie_store::CookieStoreMutex;
use rpc_toolkit::reqwest::{Client, Url};
use rpc_toolkit::url::Host;
use rpc_toolkit::Context;
@@ -21,20 +23,36 @@ pub struct CliContextConfig {
#[serde(deserialize_with = "deserialize_host")]
pub host: Option<Host>,
pub port: Option<u16>,
pub url: Option<Url>,
#[serde(deserialize_with = "crate::util::deserialize_from_str_opt")]
pub proxy: Option<Url>,
pub developer_key_path: Option<PathBuf>,
pub cookie_path: Option<PathBuf>,
#[serde(flatten)]
pub server_config: RpcContextConfig,
}
#[derive(Debug)]
pub struct CliContextSeed {
pub host: Host,
pub port: u16,
pub url: Url,
pub client: Client,
pub cookie_store: Arc<CookieStoreMutex>,
pub cookie_path: PathBuf,
pub developer_key_path: PathBuf,
}
impl Drop for CliContextSeed {
fn drop(&mut self) {
let tmp = format!("{}.tmp", self.cookie_path.display());
let mut writer = File::create(&tmp).unwrap();
let store = self.cookie_store.lock().unwrap();
store.save_json(&mut writer).unwrap();
writer.sync_all().unwrap();
std::fs::rename(tmp, &self.cookie_path).unwrap();
}
}
const DEFAULT_HOST: Host<&'static str> = Host::Ipv4(Ipv4Addr::new(127, 0, 0, 1));
const DEFAULT_PORT: u16 = 5959;
#[derive(Debug, Clone)]
pub struct CliContext(Arc<CliContextSeed>);
@@ -62,26 +80,55 @@ impl CliContext {
base.port = Some(bind.port())
}
}
if let Some(host) = matches.value_of("host") {
base.host = Some(Host::parse(host).with_kind(crate::ErrorKind::ParseUrl)?);
}
if let Some(port) = matches.value_of("port") {
base.port = Some(port.parse()?);
}
if let Some(proxy) = matches.value_of("proxy") {
base.proxy = Some(proxy.parse()?);
}
let host = if let Some(host) = matches.value_of("host") {
Some(Host::parse(host).with_kind(crate::ErrorKind::ParseUrl)?)
} else {
base.host
};
let port = if let Some(port) = matches.value_of("port") {
Some(port.parse()?)
} else {
base.port
};
let proxy = if let Some(proxy) = matches.value_of("proxy") {
Some(proxy.parse()?)
} else {
base.proxy
};
let cookie_path = base.cookie_path.unwrap_or_else(|| {
cfg_path
.parent()
.unwrap_or(Path::new("/"))
.join(".cookies.json")
});
let cookie_store = Arc::new(CookieStoreMutex::new(if cookie_path.exists() {
CookieStore::load_json(BufReader::new(File::open(&cookie_path)?))
.map_err(|e| anyhow!("{}", e))
.with_kind(crate::ErrorKind::Deserialization)?
} else {
CookieStore::default()
}));
Ok(CliContext(Arc::new(CliContextSeed {
host: base.host.unwrap_or(Host::Ipv4([127, 0, 0, 1].into())),
port: base.port.unwrap_or(5959),
client: if let Some(proxy) = base.proxy {
Client::builder()
.proxy(Proxy::all(proxy).with_kind(crate::ErrorKind::ParseUrl)?)
.build()
.expect("cannot fail")
} else {
Client::new()
url: base.url.unwrap_or_else(|| {
format!(
"http://{}:{}",
host.unwrap_or_else(|| DEFAULT_HOST.to_owned()),
port.unwrap_or(DEFAULT_PORT)
)
.parse()
.unwrap()
}),
client: {
let mut builder = Client::builder().cookie_provider(cookie_store.clone());
if let Some(proxy) = proxy {
builder =
builder.proxy(Proxy::all(proxy).with_kind(crate::ErrorKind::ParseUrl)?)
}
builder.build().expect("cannot fail")
},
cookie_store,
cookie_path,
developer_key_path: base.developer_key_path.unwrap_or_else(|| {
cfg_path
.parent()
@@ -107,15 +154,20 @@ impl std::ops::Deref for CliContext {
}
}
impl Context for CliContext {
fn protocol(&self) -> &str {
self.0.url.scheme()
}
fn host(&self) -> Host<&str> {
match &self.0.host {
Host::Domain(a) => Host::Domain(a.as_str()),
Host::Ipv4(a) => Host::Ipv4(*a),
Host::Ipv6(a) => Host::Ipv6(*a),
}
self.0.url.host().unwrap_or(DEFAULT_HOST)
}
fn port(&self) -> u16 {
self.0.port
self.0.url.port().unwrap_or(DEFAULT_PORT)
}
fn path(&self) -> &str {
self.0.url.path()
}
fn url(&self) -> Url {
self.0.url.clone()
}
fn client(&self) -> &Client {
&self.0.client

View File

@@ -60,6 +60,7 @@ pub fn echo(#[context] _ctx: EitherContext, #[arg] message: String) -> Result<St
developer::init,
inspect::inspect,
package,
auth::auth,
))]
pub fn main_api(#[context] ctx: EitherContext) -> Result<EitherContext, RpcError> {
Ok(ctx)

View File

@@ -4,12 +4,13 @@ use chrono::Utc;
use digest::Digest;
use futures::future::BoxFuture;
use futures::FutureExt;
use http::StatusCode;
use rpc_toolkit::command_helpers::prelude::RequestParts;
use rpc_toolkit::hyper::header::COOKIE;
use rpc_toolkit::hyper::http::Error as HttpError;
use rpc_toolkit::hyper::{Body, Request, Response};
use rpc_toolkit::rpc_server_helpers::{
noop3, noop4, DynMiddleware, DynMiddlewareStage2, DynMiddlewareStage3,
noop3, noop4, to_response, DynMiddleware, DynMiddlewareStage2, DynMiddlewareStage3,
};
use rpc_toolkit::yajrc::RpcMethod;
use rpc_toolkit::Metadata;
@@ -60,11 +61,12 @@ async fn is_authed(ctx: &RpcContext, req: &RequestParts) -> Result<(), Error> {
}
}
pub async fn auth<M: Metadata>(ctx: RpcContext) -> DynMiddleware<M> {
pub fn auth<M: Metadata>(ctx: RpcContext) -> DynMiddleware<M> {
Box::new(
|req: &mut Request<Body>,
metadata: M|
-> BoxFuture<Result<Result<DynMiddlewareStage2, Response<Body>>, HttpError>> {
move |req: &mut Request<Body>,
metadata: M|
-> BoxFuture<Result<Result<DynMiddlewareStage2, Response<Body>>, HttpError>> {
let ctx = ctx.clone();
async move {
let mut header_stub = Request::new(Body::empty());
*header_stub.headers_mut() = req.headers().clone();
@@ -75,14 +77,13 @@ pub async fn auth<M: Metadata>(ctx: RpcContext) -> DynMiddleware<M> {
.unwrap_or(true)
{
if let Err(e) = is_authed(&ctx, req).await {
let m3: DynMiddlewareStage3 = Box::new(|_, rpc_res| {
async move {
*rpc_res = Err(e.into());
Ok(Ok(noop4()))
}
.boxed()
});
return Ok(Ok(m3));
let (res_parts, _) = Response::new(()).into_parts();
return Ok(Err(to_response(
&req.headers,
res_parts,
Err(e.into()),
|_| StatusCode::OK,
)?));
}
}
Ok(Ok(noop3()))