diff --git a/core/startos/src/db/model/public.rs b/core/startos/src/db/model/public.rs index 7cc1ea619..ea3abc7ec 100644 --- a/core/startos/src/db/model/public.rs +++ b/core/startos/src/db/model/public.rs @@ -254,7 +254,7 @@ impl NetworkInterfaceInfo { self.secure.unwrap_or_else(|| { self.ip_info.as_ref().map_or(false, |ip_info| { ip_info.device_type == Some(NetworkInterfaceType::Wireguard) - }) + }) && !self.public() }) } } diff --git a/core/startos/src/net/tls.rs b/core/startos/src/net/tls.rs index a29a63b5e..0a58d7085 100644 --- a/core/startos/src/net/tls.rs +++ b/core/startos/src/net/tls.rs @@ -1,8 +1,9 @@ use std::sync::Arc; -use std::task::Poll; +use std::task::{Poll, ready}; -use futures::FutureExt; use futures::future::BoxFuture; +use futures::stream::FuturesUnordered; +use futures::{FutureExt, StreamExt}; use imbl_value::InternedString; use openssl::x509::X509Ref; use tokio::io::AsyncWriteExt; @@ -117,7 +118,7 @@ pub struct TlsListener TlsHandler<'a, A>> { pub accept: A, pub tls_handler: H, in_progress: SyncMutex< - Vec< + FuturesUnordered< BoxFuture< 'static, ( @@ -133,7 +134,7 @@ impl TlsHandler<'a, A>> TlsListener { Self { accept, tls_handler: cert_handler, - in_progress: SyncMutex::new(Vec::new()), + in_progress: SyncMutex::new(FuturesUnordered::new()), } } } @@ -150,100 +151,97 @@ where ) -> Poll> { self.in_progress.mutate(|in_progress| { loop { - if let Some((idx, (handler, res))) = - in_progress.iter_mut().enumerate().find_map(|(idx, fut)| { - match fut.poll_unpin(cx) { - Poll::Ready(a) => Some((idx, a)), - Poll::Pending => None, + if !in_progress.is_empty() { + if let Poll::Ready(Some((handler, res))) = in_progress.poll_next_unpin(cx) { + if let Some(res) = res.transpose() { + self.tls_handler = handler; + return Poll::Ready(res); } - }) - { - drop(in_progress.swap_remove(idx)); - if let Some(res) = res.transpose() { - self.tls_handler = handler; - return Poll::Ready(res); + continue; } - continue; } - if let Poll::Ready((metadata, stream)) = self.accept.poll_accept(cx)? { - let mut tls_handler = self.tls_handler.clone(); - in_progress.push( - async move { - let res = async { - let mut acceptor = LazyConfigAcceptor::new( - Acceptor::default(), - BackTrackingIO::new(stream), - ); - let mut mid: tokio_rustls::StartHandshake< - BackTrackingIO, - > = match (&mut acceptor).await { - Ok(a) => a, - Err(e) => { - let mut stream = - acceptor.take_io().or_not_found("acceptor io")?; - let (_, buf) = stream.rewind(); - if std::str::from_utf8(buf) - .ok() - .and_then(|buf| { - buf.lines() - .map(|l| l.trim()) - .filter(|l| !l.is_empty()) - .next() - }) - .map_or(false, |buf| { - regex::Regex::new("[A-Z]+ (.+) HTTP/1") - .unwrap() - .is_match(buf) - }) - { - handle_http_on_https(stream).await.log_err(); + let (metadata, stream) = ready!(self.accept.poll_accept(cx)?); + let mut tls_handler = self.tls_handler.clone(); + let mut fut = async move { + let res = async { + let mut acceptor = LazyConfigAcceptor::new( + Acceptor::default(), + BackTrackingIO::new(stream), + ); + let mut mid: tokio_rustls::StartHandshake> = + match (&mut acceptor).await { + Ok(a) => a, + Err(e) => { + let mut stream = + acceptor.take_io().or_not_found("acceptor io")?; + let (_, buf) = stream.rewind(); + if std::str::from_utf8(buf) + .ok() + .and_then(|buf| { + buf.lines() + .map(|l| l.trim()) + .filter(|l| !l.is_empty()) + .next() + }) + .map_or(false, |buf| { + regex::Regex::new("[A-Z]+ (.+) HTTP/1") + .unwrap() + .is_match(buf) + }) + { + handle_http_on_https(stream).await.log_err(); - return Ok(None); - } else { - return Err(e).with_kind(ErrorKind::Network); - } + return Ok(None); + } else { + return Err(e).with_kind(ErrorKind::Network); } - }; - let hello = mid.client_hello(); - if let Some(cfg) = tls_handler.get_config(&hello, &metadata).await { - let metadata = TlsMetadata { - inner: metadata, - tls_info: TlsHandshakeInfo { - sni: hello.server_name().map(InternedString::intern), - alpn: hello - .alpn() - .into_iter() - .flatten() - .map(|a| MaybeUtf8String(a.to_vec())) - .collect(), - }, - }; - let buffered = mid.io.stop_buffering(); - mid.io - .write_all(&buffered) - .await - .with_kind(ErrorKind::Network)?; - return Ok(Some(( - metadata, - Box::pin(mid.into_stream(Arc::new(cfg)).await?) - as AcceptStream, - ))); } - - Ok(None) - } - .await; - (tls_handler, res) + }; + let hello = mid.client_hello(); + if let Some(cfg) = tls_handler.get_config(&hello, &metadata).await { + let metadata = TlsMetadata { + inner: metadata, + tls_info: TlsHandshakeInfo { + sni: hello.server_name().map(InternedString::intern), + alpn: hello + .alpn() + .into_iter() + .flatten() + .map(|a| MaybeUtf8String(a.to_vec())) + .collect(), + }, + }; + let buffered = mid.io.stop_buffering(); + mid.io + .write_all(&buffered) + .await + .with_kind(ErrorKind::Network)?; + return Ok(Some(( + metadata, + Box::pin(mid.into_stream(Arc::new(cfg)).await?) as AcceptStream, + ))); } - .boxed(), - ); - continue; - } - break; - } - Poll::Pending + Ok(None) + } + .await; + (tls_handler, res) + } + .boxed(); + match fut.poll_unpin(cx) { + Poll::Pending => { + in_progress.push(fut); + return Poll::Pending; + } + Poll::Ready((handler, res)) => { + if let Some(res) = res.transpose() { + self.tls_handler = handler; + return Poll::Ready(res); + } + } + }; + } }) } } diff --git a/core/startos/src/net/web_server.rs b/core/startos/src/net/web_server.rs index 01af76a7d..a03365417 100644 --- a/core/startos/src/net/web_server.rs +++ b/core/startos/src/net/web_server.rs @@ -295,7 +295,7 @@ impl Acceptor { &mut self, cx: &mut std::task::Context<'_>, ) -> Poll> { - let _ = self.poll_changed(cx); + while self.poll_changed(cx).is_ready() {} self.acceptor.peek_mut(|a| a.poll_accept(cx)) } diff --git a/core/startos/src/tunnel/db.rs b/core/startos/src/tunnel/db.rs index 216f71007..1db6e105b 100644 --- a/core/startos/src/tunnel/db.rs +++ b/core/startos/src/tunnel/db.rs @@ -29,7 +29,7 @@ use crate::tunnel::context::TunnelContext; use crate::tunnel::web::WebserverInfo; use crate::tunnel::wg::WgServer; use crate::util::net::WebSocketExt; -use crate::util::serde::{HandlerExtSerde, apply_expr, deserialize_from_str, serialize_display}; +use crate::util::serde::{HandlerExtSerde, apply_expr}; #[derive(Default, Deserialize, Serialize, HasModel)] #[serde(rename_all = "camelCase")] diff --git a/web/projects/start-tunnel/src/app/routes/home/routes/port-forwards/index.ts b/web/projects/start-tunnel/src/app/routes/home/routes/port-forwards/index.ts index 8367b4cfe..4d483c7e4 100644 --- a/web/projects/start-tunnel/src/app/routes/home/routes/port-forwards/index.ts +++ b/web/projects/start-tunnel/src/app/routes/home/routes/port-forwards/index.ts @@ -181,7 +181,7 @@ export default class PortForwards { }) protected readonly forwards = toSignal( - combineLatest([this.devices$, this.patch.watch$('port_forwards')]).pipe( + combineLatest([this.devices$, this.patch.watch$('portForwards')]).pipe( map(([devices, forwards]) => Object.entries(forwards).map(([source, target]) => { const sourceSplit = source.split(':') diff --git a/web/projects/start-tunnel/src/app/services/api/api.service.ts b/web/projects/start-tunnel/src/app/services/api/api.service.ts index b65a59d75..3bd67ba16 100644 --- a/web/projects/start-tunnel/src/app/services/api/api.service.ts +++ b/web/projects/start-tunnel/src/app/services/api/api.service.ts @@ -23,8 +23,8 @@ export abstract class ApiService { abstract deleteDevice(params: DeleteDeviceReq): Promise // device.remove abstract showDeviceConfig(params: DeleteDeviceReq): Promise // device.show-config // forwards - abstract addForward(params: AddForwardReq): Promise // forward.add - abstract deleteForward(params: DeleteForwardReq): Promise // forward.remove + abstract addForward(params: AddForwardReq): Promise // port-forward.add + abstract deleteForward(params: DeleteForwardReq): Promise // port-forward.remove } export type SubscribeRes = { diff --git a/web/projects/start-tunnel/src/app/services/api/live-api.service.ts b/web/projects/start-tunnel/src/app/services/api/live-api.service.ts index 45c224ffa..5f44cd165 100644 --- a/web/projects/start-tunnel/src/app/services/api/live-api.service.ts +++ b/web/projects/start-tunnel/src/app/services/api/live-api.service.ts @@ -94,11 +94,11 @@ export class LiveApiService extends ApiService { // forwards async addForward(params: AddForwardReq): Promise { - return this.rpcRequest({ method: 'forward.add', params }) + return this.rpcRequest({ method: 'port-forward.add', params }) } async deleteForward(params: DeleteForwardReq): Promise { - return this.rpcRequest({ method: 'forward.remove', params }) + return this.rpcRequest({ method: 'port-forward.remove', params }) } // private diff --git a/web/projects/start-tunnel/src/app/services/patch-db/data-model.ts b/web/projects/start-tunnel/src/app/services/patch-db/data-model.ts index 2e50834b4..5d227924f 100644 --- a/web/projects/start-tunnel/src/app/services/patch-db/data-model.ts +++ b/web/projects/start-tunnel/src/app/services/patch-db/data-model.ts @@ -1,6 +1,6 @@ export type TunnelData = { wg: WgServer - port_forwards: Record + portForwards: Record } export type WgServer = { @@ -35,7 +35,7 @@ export const mockTunnelData: TunnelData = { }, }, }, - port_forwards: { + portForwards: { '69.1.1.42:443': '10.59.0.2:5443', '69.1.1.42:3000': '10.59.0.2:3000', },