Skip to content

Commit

Permalink
Code quality improvement
Browse files Browse the repository at this point in the history
  • Loading branch information
valeriansaliou committed Jul 22, 2024
1 parent bf80c3f commit 022096b
Show file tree
Hide file tree
Showing 2 changed files with 153 additions and 83 deletions.
230 changes: 150 additions & 80 deletions src/dns/handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ impl RequestHandler for DNSHandler {
request: &Request,
responder: R,
) -> ResponseInfo {
match self.handle(responder, request).await {
match self.handle_request(responder, request).await {
Ok(info) => {
debug!("success handling dns request");

Expand All @@ -69,11 +69,15 @@ impl DNSHandler {
}
}

pub fn upsert(&mut self, name: LowerName, authority: DNSAuthority) {
pub fn add_authority(&mut self, name: LowerName, authority: DNSAuthority) {
self.authorities.insert(name, authority);
}

async fn handle<R: ResponseHandler>(&self, responder: R, request: &Request) -> DNSResponse {
async fn handle_request<R: ResponseHandler>(
&self,
responder: R,
request: &Request,
) -> DNSResponse {
trace!("request: {:?} from: {}", request, request.src().ip());

match request.message_type() {
Expand All @@ -86,13 +90,13 @@ impl DNSHandler {
code @ _ => {
error!("unimplemented opcode: {:?}", code);

self.not_impl()
self.not_implemented()
}
},
MessageType::Response => {
warn!("got a response as a request from id: {}", request.id());

self.not_impl()
self.not_implemented()
}
}
}
Expand All @@ -109,7 +113,7 @@ impl DNSHandler {
// Notice: if zone cannot be found, then reject straight away.
// Notice: since we checked the status of the unwrapped query variable, this is panic-safe.
let query = request.query();
let authority_lookup = self.find_auth_recurse(query.name());
let authority_lookup = self.find_authority_recurse(query.name());

if authority_lookup.is_none() == true {
return self
Expand Down Expand Up @@ -161,59 +165,40 @@ impl DNSHandler {
Ok(records_remote) => {
// Serve response data?
if let Some(records_remote_inner) = records_remote {
debug!(
"found {} records for query from remote store: {:?}",
records_remote_inner.len(),
query
);

let records_remote_vec = records_remote_inner.iter().collect();

// Dispatch request from this block, as we cannot escape generated \
// record values lifetimes out of this context.
Self::serve_response_records(
self.lookup_remote_some(
responder,
request,
header,
&zone_name,
records_remote_vec,
query,
zone_name,
soa_records_vec,
records_remote_inner,
)
.await
} else {
// Serve error code
debug!("did not find records for query: {:?}", query);

let response_error = match records_local {
AuthLookup::Empty => {
debug!("domain not found for query: {:?}", query);

ResponseCode::NXDomain
}
AuthLookup::SOA { .. } => {
debug!("domain found for query: {:?}", query);

ResponseCode::NoError
}
AuthLookup::Records { .. } | AuthLookup::AXFR { .. } => {
// This code path is unexpected and should never be reached
panic!("error, should return noerror")
}
};

Self::stamp_header(request, &mut header, response_error, &zone_name);

// Dispatch empty records response
Self::dispatch(responder, request, header, None, Some(soa_records_vec)).await
self.lookup_remote_none(
responder,
request,
header,
query,
zone_name,
soa_records_vec,
records_local,
)
.await
}
}
Err(err) => {
debug!("query refused for: {:?} because: {}", query, err);

Self::stamp_header(request, &mut header, err, &zone_name);

// Dispatch error response
Self::dispatch(responder, request, header, None, Some(soa_records_vec)).await
self.lookup_remote_fail(
responder,
request,
header,
query,
zone_name,
soa_records_vec,
err,
)
.await
}
};
}
Expand All @@ -234,7 +219,7 @@ impl DNSHandler {
header.set_response_code(ResponseCode::Refused);

// Authority not found response dispatch
Self::dispatch(responder, request, header, None, None).await
Self::dispatch_response(responder, request, header, None, None).await
}

async fn lookup_local<'a, R: ResponseHandler>(
Expand All @@ -260,15 +245,100 @@ impl DNSHandler {
.await
}

fn not_impl(&self) -> DNSResponse {
async fn lookup_remote_some<'a, R: ResponseHandler>(
&self,
responder: R,
request: &MessageRequest,
header: Header,
query: &LowerQuery,
zone_name: Option<ZoneName>,
soa_records: Vec<&'a Record>,
records_remote: Vec<Record>,
) -> Result<ResponseInfo, Error> {
debug!(
"found {} records for query from remote store: {:?}",
records_remote.len(),
query
);

let records_remote_vec = records_remote.iter().collect();

// Dispatch request from this block, as we cannot escape generated \
// record values lifetimes out of this context.
Self::serve_response_records(
responder,
request,
header,
&zone_name,
records_remote_vec,
soa_records,
)
.await
}

async fn lookup_remote_none<'a, R: ResponseHandler>(
&self,
responder: R,
request: &MessageRequest,
mut header: Header,
query: &LowerQuery,
zone_name: Option<ZoneName>,
soa_records: Vec<&'a Record>,
records_local: AuthLookup,
) -> Result<ResponseInfo, Error> {
// Serve error code
debug!("did not find records for query: {:?}", query);

let response_error = match records_local {
AuthLookup::Empty => {
debug!("domain not found for query: {:?}", query);

ResponseCode::NXDomain
}
AuthLookup::SOA { .. } => {
debug!("domain found for query: {:?}", query);

ResponseCode::NoError
}
AuthLookup::Records { .. } | AuthLookup::AXFR { .. } => {
// This code path is unexpected and should never be reached
panic!("error, should return noerror")
}
};

Self::stamp_header(request, &mut header, response_error, &zone_name);

// Dispatch empty records response
Self::dispatch_response(responder, request, header, None, Some(soa_records)).await
}

async fn lookup_remote_fail<'a, R: ResponseHandler>(
&self,
responder: R,
request: &MessageRequest,
mut header: Header,
query: &LowerQuery,
zone_name: Option<ZoneName>,
soa_records: Vec<&'a Record>,
code: ResponseCode,
) -> Result<ResponseInfo, Error> {
debug!("query refused for: {:?} because: {}", query, code);

Self::stamp_header(request, &mut header, code, &zone_name);

// Dispatch error response
Self::dispatch_response(responder, request, header, None, Some(soa_records)).await
}

fn not_implemented(&self) -> DNSResponse {
let mut header = Header::new();

header.set_response_code(ResponseCode::NotImp);

Ok(header.into())
}

async fn dispatch<'a, R: ResponseHandler>(
async fn dispatch_response<'a, R: ResponseHandler>(
mut responder: R,
request: &MessageRequest,
header: Header,
Expand Down Expand Up @@ -308,6 +378,31 @@ impl DNSHandler {
responder.send_response(response_message).await
}

fn stamp_header<'a, 'b>(
request: &MessageRequest,
header: &mut Header,
code: ResponseCode,
zone_name: &Option<ZoneName>,
) {
// Stack answer code to metrics?
if let Some(ref zone_name) = zone_name {
let code_name = CodeName::from_hickory(&code);

METRICS_STORE.stack(zone_name, MetricsValue::AnswerCode(&code_name));
}

// Stamp with response code
header.set_response_code(code);

// Stamp response with 'AA' flag (we are authoritative on served zone)
header.set_authoritative(true);

// Stamp response with 'RD' flag? (if requested by client)
if request.recursion_desired() == true {
header.set_recursion_desired(true);
}
}

async fn serve_response_records<'a, 'b, R: ResponseHandler>(
responder: R,
request: &MessageRequest,
Expand All @@ -318,10 +413,10 @@ impl DNSHandler {
) -> DNSResponse {
Self::stamp_header(request, &mut header, ResponseCode::NoError, zone_name);

Self::dispatch(responder, request, header, Some(records), Some(soa_records)).await
Self::dispatch_response(responder, request, header, Some(records), Some(soa_records)).await
}

fn find_auth_recurse(&self, name: &LowerName) -> Option<&DNSAuthority> {
fn find_authority_recurse(&self, name: &LowerName) -> Option<&DNSAuthority> {
let authority = self.authorities.get(name);

if authority.is_some() {
Expand All @@ -330,7 +425,7 @@ impl DNSHandler {
let name = name.base_name();

if !name.is_root() {
return self.find_auth_recurse(&name);
return self.find_authority_recurse(&name);
}
}

Expand Down Expand Up @@ -758,31 +853,6 @@ impl DNSHandler {
}
}

fn stamp_header<'a, 'b>(
request: &MessageRequest,
header: &mut Header,
code: ResponseCode,
zone_name: &Option<ZoneName>,
) {
// Stack answer code to metrics?
if let Some(ref zone_name) = zone_name {
let code_name = CodeName::from_hickory(&code);

METRICS_STORE.stack(zone_name, MetricsValue::AnswerCode(&code_name));
}

// Stamp with response code
header.set_response_code(code);

// Stamp response with 'AA' flag (we are authoritative on served zone)
header.set_authoritative(true);

// Stamp response with 'RD' flag? (if requested by client)
if request.recursion_desired() == true {
header.set_recursion_desired(true);
}
}

async fn check_name_exists(
zone_name: &ZoneName,
record_name: &RecordName,
Expand Down
6 changes: 3 additions & 3 deletions src/dns/listen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ impl DNSListen {
let mut handler: DNSHandler = DNSHandler::new();

for (zone_name, _) in &APP_CONF.dns.zone {
match Self::map_authority(&zone_name) {
Ok((name, authority)) => handler.upsert(LowerName::new(&name), authority),
match Self::zone_authority(&zone_name) {
Ok((name, authority)) => handler.add_authority(LowerName::new(&name), authority),
Err(_) => error!("could not load zone {}", zone_name),
}
}
Expand Down Expand Up @@ -76,7 +76,7 @@ impl DNSListen {
}
}

fn map_authority(zone_name: &str) -> Result<(Name, DNSAuthority), ()> {
fn zone_authority(zone_name: &str) -> Result<(Name, DNSAuthority), ()> {
if let Ok(name) = Name::parse(zone_name, Some(&Name::new())) {
let mut records = BTreeMap::new();

Expand Down

0 comments on commit 022096b

Please sign in to comment.