diff --git a/mtop-client/src/dns/name.rs b/mtop-client/src/dns/name.rs index b3b4738..4a90e73 100644 --- a/mtop-client/src/dns/name.rs +++ b/mtop-client/src/dns/name.rs @@ -226,7 +226,7 @@ impl FromStr for Name { } if s.len() > Self::MAX_LENGTH { - return Err(MtopError::runtime(format!( + return Err(MtopError::configuration(format!( "Name too long; max {} bytes, got {}", Self::MAX_LENGTH, s @@ -240,7 +240,7 @@ impl FromStr for Name { for label in s.trim_end_matches('.').split('.') { let len = label.len(); if len > Self::MAX_LABEL_LENGTH { - return Err(MtopError::runtime(format!( + return Err(MtopError::configuration(format!( "Name label too long; max {} bytes, got {}", Self::MAX_LABEL_LENGTH, label @@ -251,17 +251,17 @@ impl FromStr for Name { for (i, c) in label.char_indices() { if i == 0 && !c.is_ascii_alphanumeric() && c != '_' { - return Err(MtopError::runtime(format!( + return Err(MtopError::configuration(format!( "Name label must begin with ASCII letter, number, or underscore; got {}", label ))); } else if i == len - 1 && !c.is_ascii_alphanumeric() { - return Err(MtopError::runtime(format!( + return Err(MtopError::configuration(format!( "Name label must end with ASCII letter or number; got {}", label ))); } else if !c.is_ascii_alphanumeric() && c != '-' && c != '_' { - return Err(MtopError::runtime(format!( + return Err(MtopError::configuration(format!( "Name label must be ASCII letter, number, hyphen, or underscore; got {}", label ))); diff --git a/mtop/src/bin/dns.rs b/mtop/src/bin/dns.rs index 5b92080..cc672cb 100644 --- a/mtop/src/bin/dns.rs +++ b/mtop/src/bin/dns.rs @@ -5,13 +5,18 @@ use std::fmt::Write; use std::io::Cursor; use std::path::PathBuf; use std::process::ExitCode; -use std::str::FromStr; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::Arc; +use std::time::{Duration, Instant}; use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::{task, time}; use tracing::{Instrument, Level}; const DEFAULT_LOG_LEVEL: Level = Level::INFO; const DEFAULT_RECORD_TYPE: RecordType = RecordType::A; const DEFAULT_RECORD_CLASS: RecordClass = RecordClass::INET; +const DEFAULT_PING_INTERVAL_SECS: f64 = 1.0; +const MIN_PING_INTERVAL_SECS: f64 = 0.01; /// dns: Make DNS queries or read/write binary format DNS messages #[derive(Debug, Parser)] @@ -33,11 +38,49 @@ struct DnsConfig { #[derive(Debug, Subcommand)] enum Action { + Ping(PingCommand), Query(QueryCommand), Read(ReadCommand), Write(WriteCommand), } +/// Repeatedly perform a DNS query and display the time taken as ping-like text output. +#[derive(Debug, Args)] +struct PingCommand { + /// How often to run queries, in seconds. Fractional seconds are allowed. + #[arg(long, value_parser = parse_interval, default_value_t = DEFAULT_PING_INTERVAL_SECS)] + interval_secs: f64, + + /// Stop after performing `count` queries. Default is to run until interrupted. + #[arg(long, default_value_t = 0)] + count: u64, + + /// Path to resolv.conf file for loading DNS configuration information. If this file + /// can't be loaded, default values for DNS configuration are used instead. + #[arg(long, default_value = default_resolv_conf().into_os_string(), value_hint = ValueHint::FilePath)] + resolv_conf: PathBuf, + + /// Type of record to request. Supported: A, AAAA, CNAME, NS, SOA, SRV, TXT. + #[arg(long, default_value_t = DEFAULT_RECORD_TYPE)] + rtype: RecordType, + + /// Class of record to request. Supported: INET, CHAOS, HESIOD, NONE, ANY. + #[arg(long, default_value_t = DEFAULT_RECORD_CLASS)] + rclass: RecordClass, + + /// Domain name to lookup. + #[arg(required = true)] + name: Name, +} + +fn parse_interval(s: &str) -> Result { + match s.parse() { + Ok(v) if v >= MIN_PING_INTERVAL_SECS => Ok(v), + Ok(_) => Err(format!("must be at least {}", MIN_PING_INTERVAL_SECS)), + Err(e) => Err(e.to_string()), + } +} + /// Perform a DNS query and display the result as dig-like text output. #[derive(Debug, Args)] struct QueryCommand { @@ -62,7 +105,7 @@ struct QueryCommand { /// Domain name to lookup. #[arg(required = true)] - name: String, + name: Name, } fn default_resolv_conf() -> PathBuf { @@ -86,7 +129,7 @@ struct WriteCommand { /// Domain name to lookup. #[arg(required = true)] - name: String, + name: Name, } #[tokio::main] @@ -99,6 +142,7 @@ async fn main() -> ExitCode { let profiling = profile::Writer::default(); let code = match &opts.mode { + Action::Ping(cmd) => run_ping(cmd).await, Action::Query(cmd) => run_query(cmd).await, Action::Read(cmd) => run_read(cmd).await, Action::Write(cmd) => run_write(cmd).await, @@ -111,26 +155,78 @@ async fn main() -> ExitCode { code } -async fn run_query(cmd: &QueryCommand) -> ExitCode { +async fn run_ping(cmd: &PingCommand) -> ExitCode { let client = mtop::dns::new_client(&cmd.resolv_conf) .instrument(tracing::span!(Level::INFO, "dns.new_client")) .await; - let name = match Name::from_str(&cmd.name) { - Ok(n) => n, - Err(e) => { - tracing::error!(message = "invalid name supplied", name = cmd.name, err = %e); - return ExitCode::FAILURE; + + // This command runs until interrupted, so we need to handle SIGINT + // to stop gracefully. + let run = Arc::new(AtomicBool::new(true)); + let run_ref = run.clone(); + task::spawn(async move { + tokio::select! { + _ = tokio::signal::ctrl_c() => { + run_ref.store(false, Ordering::Release); + } } - }; + }); + + let mut interval = time::interval(Duration::from_secs_f64(cmd.interval_secs)); + let mut iterations = 0; + + while run.load(Ordering::Acquire) && (cmd.count == 0 || iterations < cmd.count) { + let _ = interval.tick().await; + // Create our own Instant to measure the time taken to perform the query since + // the one emitted by the interval isn't _immediately_ when the future resolves + // and so skews the measurement of queries. + let start = Instant::now(); + + match client + .resolve(cmd.name.clone(), cmd.rtype, cmd.rclass) + .instrument(tracing::span!(Level::INFO, "client.resolve")) + .await + { + Ok(r) => { + tracing::info!( + id = %r.id(), + name = %cmd.name, + response_code = ?r.flags().get_response_code(), + num_questions = r.questions().len(), + num_answers = r.answers().len(), + num_authority = r.authority().len(), + num_extra = r.extra().len(), + elapsed = ?start.elapsed(), + ); + } + Err(e) => { + tracing::error!(message = "failed to resolve", name = %cmd.name, err = %e); + } + } + + iterations += 1; + } + + if !run.load(Ordering::Acquire) { + tracing::info!("stopping on SIGINT"); + } + + ExitCode::SUCCESS +} + +async fn run_query(cmd: &QueryCommand) -> ExitCode { + let client = mtop::dns::new_client(&cmd.resolv_conf) + .instrument(tracing::span!(Level::INFO, "dns.new_client")) + .await; let response = match client - .resolve(name, cmd.rtype, cmd.rclass) + .resolve(cmd.name.clone(), cmd.rtype, cmd.rclass) .instrument(tracing::span!(Level::INFO, "client.resolve")) .await { Ok(r) => r, Err(e) => { - tracing::error!(message = "unable to perform DNS query", err = %e); + tracing::error!(message = "unable to perform DNS query", name = %cmd.name, err = %e); return ExitCode::FAILURE; } }; @@ -168,17 +264,9 @@ async fn run_read(_: &ReadCommand) -> ExitCode { } async fn run_write(cmd: &WriteCommand) -> ExitCode { - let name = match Name::from_str(&cmd.name) { - Ok(n) => n, - Err(e) => { - tracing::error!(message = "invalid name supplied", name = cmd.name, err = %e); - return ExitCode::FAILURE; - } - }; - let id = MessageId::random(); let msg = Message::new(id, Flags::default().set_query().set_recursion_desired()) - .add_question(Question::new(name, cmd.rtype).set_qclass(cmd.rclass)); + .add_question(Question::new(cmd.name.clone(), cmd.rtype).set_qclass(cmd.rclass)); write_binary_message(&msg).await }