fix dns recursion and localhost (#3021)

This commit is contained in:
Aiden McClelland
2025-09-11 12:35:12 -06:00
committed by GitHub
parent 723dea100f
commit d1812d875b

View File

@@ -1,19 +1,20 @@
use std::borrow::Borrow; use std::borrow::Borrow;
use std::collections::BTreeMap; use std::collections::BTreeMap;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
use std::str::FromStr;
use std::sync::{Arc, Weak}; use std::sync::{Arc, Weak};
use std::time::Duration; use std::time::Duration;
use clap::Parser; use clap::Parser;
use color_eyre::eyre::eyre; use color_eyre::eyre::eyre;
use futures::future::BoxFuture; use futures::future::BoxFuture;
use futures::{FutureExt, StreamExt, TryStreamExt}; use futures::{FutureExt, StreamExt};
use helpers::NonDetachingJoinHandle; use helpers::NonDetachingJoinHandle;
use hickory_client::client::Client; use hickory_client::client::Client;
use hickory_client::proto::runtime::TokioRuntimeProvider; use hickory_client::proto::runtime::TokioRuntimeProvider;
use hickory_client::proto::tcp::TcpClientStream; use hickory_client::proto::tcp::TcpClientStream;
use hickory_client::proto::udp::UdpClientStream; use hickory_client::proto::udp::UdpClientStream;
use hickory_client::proto::xfer::{DnsExchangeBackground, DnsRequestOptions}; use hickory_client::proto::xfer::DnsRequestOptions;
use hickory_client::proto::DnsHandle; use hickory_client::proto::DnsHandle;
use hickory_server::authority::MessageResponseBuilder; use hickory_server::authority::MessageResponseBuilder;
use hickory_server::proto::op::{Header, ResponseCode}; use hickory_server::proto::op::{Header, ResponseCode};
@@ -22,7 +23,6 @@ use hickory_server::server::{Request, RequestHandler, ResponseHandler, ResponseI
use hickory_server::ServerFuture; use hickory_server::ServerFuture;
use imbl::OrdMap; use imbl::OrdMap;
use imbl_value::InternedString; use imbl_value::InternedString;
use itertools::Itertools;
use models::{GatewayId, OptionExt, PackageId}; use models::{GatewayId, OptionExt, PackageId};
use rpc_toolkit::{ use rpc_toolkit::{
from_fn_async, from_fn_blocking, Context, HandlerArgs, HandlerExt, ParentHandler, from_fn_async, from_fn_blocking, Context, HandlerArgs, HandlerExt, ParentHandler,
@@ -269,6 +269,12 @@ impl DnsClient {
} }
} }
lazy_static::lazy_static! {
static ref LOCALHOST: Name = Name::from_ascii("localhost").unwrap();
static ref STARTOS: Name = Name::from_ascii("startos").unwrap();
static ref EMBASSY: Name = Name::from_ascii("embassy").unwrap();
}
struct Resolver { struct Resolver {
client: DnsClient, client: DnsClient,
net_iface: Watch<OrdMap<GatewayId, NetworkInterfaceInfo>>, net_iface: Watch<OrdMap<GatewayId, NetworkInterfaceInfo>>,
@@ -276,9 +282,12 @@ struct Resolver {
} }
impl Resolver { impl Resolver {
fn resolve(&self, name: &Name, src: IpAddr) -> Option<Vec<IpAddr>> { fn resolve(&self, name: &Name, src: IpAddr) -> Option<Vec<IpAddr>> {
if name.zone_of(&*LOCALHOST) {
return Some(vec![Ipv4Addr::LOCALHOST.into(), Ipv6Addr::LOCALHOST.into()]);
}
self.resolve.peek(|r| { self.resolve.peek(|r| {
if r.private_domains if r.private_domains
.get(&*name.to_lowercase().to_ascii()) .get(&*name.to_lowercase().to_utf8().trim_end_matches('.'))
.map_or(false, |d| d.strong_count() > 0) .map_or(false, |d| d.strong_count() > 0)
{ {
if let Some(res) = self.net_iface.peek(|i| { if let Some(res) = self.net_iface.peek(|i| {
@@ -295,36 +304,30 @@ impl Resolver {
return Some(res); return Some(res);
} }
} }
match name.iter().next_back() { if name.zone_of(&*STARTOS) || name.zone_of(&*EMBASSY) {
Some(b"embassy") | Some(b"startos") => { let Ok(pkg) = name
if let Some(pkg) = name.iter().rev().skip(1).next() { .trim_to(2)
if let Some(ip) = r.services.get(&Some( .iter()
std::str::from_utf8(pkg) .next()
.unwrap_or_default() .map(std::str::from_utf8)
.parse() .transpose()
.unwrap_or_default(), .map_err(|_| ())
)) { .and_then(|s| s.map(PackageId::from_str).transpose().map_err(|_| ()))
Some( else {
ip.iter() return None;
.filter(|(_, rc)| rc.strong_count() > 0) };
.map(|(ip, _)| (*ip).into()) if let Some(ip) = r.services.get(&pkg) {
.collect(), Some(
) ip.iter()
} else { .filter(|(_, rc)| rc.strong_count() > 0)
None .map(|(ip, _)| (*ip).into())
} .collect(),
} else if let Some(ip) = r.services.get(&None) { )
Some( } else {
ip.iter() None
.filter(|(_, rc)| rc.strong_count() > 0)
.map(|(ip, _)| (*ip).into())
.collect(),
)
} else {
None
}
} }
_ => None, } else {
None
} }
}) })
} }
@@ -420,16 +423,22 @@ impl RequestHandler for Resolver {
} }
} else { } else {
let query = query.original().clone(); let query = query.original().clone();
let mut streams = self.client.lookup(query, DnsRequestOptions::default()); let mut opt = DnsRequestOptions::default();
opt.recursion_desired = request.recursion_desired();
let mut streams = self.client.lookup(query, opt);
let mut err = None; let mut err = None;
for stream in streams.iter_mut() { for stream in streams.iter_mut() {
match tokio::time::timeout(Duration::from_secs(5), stream.next()).await { match tokio::time::timeout(Duration::from_secs(5), stream.next()).await {
Ok(Some(Err(e))) => err = Some(e), Ok(Some(Err(e))) => err = Some(e),
Ok(Some(Ok(msg))) => { Ok(Some(Ok(msg))) => {
let mut header = msg.header().clone();
header.set_id(request.id());
header.set_checking_disabled(request.checking_disabled());
header.set_recursion_available(true);
return response_handle return response_handle
.send_response( .send_response(
MessageResponseBuilder::from_message_request(&*request).build( MessageResponseBuilder::from_message_request(&*request).build(
Header::response_from_request(request.header()), header,
msg.answers(), msg.answers(),
msg.name_servers(), msg.name_servers(),
&msg.soa().map(|s| s.to_owned().into_record_of_rdata()), &msg.soa().map(|s| s.to_owned().into_record_of_rdata()),