mirror of
https://github.com/Start9Labs/start-os.git
synced 2026-04-01 21:13:09 +00:00
232 lines
8.0 KiB
Rust
232 lines
8.0 KiB
Rust
use std::borrow::Borrow;
|
|
use std::collections::BTreeMap;
|
|
use std::future::Future;
|
|
use std::net::Ipv4Addr;
|
|
use std::sync::{Arc, Weak};
|
|
use std::time::Duration;
|
|
|
|
use color_eyre::eyre::eyre;
|
|
use helpers::NonDetachingJoinHandle;
|
|
use models::PackageId;
|
|
use tokio::net::{TcpListener, UdpSocket};
|
|
use tokio::process::Command;
|
|
use tokio::sync::RwLock;
|
|
use tracing::instrument;
|
|
use trust_dns_server::authority::MessageResponseBuilder;
|
|
use trust_dns_server::proto::op::{Header, ResponseCode};
|
|
use trust_dns_server::proto::rr::{Name, Record, RecordType};
|
|
use trust_dns_server::server::{Request, RequestHandler, ResponseHandler, ResponseInfo};
|
|
use trust_dns_server::ServerFuture;
|
|
|
|
use crate::net::forward::START9_BRIDGE_IFACE;
|
|
use crate::util::Invoke;
|
|
use crate::{Error, ErrorKind, ResultExt};
|
|
|
|
pub struct DnsController {
|
|
services: Weak<RwLock<BTreeMap<Option<PackageId>, BTreeMap<Ipv4Addr, Weak<()>>>>>,
|
|
#[allow(dead_code)]
|
|
dns_server: NonDetachingJoinHandle<Result<(), Error>>,
|
|
}
|
|
|
|
struct Resolver {
|
|
services: Arc<RwLock<BTreeMap<Option<PackageId>, BTreeMap<Ipv4Addr, Weak<()>>>>>,
|
|
}
|
|
impl Resolver {
|
|
async fn resolve(&self, name: &Name) -> Option<Vec<Ipv4Addr>> {
|
|
match name.iter().next_back() {
|
|
Some(b"embassy") | Some(b"startos") => {
|
|
if let Some(pkg) = name.iter().rev().skip(1).next() {
|
|
if let Some(ip) = self.services.read().await.get(&Some(
|
|
std::str::from_utf8(pkg)
|
|
.unwrap_or_default()
|
|
.parse()
|
|
.unwrap_or_default(),
|
|
)) {
|
|
Some(
|
|
ip.iter()
|
|
.filter(|(_, rc)| rc.strong_count() > 0)
|
|
.map(|(ip, _)| *ip)
|
|
.collect(),
|
|
)
|
|
} else {
|
|
None
|
|
}
|
|
} else if let Some(ip) = self.services.read().await.get(&None) {
|
|
Some(
|
|
ip.iter()
|
|
.filter(|(_, rc)| rc.strong_count() > 0)
|
|
.map(|(ip, _)| *ip)
|
|
.collect(),
|
|
)
|
|
} else {
|
|
None
|
|
}
|
|
}
|
|
_ => None,
|
|
}
|
|
}
|
|
}
|
|
|
|
#[async_trait::async_trait]
|
|
impl RequestHandler for Resolver {
|
|
async fn handle_request<R: ResponseHandler>(
|
|
&self,
|
|
request: &Request,
|
|
mut response_handle: R,
|
|
) -> ResponseInfo {
|
|
let query = request.request_info().query;
|
|
if let Some(ip) = self.resolve(query.name().borrow()).await {
|
|
match query.query_type() {
|
|
RecordType::A => {
|
|
response_handle
|
|
.send_response(
|
|
MessageResponseBuilder::from_message_request(&*request).build(
|
|
Header::response_from_request(request.header()),
|
|
&ip.into_iter()
|
|
.map(|ip| {
|
|
Record::from_rdata(
|
|
request.request_info().query.name().to_owned().into(),
|
|
0,
|
|
trust_dns_server::proto::rr::RData::A(ip.into()),
|
|
)
|
|
})
|
|
.collect::<Vec<_>>(),
|
|
[],
|
|
[],
|
|
[],
|
|
),
|
|
)
|
|
.await
|
|
}
|
|
_ => {
|
|
let res = Header::response_from_request(request.header());
|
|
response_handle
|
|
.send_response(
|
|
MessageResponseBuilder::from_message_request(&*request).build(
|
|
res.into(),
|
|
[],
|
|
[],
|
|
[],
|
|
[],
|
|
),
|
|
)
|
|
.await
|
|
}
|
|
}
|
|
} else {
|
|
let mut res = Header::response_from_request(request.header());
|
|
res.set_response_code(ResponseCode::NXDomain);
|
|
response_handle
|
|
.send_response(
|
|
MessageResponseBuilder::from_message_request(&*request).build(
|
|
res.into(),
|
|
[],
|
|
[],
|
|
[],
|
|
[],
|
|
),
|
|
)
|
|
.await
|
|
}
|
|
.unwrap_or_else(|e| {
|
|
tracing::error!("{}", e);
|
|
tracing::debug!("{:?}", e);
|
|
let mut res = Header::response_from_request(request.header());
|
|
res.set_response_code(ResponseCode::ServFail);
|
|
res.into()
|
|
})
|
|
}
|
|
}
|
|
|
|
impl DnsController {
|
|
#[instrument(skip_all)]
|
|
pub async fn init(
|
|
bridge_activated: impl Future<Output = ()> + Send + Sync + 'static,
|
|
) -> Result<Self, Error> {
|
|
let services = Arc::new(RwLock::new(BTreeMap::new()));
|
|
|
|
let mut server = ServerFuture::new(Resolver {
|
|
services: services.clone(),
|
|
});
|
|
|
|
let dns_server = tokio::spawn(async move {
|
|
server.register_listener(
|
|
TcpListener::bind((Ipv4Addr::LOCALHOST, 53))
|
|
.await
|
|
.with_kind(ErrorKind::Network)?,
|
|
Duration::from_secs(30),
|
|
);
|
|
server.register_socket(
|
|
UdpSocket::bind((Ipv4Addr::LOCALHOST, 53))
|
|
.await
|
|
.with_kind(ErrorKind::Network)?,
|
|
);
|
|
|
|
bridge_activated.await;
|
|
|
|
Command::new("resolvectl")
|
|
.arg("dns")
|
|
.arg(START9_BRIDGE_IFACE)
|
|
.arg("127.0.0.1")
|
|
.invoke(ErrorKind::Network)
|
|
.await?;
|
|
Command::new("resolvectl")
|
|
.arg("domain")
|
|
.arg(START9_BRIDGE_IFACE)
|
|
.arg("embassy")
|
|
.invoke(ErrorKind::Network)
|
|
.await?;
|
|
|
|
server
|
|
.block_until_done()
|
|
.await
|
|
.map_err(|e| Error::new(e, ErrorKind::Network))
|
|
})
|
|
.into();
|
|
|
|
Ok(Self {
|
|
services: Arc::downgrade(&services),
|
|
dns_server,
|
|
})
|
|
}
|
|
|
|
pub async fn add(&self, pkg_id: Option<PackageId>, ip: Ipv4Addr) -> Result<Arc<()>, Error> {
|
|
if let Some(services) = Weak::upgrade(&self.services) {
|
|
let mut writable = services.write().await;
|
|
let mut ips = writable.remove(&pkg_id).unwrap_or_default();
|
|
let rc = if let Some(rc) = Weak::upgrade(&ips.remove(&ip).unwrap_or_default()) {
|
|
rc
|
|
} else {
|
|
Arc::new(())
|
|
};
|
|
ips.insert(ip, Arc::downgrade(&rc));
|
|
writable.insert(pkg_id, ips);
|
|
Ok(rc)
|
|
} else {
|
|
Err(Error::new(
|
|
eyre!("DNS Server Thread has exited"),
|
|
crate::ErrorKind::Network,
|
|
))
|
|
}
|
|
}
|
|
|
|
pub async fn gc(&self, pkg_id: Option<PackageId>, ip: Ipv4Addr) -> Result<(), Error> {
|
|
if let Some(services) = Weak::upgrade(&self.services) {
|
|
let mut writable = services.write().await;
|
|
let mut ips = writable.remove(&pkg_id).unwrap_or_default();
|
|
if let Some(rc) = Weak::upgrade(&ips.remove(&ip).unwrap_or_default()) {
|
|
ips.insert(ip, Arc::downgrade(&rc));
|
|
}
|
|
if !ips.is_empty() {
|
|
writable.insert(pkg_id, ips);
|
|
}
|
|
Ok(())
|
|
} else {
|
|
Err(Error::new(
|
|
eyre!("DNS Server Thread has exited"),
|
|
crate::ErrorKind::Network,
|
|
))
|
|
}
|
|
}
|
|
}
|