diff --git a/core/startos/src/net/net_controller.rs b/core/startos/src/net/net_controller.rs index 521888665..b7a8022b4 100644 --- a/core/startos/src/net/net_controller.rs +++ b/core/startos/src/net/net_controller.rs @@ -29,6 +29,7 @@ pub struct PreInitNetController { tor: TorController, vhost: VHostController, os_bindings: Vec>, + server_hostnames: Vec>, } impl PreInitNetController { #[instrument(skip_all)] @@ -44,6 +45,7 @@ impl PreInitNetController { tor: TorController::new(tor_control, tor_socks), vhost: VHostController::new(db), os_bindings: Vec::new(), + server_hostnames: Vec::new(), }; res.add_os_bindings(hostname, os_tor_key).await?; Ok(res) @@ -59,64 +61,26 @@ impl PreInitNetController { MaybeUtf8String("h2".into()), ])); - // Internal DNS - self.vhost - .add( - Some("embassy".into()), - 443, - ([127, 0, 0, 1], 80).into(), - alpn.clone(), - ) - .await?; - self.vhost - .add( - Some("startos".into()), - 443, - ([127, 0, 0, 1], 80).into(), - alpn.clone(), - ) - .await?; + self.server_hostnames = vec![ + // LAN IP + None, + // Internal DNS + Some("embassy".into()), + Some("startos".into()), + // localhost + Some("localhost".into()), + Some(hostname.no_dot_host_name()), + // LAN mDNS + Some(hostname.local_domain_name()), + ]; - // LAN IP - self.os_bindings.push( - self.vhost - .add(None, 443, ([127, 0, 0, 1], 80).into(), alpn.clone()) - .await?, - ); - - // localhost - self.os_bindings.push( - self.vhost - .add( - Some("localhost".into()), - 443, - ([127, 0, 0, 1], 80).into(), - alpn.clone(), - ) - .await?, - ); - self.os_bindings.push( - self.vhost - .add( - Some(hostname.no_dot_host_name()), - 443, - ([127, 0, 0, 1], 80).into(), - alpn.clone(), - ) - .await?, - ); - - // LAN mDNS - self.os_bindings.push( - self.vhost - .add( - Some(hostname.local_domain_name()), - 443, - ([127, 0, 0, 1], 80).into(), - alpn.clone(), - ) - .await?, - ); + for hostname in self.server_hostnames.iter().cloned() { + self.os_bindings.push( + self.vhost + .add(hostname, 443, ([127, 0, 0, 1], 80).into(), alpn.clone()) + .await?, + ); + } // Tor self.os_bindings.push( @@ -154,6 +118,7 @@ pub struct NetController { pub(super) dns: DnsController, pub(super) forward: LanPortForwardController, pub(super) os_bindings: Vec>, + pub(super) server_hostnames: Vec>, } impl NetController { @@ -163,6 +128,7 @@ impl NetController { tor, vhost, os_bindings, + server_hostnames, }: PreInitNetController, dns_bind: &[SocketAddr], ) -> Result { @@ -173,6 +139,7 @@ impl NetController { dns: DnsController::init(dns_bind).await?, forward: LanPortForwardController::new(), os_bindings, + server_hostnames, }; res.os_bindings .push(res.dns.add(None, HOST_IP.into()).await?); @@ -258,10 +225,15 @@ impl NetService { let ctrl = self.net_controller()?; let mut errors = ErrorCollection::new(); for (_, binds) in std::mem::take(&mut self.binds) { - for (_, (lan, _, _, rc)) in binds.lan { + for (_, (lan, _, hostnames, rc)) in binds.lan { drop(rc); if let Some(external) = lan.assigned_ssl_port { - ctrl.vhost.gc(None, external).await?; + for hostname in ctrl.server_hostnames.iter().cloned() { + ctrl.vhost.gc(hostname, external).await?; + } + for hostname in hostnames { + ctrl.vhost.gc(Some(hostname), external).await?; + } } if let Some(external) = lan.assigned_port { ctrl.forward.gc(external).await?; @@ -317,11 +289,13 @@ impl NetService { Err(AlpnInfo::Reflect) } }; - rcs.push( - ctrl.vhost - .add(None, external, target, connect_ssl.clone()) - .await?, - ); + for hostname in ctrl.server_hostnames.iter().cloned() { + rcs.push( + ctrl.vhost + .add(hostname, external, target, connect_ssl.clone()) + .await?, + ); + } for address in host.addresses() { match address { HostAddress::Onion { address } => { @@ -407,7 +381,9 @@ impl NetService { } if let Some((lan, _, hostnames, _)) = old_lan_bind { if let Some(external) = lan.assigned_ssl_port { - ctrl.vhost.gc(None, external).await?; + for hostname in ctrl.server_hostnames.iter().cloned() { + ctrl.vhost.gc(hostname, external).await?; + } for hostname in hostnames { ctrl.vhost.gc(Some(hostname), external).await?; } @@ -429,7 +405,9 @@ impl NetService { }); for (lan, hostnames) in removed { if let Some(external) = lan.assigned_ssl_port { - ctrl.vhost.gc(None, external).await?; + for hostname in ctrl.server_hostnames.iter().cloned() { + ctrl.vhost.gc(hostname, external).await?; + } for hostname in hostnames { ctrl.vhost.gc(Some(hostname), external).await?; } @@ -533,7 +511,9 @@ impl NetService { pub async fn remove_all(mut self) -> Result<(), Error> { self.shutdown = true; if let Some(ctrl) = Weak::upgrade(&self.controller) { - self.clear_bindings().await + self.clear_bindings().await?; + drop(ctrl); + Ok(()) } else { tracing::warn!("NetService dropped after NetController is shutdown"); Err(Error::new(