fix dns recursion and localhost (#3021)

This commit is contained in:
Aiden McClelland
2025-09-11 12:35:12 -06:00
committed by GitHub
parent 723dea100f
commit d1812d875b

View File

@@ -1,19 +1,20 @@
use std::borrow::Borrow;
use std::collections::BTreeMap;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
use std::str::FromStr;
use std::sync::{Arc, Weak};
use std::time::Duration;
use clap::Parser;
use color_eyre::eyre::eyre;
use futures::future::BoxFuture;
use futures::{FutureExt, StreamExt, TryStreamExt};
use futures::{FutureExt, StreamExt};
use helpers::NonDetachingJoinHandle;
use hickory_client::client::Client;
use hickory_client::proto::runtime::TokioRuntimeProvider;
use hickory_client::proto::tcp::TcpClientStream;
use hickory_client::proto::udp::UdpClientStream;
use hickory_client::proto::xfer::{DnsExchangeBackground, DnsRequestOptions};
use hickory_client::proto::xfer::DnsRequestOptions;
use hickory_client::proto::DnsHandle;
use hickory_server::authority::MessageResponseBuilder;
use hickory_server::proto::op::{Header, ResponseCode};
@@ -22,7 +23,6 @@ use hickory_server::server::{Request, RequestHandler, ResponseHandler, ResponseI
use hickory_server::ServerFuture;
use imbl::OrdMap;
use imbl_value::InternedString;
use itertools::Itertools;
use models::{GatewayId, OptionExt, PackageId};
use rpc_toolkit::{
from_fn_async, from_fn_blocking, Context, HandlerArgs, HandlerExt, ParentHandler,
@@ -269,6 +269,12 @@ impl DnsClient {
}
}
lazy_static::lazy_static! {
static ref LOCALHOST: Name = Name::from_ascii("localhost").unwrap();
static ref STARTOS: Name = Name::from_ascii("startos").unwrap();
static ref EMBASSY: Name = Name::from_ascii("embassy").unwrap();
}
struct Resolver {
client: DnsClient,
net_iface: Watch<OrdMap<GatewayId, NetworkInterfaceInfo>>,
@@ -276,9 +282,12 @@ struct Resolver {
}
impl Resolver {
fn resolve(&self, name: &Name, src: IpAddr) -> Option<Vec<IpAddr>> {
if name.zone_of(&*LOCALHOST) {
return Some(vec![Ipv4Addr::LOCALHOST.into(), Ipv6Addr::LOCALHOST.into()]);
}
self.resolve.peek(|r| {
if r.private_domains
.get(&*name.to_lowercase().to_ascii())
.get(&*name.to_lowercase().to_utf8().trim_end_matches('.'))
.map_or(false, |d| d.strong_count() > 0)
{
if let Some(res) = self.net_iface.peek(|i| {
@@ -295,36 +304,30 @@ impl Resolver {
return Some(res);
}
}
match name.iter().next_back() {
Some(b"embassy") | Some(b"startos") => {
if let Some(pkg) = name.iter().rev().skip(1).next() {
if let Some(ip) = r.services.get(&Some(
std::str::from_utf8(pkg)
.unwrap_or_default()
.parse()
.unwrap_or_default(),
)) {
Some(
ip.iter()
.filter(|(_, rc)| rc.strong_count() > 0)
.map(|(ip, _)| (*ip).into())
.collect(),
)
} else {
None
}
} else if let Some(ip) = r.services.get(&None) {
Some(
ip.iter()
.filter(|(_, rc)| rc.strong_count() > 0)
.map(|(ip, _)| (*ip).into())
.collect(),
)
} else {
None
}
if name.zone_of(&*STARTOS) || name.zone_of(&*EMBASSY) {
let Ok(pkg) = name
.trim_to(2)
.iter()
.next()
.map(std::str::from_utf8)
.transpose()
.map_err(|_| ())
.and_then(|s| s.map(PackageId::from_str).transpose().map_err(|_| ()))
else {
return None;
};
if let Some(ip) = r.services.get(&pkg) {
Some(
ip.iter()
.filter(|(_, rc)| rc.strong_count() > 0)
.map(|(ip, _)| (*ip).into())
.collect(),
)
} else {
None
}
_ => None,
} else {
None
}
})
}
@@ -420,16 +423,22 @@ impl RequestHandler for Resolver {
}
} else {
let query = query.original().clone();
let mut streams = self.client.lookup(query, DnsRequestOptions::default());
let mut opt = DnsRequestOptions::default();
opt.recursion_desired = request.recursion_desired();
let mut streams = self.client.lookup(query, opt);
let mut err = None;
for stream in streams.iter_mut() {
match tokio::time::timeout(Duration::from_secs(5), stream.next()).await {
Ok(Some(Err(e))) => err = Some(e),
Ok(Some(Ok(msg))) => {
let mut header = msg.header().clone();
header.set_id(request.id());
header.set_checking_disabled(request.checking_disabled());
header.set_recursion_available(true);
return response_handle
.send_response(
MessageResponseBuilder::from_message_request(&*request).build(
Header::response_from_request(request.header()),
header,
msg.answers(),
msg.name_servers(),
&msg.soa().map(|s| s.to_owned().into_record_of_rdata()),