diff --git a/core/startos/src/net/dns.rs b/core/startos/src/net/dns.rs index 9538fda55..930f903a9 100644 --- a/core/startos/src/net/dns.rs +++ b/core/startos/src/net/dns.rs @@ -1,19 +1,20 @@ use std::borrow::Borrow; use std::collections::BTreeMap; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; +use std::str::FromStr; use std::sync::{Arc, Weak}; use std::time::Duration; use clap::Parser; use color_eyre::eyre::eyre; use futures::future::BoxFuture; -use futures::{FutureExt, StreamExt, TryStreamExt}; +use futures::{FutureExt, StreamExt}; use helpers::NonDetachingJoinHandle; use hickory_client::client::Client; use hickory_client::proto::runtime::TokioRuntimeProvider; use hickory_client::proto::tcp::TcpClientStream; use hickory_client::proto::udp::UdpClientStream; -use hickory_client::proto::xfer::{DnsExchangeBackground, DnsRequestOptions}; +use hickory_client::proto::xfer::DnsRequestOptions; use hickory_client::proto::DnsHandle; use hickory_server::authority::MessageResponseBuilder; use hickory_server::proto::op::{Header, ResponseCode}; @@ -22,7 +23,6 @@ use hickory_server::server::{Request, RequestHandler, ResponseHandler, ResponseI use hickory_server::ServerFuture; use imbl::OrdMap; use imbl_value::InternedString; -use itertools::Itertools; use models::{GatewayId, OptionExt, PackageId}; use rpc_toolkit::{ from_fn_async, from_fn_blocking, Context, HandlerArgs, HandlerExt, ParentHandler, @@ -269,6 +269,12 @@ impl DnsClient { } } +lazy_static::lazy_static! { + static ref LOCALHOST: Name = Name::from_ascii("localhost").unwrap(); + static ref STARTOS: Name = Name::from_ascii("startos").unwrap(); + static ref EMBASSY: Name = Name::from_ascii("embassy").unwrap(); +} + struct Resolver { client: DnsClient, net_iface: Watch>, @@ -276,9 +282,12 @@ struct Resolver { } impl Resolver { fn resolve(&self, name: &Name, src: IpAddr) -> Option> { + if name.zone_of(&*LOCALHOST) { + return Some(vec![Ipv4Addr::LOCALHOST.into(), Ipv6Addr::LOCALHOST.into()]); + } self.resolve.peek(|r| { if r.private_domains - .get(&*name.to_lowercase().to_ascii()) + .get(&*name.to_lowercase().to_utf8().trim_end_matches('.')) .map_or(false, |d| d.strong_count() > 0) { if let Some(res) = self.net_iface.peek(|i| { @@ -295,36 +304,30 @@ impl Resolver { return Some(res); } } - match name.iter().next_back() { - Some(b"embassy") | Some(b"startos") => { - if let Some(pkg) = name.iter().rev().skip(1).next() { - if let Some(ip) = r.services.get(&Some( - std::str::from_utf8(pkg) - .unwrap_or_default() - .parse() - .unwrap_or_default(), - )) { - Some( - ip.iter() - .filter(|(_, rc)| rc.strong_count() > 0) - .map(|(ip, _)| (*ip).into()) - .collect(), - ) - } else { - None - } - } else if let Some(ip) = r.services.get(&None) { - Some( - ip.iter() - .filter(|(_, rc)| rc.strong_count() > 0) - .map(|(ip, _)| (*ip).into()) - .collect(), - ) - } else { - None - } + if name.zone_of(&*STARTOS) || name.zone_of(&*EMBASSY) { + let Ok(pkg) = name + .trim_to(2) + .iter() + .next() + .map(std::str::from_utf8) + .transpose() + .map_err(|_| ()) + .and_then(|s| s.map(PackageId::from_str).transpose().map_err(|_| ())) + else { + return None; + }; + if let Some(ip) = r.services.get(&pkg) { + Some( + ip.iter() + .filter(|(_, rc)| rc.strong_count() > 0) + .map(|(ip, _)| (*ip).into()) + .collect(), + ) + } else { + None } - _ => None, + } else { + None } }) } @@ -420,16 +423,22 @@ impl RequestHandler for Resolver { } } else { let query = query.original().clone(); - let mut streams = self.client.lookup(query, DnsRequestOptions::default()); + let mut opt = DnsRequestOptions::default(); + opt.recursion_desired = request.recursion_desired(); + let mut streams = self.client.lookup(query, opt); let mut err = None; for stream in streams.iter_mut() { match tokio::time::timeout(Duration::from_secs(5), stream.next()).await { Ok(Some(Err(e))) => err = Some(e), Ok(Some(Ok(msg))) => { + let mut header = msg.header().clone(); + header.set_id(request.id()); + header.set_checking_disabled(request.checking_disabled()); + header.set_recursion_available(true); return response_handle .send_response( MessageResponseBuilder::from_message_request(&*request).build( - Header::response_from_request(request.header()), + header, msg.answers(), msg.name_servers(), &msg.soa().map(|s| s.to_owned().into_record_of_rdata()),