diff --git a/backend/src/net/dns.rs b/backend/src/net/dns.rs index d73c781eb..60c33fcb0 100644 --- a/backend/src/net/dns.rs +++ b/backend/src/net/dns.rs @@ -1,5 +1,5 @@ use std::borrow::Borrow; -use std::collections::{BTreeMap, BTreeSet}; +use std::collections::BTreeMap; use std::net::{Ipv4Addr, SocketAddr}; use std::sync::Arc; use std::time::Duration; @@ -13,6 +13,7 @@ use tokio::sync::RwLock; use trust_dns_server::authority::MessageResponseBuilder; use trust_dns_server::client::op::{Header, ResponseCode}; use trust_dns_server::client::rr::{Name, Record, RecordType}; +use trust_dns_server::proto::rr::rdata::a; use trust_dns_server::server::{Request, RequestHandler, ResponseHandler, ResponseInfo}; use trust_dns_server::ServerFuture; @@ -20,13 +21,13 @@ use crate::util::Invoke; use crate::{Error, ErrorKind, ResultExt, HOST_IP}; pub struct DnsController { - services: Arc>>>, + services: Arc>>>, #[allow(dead_code)] dns_server: NonDetachingJoinHandle>, } struct Resolver { - services: Arc>>>, + services: Arc>>>, } impl Resolver { async fn resolve(&self, name: &Name) -> Option> { @@ -61,32 +62,51 @@ impl RequestHandler for Resolver { ) -> ResponseInfo { let query = request.request_info().query; if let Some(ip) = self.resolve(query.name().borrow()).await { - if query.query_type() != RecordType::A { - tracing::warn!( - "Non A-Record requested for {}: {:?}", - query.name(), - query.query_type() - ); + match query.query_type() { + RecordType::A => { + response_handle + .send_response( + MessageResponseBuilder::from_message_request(&*request).build( + Header::response_from_request(request.header()), + &ip.into_iter() + .map(|ip| { + Record::from_rdata( + request.request_info().query.name().to_owned().into(), + 0, + trust_dns_server::client::rr::RData::A(ip), + ) + }) + .collect::>(), + [], + [], + [], + ), + ) + .await + } + a => { + if a != RecordType::AAAA { + tracing::warn!( + "Non A-Record requested for {}: {:?}", + query.name(), + query.query_type() + ); + } + let mut res = Header::response_from_request(request.header()); + res.set_response_code(ResponseCode::NXDomain); + response_handle + .send_response( + MessageResponseBuilder::from_message_request(&*request).build( + res.into(), + [], + [], + [], + [], + ), + ) + .await + } } - response_handle - .send_response( - MessageResponseBuilder::from_message_request(&*request).build( - Header::response_from_request(request.header()), - &ip.into_iter() - .map(|ip| { - Record::from_rdata( - request.request_info().query.name().to_owned().into(), - 0, - trust_dns_server::client::rr::RData::A(ip), - ) - }) - .collect::>(), - [], - [], - [], - ), - ) - .await } else { let mut res = Header::response_from_request(request.header()); res.set_response_code(ResponseCode::NXDomain); @@ -150,14 +170,16 @@ impl DnsController { pub async fn add(&self, pkg_id: &PackageId, ip: Ipv4Addr) { let mut writable = self.services.write().await; let mut ips = writable.remove(pkg_id).unwrap_or_default(); - ips.insert(ip); + ips.push(ip); writable.insert(pkg_id.clone(), ips); } pub async fn remove(&self, pkg_id: &PackageId, ip: Ipv4Addr) { let mut writable = self.services.write().await; let mut ips = writable.remove(pkg_id).unwrap_or_default(); - ips.remove(&ip); + if let Some((idx, _)) = ips.iter().copied().enumerate().find(|(_, x)| *x == ip) { + ips.swap_remove(idx); + } if !ips.is_empty() { writable.insert(pkg_id.clone(), ips); }