Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhancement: support querying multiple servers at once #73

Closed
CosmicToast opened this issue Oct 31, 2023 · 1 comment
Closed

Enhancement: support querying multiple servers at once #73

CosmicToast opened this issue Oct 31, 2023 · 1 comment

Comments

@CosmicToast
Copy link

Support for querying multiple dns servers, like dog/doggo.
This is particularly useful when testing a given dns server for conformance against a known good server, or trying to debug issues.
Example in doggo (as doggo's output makes it more obvious):
image

Since q already supports querying multiple RRs at once, this seems like a logical extension, + allows for fuller parity.

@CosmicToast
Copy link
Author

I made an attempt at implementing this, and it's relatively straightforward until we get to the output, since we don't store the origination server in the replies, and do not output the replying-server by default.
Changing this is possible, but awkward due to #72.
Here's the diff in case someone wants to pick up from where I left off; all but the final section should be relevant:

diff --git a/cli/flags.go b/cli/flags.go
index 5e396ae..536b5eb 100644
--- a/cli/flags.go
+++ b/cli/flags.go
@@ -4,7 +4,7 @@ import "time"
 
 type Flags struct {
 	Name         string        `short:"q" long:"qname" description:"Query name"`
-	Server       string        `short:"s" long:"server" description:"DNS server"`
+	Server       []string      `short:"s" long:"server" description:"DNS server"`
 	Types        []string      `short:"t" long:"type" description:"RR type (e.g. A, AAAA, MX, etc.) or type integer"`
 	Reverse      bool          `short:"x" long:"reverse" description:"Reverse lookup"`
 	DNSSEC       bool          `short:"d" long:"dnssec" description:"Set the DO (DNSSEC OK) bit in the OPT record"`
diff --git a/main.go b/main.go
index 3076d9e..0071d22 100644
--- a/main.go
+++ b/main.go
@@ -228,7 +228,7 @@ func parseServer(s string) (string, transport.Type, error) {
 // driver is the "main" function for this program that accepts a flag slice for testing
 func driver(args []string, out io.Writer) error {
 	parser := flags.NewParser(&opts, flags.Default)
-	parser.Usage = `[OPTIONS] [@server] [type...] [name]
+	parser.Usage = `[OPTIONS] [@server...] [type...] [name]
 
 All long form (--) flags can be toggled with the dig-standard +[no]flag notation.`
 	_, err := parser.ParseArgs(args)
@@ -279,9 +279,9 @@ All long form (--) flags can be toggled with the dig-standard +[no]flag notation
 
 	// Add non-flag RR types
 	for _, arg := range args {
-		// Find a server by @ symbol if it isn't set by flag
-		if opts.Server == "" && strings.HasPrefix(arg, "@") {
-			opts.Server = strings.TrimPrefix(arg, "@")
+		// @ servers added to server list
+		if strings.HasPrefix(arg, "@") {
+			opts.Server = append(opts.Server, strings.TrimPrefix(arg, "@"))
 		}
 
 		// Parse chaos class
@@ -335,23 +335,24 @@ All long form (--) flags can be toggled with the dig-standard +[no]flag notation
 		log.Debugf("RR types: %+v", rrTypeStrings)
 	}
 
-	// Set default DNS server
-	if opts.Server == "" {
+	// Set default DNS server if none were set explicitly
+	if len(opts.Server) == 0 {
+		opts.Server = make([]string, 1)
 		if os.Getenv(defaultServerVar) != "" {
-			opts.Server = os.Getenv(defaultServerVar)
+			opts.Server[0] = os.Getenv(defaultServerVar)
 			log.Debugf("Using %s from %s environment variable", opts.Server, defaultServerVar)
 		} else {
 			log.Debugf("No server specified or %s set, using /etc/resolv.conf", defaultServerVar)
 			conf, err := dns.ClientConfigFromFile("/etc/resolv.conf")
 			if err != nil {
-				opts.Server = "https://cloudflare-dns.com/dns-query"
+				opts.Server[0] = "https://cloudflare-dns.com/dns-query"
 				log.Debugf("no server set, using %s", opts.Server)
 			} else {
 				if len(conf.Servers) == 0 {
-					opts.Server = "https://cloudflare-dns.com/dns-query"
+					opts.Server[0] = "https://cloudflare-dns.com/dns-query"
 					log.Debugf("no server set, using %s", opts.Server)
 				} else {
-					opts.Server = conf.Servers[0]
+					opts.Server[0] = conf.Servers[0]
 					log.Debugf("found server %s from /etc/resolv.conf", opts.Server)
 				}
 			}
@@ -363,8 +364,10 @@ All long form (--) flags can be toggled with the dig-standard +[no]flag notation
 		if !strings.HasPrefix(opts.ODoHProxy, "https://") {
 			return fmt.Errorf("ODoH proxy must use HTTPS")
 		}
-		if !strings.HasPrefix(opts.Server, "https://") {
-			return fmt.Errorf("ODoH target must use HTTPS")
+		for _, v := range opts.Server {
+			if !strings.HasPrefix(v, "https://") {
+				return fmt.Errorf("ODoH target must use HTTPS")
+			}
 		}
 	}
 
@@ -420,17 +423,20 @@ All long form (--) flags can be toggled with the dig-standard +[no]flag notation
 	)
 
 	// Parse server address and transport type
-	server, transportType, err := parseServer(opts.Server)
-	if err != nil {
-		return err
+	type tserver struct {
+		server string
+		ttype transport.Type
+		trans *transport.Transport
 	}
-	log.Debugf("Using server %s with transport %s", server, transportType)
-
-	// QUIC specific overrides
-	if transportType == transport.TypeQUIC {
-		tlsConfig.NextProtos = opts.QUICALPNTokens
-		// Skip ID check if QUIC (https://datatracker.ietf.org/doc/html/rfc9250#section-4.2.1)
-		opts.NoIDCheck = true
+	servers := make([]*tserver, 0, len(opts.Server))
+	for _, v := range opts.Server {
+		server, transportType, err := parseServer(v)
+		if err != nil {
+			log.Debugf("Skipping server %s due to error %s", v, err)
+			continue
+		}
+		log.Debugf("Adding server %s with transport %s", server, transportType)
+		servers = append(servers, &tserver{server, transportType, nil})
 	}
 
 	// Recursive zone transfer
@@ -438,28 +444,63 @@ All long form (--) flags can be toggled with the dig-standard +[no]flag notation
 		if opts.Name == "" {
 			return fmt.Errorf("no name specified for AXFR")
 		}
-		_ = RecAXFR(opts.Name, server, out)
+		for _, v := range servers {
+			_ = RecAXFR(opts.Name, v.server, out)
+		}
 		return nil
 	}
 
-	// Create transport
-	txp, err := newTransport(server, transportType, tlsConfig)
-	if err != nil {
-		return err
+	// Create transports
+	for _, v := range servers {
+		tlsConfig := tlsConfig.Clone()
+
+		// QUIC specific overrides
+		if v.ttype == transport.TypeQUIC {
+			tlsConfig.NextProtos = opts.QUICALPNTokens
+			// Skip ID check if QUIC (https://datatracker.ietf.org/doc/html/rfc9250#section-4.2.1)
+			opts.NoIDCheck = true // TODO: per-server overrides?
+		}
+
+		// Create transport
+		txp, err := newTransport(v.server, v.ttype, tlsConfig)
+		if err != nil {
+			log.Debugf("Skipping server %s due to error %s", v.server, err)
+			continue
+		}
+		v.trans = txp
 	}
 
-	startTime := time.Now()
-	var replies []*dns.Msg
-	for _, msg := range msgs {
-		reply, err := (*txp).Exchange(&msg)
-		if err != nil {
-			return err
+	// Filter failed servers, so we can preallocate responses
+	{
+		n := 0
+		for _, v := range servers {
+			if v.trans == nil {
+				continue
+			}
+			servers[n] = v
+			n++
 		}
+		servers = servers[:n]
+	}
 
-		if !opts.NoIDCheck && reply.Id != msg.Id {
-			return fmt.Errorf("ID mismatch: expected %d, got %d", msg.Id, reply.Id)
+	// preallocate replies storage
+	replies := make([]*dns.Msg, 0, len(msgs) * len(servers))
+	startTime := time.Now()
+	for _, v := range servers {
+		for _, msg := range msgs {
+			reply, err := (*v.trans).Exchange(&msg)
+			if err != nil {
+				log.Debugf("Skipping message %s with servers %s due to error %s",
+					&msg, v.server, err)
+				replies = append(replies, nil) // append nil so we can keep server sizes stable
+				continue
+			}
+
+			if !opts.NoIDCheck && reply.Id != msg.Id {
+				return fmt.Errorf("ID mismatch: expected %d, got %d", msg.Id, reply.Id)
+			}
+			replies = append(replies, reply)
 		}
-		replies = append(replies, reply)
 	}
 	queryTime := time.Since(startTime)
 
@@ -474,27 +515,35 @@ All long form (--) flags can be toggled with the dig-standard +[no]flag notation
 		output.PrettyPrintNSID(replies, out)
 	}
 
-	printer := output.Printer{
-		Server:     server,
-		Out:        out,
-		Opts:       &opts,
-		QueryTime:  queryTime,
-		NumReplies: len(replies),
-		Transport:  txp,
-	}
-	if opts.Format == "column" {
-		printer.PrintColumn(replies)
-	} else {
-		for i, reply := range replies {
-			switch opts.Format {
-			case "pretty":
-				printer.PrintPretty(i, reply)
-			case "raw":
-				printer.PrintRaw(i, reply)
-			case "json", "yml", "yaml":
-				printer.PrintStructured(i, reply)
-			default:
-				return fmt.Errorf("invalid output format")
+	for i, v := range servers {
+		fmt.Println(v.server)
+		printer := output.Printer{
+			Server:     v.server,
+			Out:        out,
+			Opts:       &opts,
+			QueryTime:  queryTime,
+			NumReplies: len(msgs), // TODO: filter nils
+			Transport:  v.trans,
+		}
+		if opts.Format == "column" {
+			printer.PrintColumn(replies[i * len(msgs) : (i+1) * len(msgs)])
+		} else {
+			for i, reply := range replies[i * len(msgs) : (i+1) * len(msgs)] {
+				if reply == nil {
+					continue
+				}
+				switch opts.Format {
+				case "pretty":
+					printer.PrintPretty(i, reply)
+				case "raw":
+					printer.PrintRaw(i, reply)
+				// TODO: jq and co can handle multipe separate json objects on stdin
+				// however, it would be nice to potentially return a [] instead
+				case "json", "yml", "yaml":
+					printer.PrintStructured(i, reply)
+				default:
+					return fmt.Errorf("invalid output format")
+				}
 			}
 		}
 	}

natesales added a commit that referenced this issue Nov 11, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants