Feature/rate limiting (#786)

* rate limiting

* 10s rate limit
This commit is contained in:
Aiden McClelland
2021-11-12 15:00:57 -07:00
parent c723ee6a15
commit 9e1e3e167b

View File

@@ -1,4 +1,6 @@
use std::borrow::Borrow;
use std::sync::Arc;
use std::time::{Duration, Instant};
use basic_cookies::Cookie;
use color_eyre::eyre::eyre;
@@ -15,6 +17,7 @@ use rpc_toolkit::yajrc::RpcMethod;
use rpc_toolkit::Metadata;
use serde::{Deserialize, Serialize};
use sha2::Sha256;
use tokio::sync::Mutex;
use crate::context::RpcContext;
use crate::{Error, ResultExt};
@@ -178,21 +181,23 @@ impl Borrow<String> for HashSessionToken {
}
pub fn auth<M: Metadata>(ctx: RpcContext) -> DynMiddleware<M> {
let rate_limiter = Arc::new(Mutex::new(Instant::now()));
Box::new(
move |req: &mut Request<Body>,
metadata: M|
-> BoxFuture<Result<Result<DynMiddlewareStage2, Response<Body>>, HttpError>> {
let ctx = ctx.clone();
let rate_limiter = rate_limiter.clone();
async move {
let mut header_stub = Request::new(Body::empty());
*header_stub.headers_mut() = req.headers().clone();
let m2: DynMiddlewareStage2 = Box::new(move |req, rpc_req| {
async move {
if metadata
.get(rpc_req.method.as_str(), "authenticated")
.unwrap_or(true)
{
if let Err(e) = HasValidSession::from_request_parts(req, &ctx).await {
if let Err(e) = HasValidSession::from_request_parts(req, &ctx).await {
if metadata
.get(rpc_req.method.as_str(), "authenticated")
.unwrap_or(true)
{
let (res_parts, _) = Response::new(()).into_parts();
return Ok(Err(to_response(
&req.headers,
@@ -200,6 +205,24 @@ pub fn auth<M: Metadata>(ctx: RpcContext) -> DynMiddleware<M> {
Err(e.into()),
|_| StatusCode::OK,
)?));
} else {
let mut guard = rate_limiter.lock().await;
if guard.elapsed() < Duration::from_secs(10) {
let (res_parts, _) = Response::new(()).into_parts();
return Ok(Err(to_response(
&req.headers,
res_parts,
Err(Error::new(
eyre!(
"Please limit login attempts to 1 per 10 seconds."
),
crate::ErrorKind::RateLimited,
)
.into()),
|_| StatusCode::OK,
)?));
}
*guard = Instant::now();
}
}
Ok(Ok(noop3()))