use a vec instead of set for ip (#2112)

This commit is contained in:
Aiden McClelland
2023-01-12 09:57:08 -07:00
committed by Aiden McClelland
parent 89ca0ca927
commit 274db6f606

View File

@@ -1,5 +1,5 @@
use std::borrow::Borrow;
use std::collections::{BTreeMap, BTreeSet};
use std::collections::BTreeMap;
use std::net::{Ipv4Addr, SocketAddr};
use std::sync::Arc;
use std::time::Duration;
@@ -13,6 +13,7 @@ use tokio::sync::RwLock;
use trust_dns_server::authority::MessageResponseBuilder;
use trust_dns_server::client::op::{Header, ResponseCode};
use trust_dns_server::client::rr::{Name, Record, RecordType};
use trust_dns_server::proto::rr::rdata::a;
use trust_dns_server::server::{Request, RequestHandler, ResponseHandler, ResponseInfo};
use trust_dns_server::ServerFuture;
@@ -20,13 +21,13 @@ use crate::util::Invoke;
use crate::{Error, ErrorKind, ResultExt, HOST_IP};
pub struct DnsController {
services: Arc<RwLock<BTreeMap<PackageId, BTreeSet<Ipv4Addr>>>>,
services: Arc<RwLock<BTreeMap<PackageId, Vec<Ipv4Addr>>>>,
#[allow(dead_code)]
dns_server: NonDetachingJoinHandle<Result<(), Error>>,
}
struct Resolver {
services: Arc<RwLock<BTreeMap<PackageId, BTreeSet<Ipv4Addr>>>>,
services: Arc<RwLock<BTreeMap<PackageId, Vec<Ipv4Addr>>>>,
}
impl Resolver {
async fn resolve(&self, name: &Name) -> Option<Vec<Ipv4Addr>> {
@@ -61,32 +62,51 @@ impl RequestHandler for Resolver {
) -> ResponseInfo {
let query = request.request_info().query;
if let Some(ip) = self.resolve(query.name().borrow()).await {
if query.query_type() != RecordType::A {
tracing::warn!(
"Non A-Record requested for {}: {:?}",
query.name(),
query.query_type()
);
match query.query_type() {
RecordType::A => {
response_handle
.send_response(
MessageResponseBuilder::from_message_request(&*request).build(
Header::response_from_request(request.header()),
&ip.into_iter()
.map(|ip| {
Record::from_rdata(
request.request_info().query.name().to_owned().into(),
0,
trust_dns_server::client::rr::RData::A(ip),
)
})
.collect::<Vec<_>>(),
[],
[],
[],
),
)
.await
}
a => {
if a != RecordType::AAAA {
tracing::warn!(
"Non A-Record requested for {}: {:?}",
query.name(),
query.query_type()
);
}
let mut res = Header::response_from_request(request.header());
res.set_response_code(ResponseCode::NXDomain);
response_handle
.send_response(
MessageResponseBuilder::from_message_request(&*request).build(
res.into(),
[],
[],
[],
[],
),
)
.await
}
}
response_handle
.send_response(
MessageResponseBuilder::from_message_request(&*request).build(
Header::response_from_request(request.header()),
&ip.into_iter()
.map(|ip| {
Record::from_rdata(
request.request_info().query.name().to_owned().into(),
0,
trust_dns_server::client::rr::RData::A(ip),
)
})
.collect::<Vec<_>>(),
[],
[],
[],
),
)
.await
} else {
let mut res = Header::response_from_request(request.header());
res.set_response_code(ResponseCode::NXDomain);
@@ -150,14 +170,16 @@ impl DnsController {
pub async fn add(&self, pkg_id: &PackageId, ip: Ipv4Addr) {
let mut writable = self.services.write().await;
let mut ips = writable.remove(pkg_id).unwrap_or_default();
ips.insert(ip);
ips.push(ip);
writable.insert(pkg_id.clone(), ips);
}
pub async fn remove(&self, pkg_id: &PackageId, ip: Ipv4Addr) {
let mut writable = self.services.write().await;
let mut ips = writable.remove(pkg_id).unwrap_or_default();
ips.remove(&ip);
if let Some((idx, _)) = ips.iter().copied().enumerate().find(|(_, x)| *x == ip) {
ips.swap_remove(idx);
}
if !ips.is_empty() {
writable.insert(pkg_id.clone(), ips);
}