fix: refactor dns to handle tcp connections: (#3083)

* fix: refactor dns to handle tcp connections:
- do not use long-lived tcp connections to upstream dns servers
- when incoming request is over tcp, force a tcp lookup instead of udp

this solves cases where large dns records were not being resolved due to udp->tcp switch-over.

* use forwarding resolver for fallback

---------

Co-authored-by: Aiden McClelland <me@drbonez.dev>
This commit is contained in:
Remco Ros
2025-12-20 07:26:29 +01:00
committed by GitHub
parent 5446c89bc0
commit 7c12b58bb5
5 changed files with 258 additions and 309 deletions

View File

@@ -1,5 +1,5 @@
use std::borrow::Borrow;
use std::collections::{BTreeMap, VecDeque};
use std::collections::BTreeMap;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
use std::str::FromStr;
use std::sync::{Arc, Weak};
@@ -7,27 +7,23 @@ use std::time::Duration;
use clap::Parser;
use color_eyre::eyre::eyre;
use futures::future::BoxFuture;
use futures::{FutureExt, StreamExt};
use hickory_client::client::Client;
use hickory_client::proto::DnsHandle;
use hickory_client::proto::runtime::TokioRuntimeProvider;
use hickory_client::proto::tcp::TcpClientStream;
use hickory_client::proto::udp::UdpClientStream;
use hickory_client::proto::xfer::DnsRequestOptions;
use hickory_server::ServerFuture;
use hickory_server::authority::MessageResponseBuilder;
use futures::{FutureExt, StreamExt, TryStreamExt};
use hickory_server::authority::{AuthorityObject, Catalog, MessageResponseBuilder};
use hickory_server::proto::op::{Header, ResponseCode};
use hickory_server::proto::rr::{Name, Record, RecordType};
use hickory_server::proto::rr::{LowerName, Name, Record, RecordType};
use hickory_server::resolver::config::{ResolverConfig, ResolverOpts};
use hickory_server::server::{Request, RequestHandler, ResponseHandler, ResponseInfo};
use hickory_server::store::forwarder::{ForwardAuthority, ForwardConfig};
use hickory_server::{ServerFuture, resolver as hickory_resolver};
use imbl::OrdMap;
use imbl_value::InternedString;
use patch_db::json_ptr::JsonPointer;
use itertools::Itertools;
use rpc_toolkit::{
Context, HandlerArgs, HandlerExt, ParentHandler, from_fn_async, from_fn_blocking,
};
use serde::{Deserialize, Serialize};
use tokio::net::{TcpListener, UdpSocket};
use tokio::sync::RwLock;
use tracing::instrument;
use crate::context::{CliContext, RpcContext};
@@ -35,7 +31,6 @@ use crate::db::model::Database;
use crate::db::model::public::NetworkInterfaceInfo;
use crate::net::gateway::NetworkInterfaceWatcher;
use crate::prelude::*;
use crate::util::actor::background::BackgroundJobQueue;
use crate::util::future::NonDetachingJoinHandle;
use crate::util::io::file_string_stream;
use crate::util::serde::{HandlerExtSerde, display_serializable};
@@ -214,173 +209,6 @@ pub struct DnsController {
dns_server: NonDetachingJoinHandle<()>,
}
struct DnsClient {
client: Arc<SyncRwLock<Vec<(SocketAddr, hickory_client::client::Client)>>>,
_thread: NonDetachingJoinHandle<()>,
}
impl DnsClient {
pub fn new(db: TypedPatchDb<Database>) -> Self {
let client = Arc::new(SyncRwLock::new(Vec::new()));
Self {
client: client.clone(),
_thread: tokio::spawn(async move {
let (bg, mut runner) = BackgroundJobQueue::new();
runner
.run_while(async move {
let dhcp_ns_db = db.clone();
bg.add_job(async move {
loop {
if let Err(e) = async {
let mut stream =
file_string_stream("/run/systemd/resolve/resolv.conf")
.filter_map(|a| futures::future::ready(a.transpose()))
.boxed();
while let Some(conf) = stream.next().await {
let conf: String = conf?;
let mut nameservers = conf
.lines()
.map(|l| l.trim())
.filter_map(|l| l.strip_prefix("nameserver "))
.map(|n| {
n.parse::<SocketAddr>().or_else(|_| {
n.parse::<IpAddr>().map(|a| (a, 53).into())
})
})
.collect::<Result<VecDeque<_>, _>>()?;
if nameservers
.front()
.map_or(false, |addr| addr.ip().is_loopback())
{
nameservers.pop_front();
}
if nameservers.front().map_or(false, |addr| {
addr.ip() == IpAddr::from([1, 1, 1, 1])
}) {
nameservers.pop_front();
}
dhcp_ns_db
.mutate(|db| {
let dns = db
.as_public_mut()
.as_server_info_mut()
.as_network_mut()
.as_dns_mut();
dns.as_dhcp_servers_mut().ser(&nameservers)
})
.await
.result?
}
Ok::<_, Error>(())
}
.await
{
tracing::error!("{e}");
tracing::debug!("{e:?}");
tokio::time::sleep(Duration::from_secs(1)).await;
}
}
});
loop {
if let Err::<(), Error>(e) = async {
let mut dns_changed = db
.subscribe(
"/public/serverInfo/network/dns"
.parse::<JsonPointer>()
.with_kind(ErrorKind::Database)?,
)
.await;
let mut prev_nameservers = VecDeque::new();
let mut bg = BTreeMap::<SocketAddr, BoxFuture<_>>::new();
loop {
let dns = db
.peek()
.await
.into_public()
.into_server_info()
.into_network()
.into_dns();
let nameservers = dns
.as_static_servers()
.transpose_ref()
.unwrap_or_else(|| dns.as_dhcp_servers())
.de()?;
if nameservers != prev_nameservers {
let mut existing: BTreeMap<_, _> =
client.peek(|c| c.iter().cloned().collect());
let mut new = Vec::with_capacity(nameservers.len());
for addr in &nameservers {
if let Some(existing) = existing.remove(addr) {
new.push((*addr, existing));
} else {
let client = if let Ok((client, bg_thread)) =
Client::connect(
UdpClientStream::builder(
*addr,
TokioRuntimeProvider::new(),
)
.build(),
)
.await
{
bg.insert(*addr, bg_thread.boxed());
client
} else {
let (stream, sender) = TcpClientStream::new(
*addr,
None,
Some(Duration::from_secs(30)),
TokioRuntimeProvider::new(),
);
let (client, bg_thread) =
Client::new(stream, sender, None)
.await
.with_kind(ErrorKind::Network)?;
bg.insert(*addr, bg_thread.fuse().boxed());
client
};
new.push((*addr, client));
}
}
bg.retain(|n, _| nameservers.iter().any(|a| a == n));
prev_nameservers = nameservers;
client.replace(new);
}
futures::future::select(
dns_changed.recv().boxed(),
futures::future::join(
futures::future::join_all(bg.values_mut()),
futures::future::pending::<()>(),
),
)
.await;
}
}
.await
{
tracing::error!("{e}");
tracing::debug!("{e:?}");
}
tokio::time::sleep(Duration::from_secs(1)).await;
}
})
.await;
})
.into(),
}
}
fn lookup(
&self,
query: hickory_client::proto::op::Query,
options: DnsRequestOptions,
) -> Vec<hickory_client::proto::xfer::DnsExchangeSend> {
self.client.peek(|c| {
c.iter()
.map(|(_, c)| c.lookup(query.clone(), options.clone()))
.collect()
})
}
}
lazy_static::lazy_static! {
static ref LOCALHOST: Name = Name::from_ascii("localhost").unwrap();
static ref STARTOS: Name = Name::from_ascii("startos").unwrap();
@@ -388,11 +216,106 @@ lazy_static::lazy_static! {
}
struct Resolver {
client: DnsClient,
net_iface: Watch<OrdMap<GatewayId, NetworkInterfaceInfo>>,
catalog: Arc<RwLock<Catalog>>,
resolve: Arc<SyncRwLock<ResolveMap>>,
net_iface: Watch<OrdMap<GatewayId, NetworkInterfaceInfo>>,
_thread: NonDetachingJoinHandle<()>,
}
impl Resolver {
fn new(
db: TypedPatchDb<Database>,
net_iface: Watch<OrdMap<GatewayId, NetworkInterfaceInfo>>,
) -> Self {
let catalog = Arc::new(RwLock::new(Catalog::new()));
Self {
catalog: catalog.clone(),
resolve: Arc::new(SyncRwLock::new(ResolveMap::default())),
net_iface,
_thread: tokio::spawn(async move {
let mut prev = crate::util::serde::hash_serializable::<sha2::Sha256, _>(&(
ResolverConfig::new(),
ResolverOpts::default(),
))
.unwrap_or_default();
loop {
if let Err(e) = async {
let mut stream = file_string_stream("/run/systemd/resolve/resolv.conf")
.filter_map(|a| futures::future::ready(a.transpose()))
.boxed();
while let Some(conf) = stream.try_next().await? {
let (config, mut opts) =
hickory_resolver::system_conf::parse_resolv_conf(conf)
.with_kind(ErrorKind::ParseSysInfo)?;
opts.timeout = Duration::from_secs(30);
let hash = crate::util::serde::hash_serializable::<sha2::Sha256, _>(
&(&config, &opts),
)?;
if hash != prev {
db.mutate(|db| {
db.as_public_mut()
.as_server_info_mut()
.as_network_mut()
.as_dns_mut()
.as_dhcp_servers_mut()
.ser(
&config
.name_servers()
.into_iter()
.map(|n| n.socket_addr)
.dedup()
.skip(2)
.collect(),
)
})
.await
.result?;
let auth: Vec<Arc<dyn AuthorityObject>> = vec![Arc::new(
ForwardAuthority::builder_tokio(ForwardConfig {
name_servers: from_value(Value::Array(
config
.name_servers()
.into_iter()
.skip(4)
.map(to_value)
.collect::<Result<_, Error>>()?,
))?,
options: Some(opts),
})
.build()
.map_err(|e| Error::new(eyre!("{e}"), ErrorKind::Network))?,
)];
{
let mut guard = tokio::time::timeout(
Duration::from_secs(10),
catalog.write(),
)
.await
.map_err(|_| {
Error::new(
eyre!("timed out waiting to update dns catalog"),
ErrorKind::Timeout,
)
})?;
guard.upsert(Name::root().into(), auth);
drop(guard);
}
}
prev = hash;
}
Ok::<_, Error>(())
}
.await
{
tracing::error!("{e}");
tracing::debug!("{e:?}");
tokio::time::sleep(Duration::from_secs(1)).await;
}
}
})
.into(),
}
}
fn resolve(&self, name: &Name, mut src: IpAddr) -> Option<Vec<IpAddr>> {
if name.zone_of(&*LOCALHOST) {
return Some(vec![Ipv4Addr::LOCALHOST.into(), Ipv6Addr::LOCALHOST.into()]);
@@ -495,6 +418,7 @@ impl RequestHandler for Resolver {
),
)
.await
.map(Some)
}
RecordType::AAAA => {
let mut header = Header::response_from_request(request.header());
@@ -521,6 +445,7 @@ impl RequestHandler for Resolver {
),
)
.await
.map(Some)
}
_ => {
let mut header = Header::response_from_request(request.header());
@@ -536,70 +461,27 @@ impl RequestHandler for Resolver {
),
)
.await
.map(Some)
}
}
} else {
let query = query.original().clone();
let mut opt = DnsRequestOptions::default();
opt.recursion_desired = request.recursion_desired();
let mut streams = self.client.lookup(query, opt);
let mut err = None;
for stream in streams.iter_mut() {
match tokio::time::timeout(Duration::from_secs(5), stream.next()).await {
Ok(Some(Err(e))) => err = Some(e),
Ok(Some(Ok(msg))) => {
let mut header = msg.header().clone();
header.set_id(request.id());
header.set_checking_disabled(request.checking_disabled());
header.set_recursion_available(true);
return response_handle
.send_response(
MessageResponseBuilder::from_message_request(&*request).build(
header,
msg.answers(),
msg.name_servers(),
&msg.soa().map(|s| s.to_owned().into_record_of_rdata()),
msg.additionals(),
),
)
.await;
}
_ => (),
}
}
if let Some(e) = err {
tracing::error!("{e}");
tracing::debug!("{e:?}");
}
let mut header = Header::response_from_request(request.header());
header.set_recursion_available(true);
header.set_response_code(ResponseCode::ServFail);
response_handle
.send_response(
MessageResponseBuilder::from_message_request(&*request).build(
header,
[],
[],
[],
[],
),
)
.await
Ok(None)
}
}
.await
{
Ok(a) => a,
Ok(Some(a)) => return a,
Ok(None) => (),
Err(e) => {
tracing::error!("{}", e);
tracing::debug!("{:?}", e);
tracing::error!("Error resolving internal DNS: {e}");
tracing::debug!("{e:?}");
let mut header = Header::response_from_request(request.header());
header.set_recursion_available(true);
header.set_response_code(ResponseCode::ServFail);
response_handle
return response_handle
.send_response(
MessageResponseBuilder::from_message_request(&*request).build(
header,
header.into(),
[],
[],
[],
@@ -607,9 +489,14 @@ impl RequestHandler for Resolver {
),
)
.await
.unwrap_or(header.into())
.unwrap_or_else(|_| header.into());
}
}
self.catalog
.read()
.await
.handle_request(request, response_handle)
.await
}
}
@@ -619,13 +506,9 @@ impl DnsController {
db: TypedPatchDb<Database>,
watcher: &NetworkInterfaceWatcher,
) -> Result<Self, Error> {
let resolve = Arc::new(SyncRwLock::new(ResolveMap::default()));
let mut server = ServerFuture::new(Resolver {
client: DnsClient::new(db),
net_iface: watcher.subscribe(),
resolve: resolve.clone(),
});
let resolver = Resolver::new(db, watcher.subscribe());
let resolve = Arc::downgrade(&resolver.resolve);
let mut server = ServerFuture::new(resolver);
let dns_server = tokio::spawn(
async move {
@@ -653,7 +536,7 @@ impl DnsController {
.into();
Ok(Self {
resolve: Arc::downgrade(&resolve),
resolve,
dns_server,
})
}