forked from folbricht/routedns
-
Notifications
You must be signed in to change notification settings - Fork 0
/
fastest-tcp.go
208 lines (188 loc) · 5.6 KB
/
fastest-tcp.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
package rdns
import (
"context"
"errors"
"fmt"
"net"
"strconv"
"time"
"github.com/miekg/dns"
"github.com/sirupsen/logrus"
)
// FastestTCP first resolves the query with the upstream resolver, then
// performs TCP connection tests with the response IPs to determine which
// IP responds the fastest. This IP is then returned in the response as first
// A/AAAA record. This should be used in combination with a Cache to avoid
// the TCP connection overhead on every query.
type FastestTCP struct {
id string
resolver Resolver
opt FastestTCPOptions
port string
}
var _ Resolver = &FastestTCP{}
// FastestTCPOptions contain settings for a resolver that filters responses
// based on TCP connection probes.
type FastestTCPOptions struct {
// Port number to use for TCP probes, default 443
Port int
// Wait for all connection probes and sort the responses based on time
// (fastest first). This is generally slower than just waiting for the
// fastest, since the response time is determined by the slowest probe.
WaitAll bool
// TTL set on all RRs when TCP probing was successful. Can be used to
// ensure these are kept for longer in a cache and improve performance.
SuccessTTLMin uint32
}
// NewFastestTCP returns a new instance of a TCP probe resolver.
func NewFastestTCP(id string, resolver Resolver, opt FastestTCPOptions) *FastestTCP {
port := strconv.Itoa(opt.Port)
if port == "0" {
port = "443"
}
return &FastestTCP{
id: id,
resolver: resolver,
opt: opt,
port: port,
}
}
// Resolve a DNS query and order the response based on which IP was able to establish
// a TCP connection the fastest.
func (r *FastestTCP) Resolve(q *dns.Msg, ci ClientInfo) (*dns.Msg, error) {
log := logger(r.id, q, ci)
a, err := r.resolver.Resolve(q, ci)
if err != nil {
return a, err
}
question := q.Question[0]
// Don't need to do anything if the query wasn't for an IP
if question.Qtype != dns.TypeA && question.Qtype != dns.TypeAAAA {
return a, nil
}
fmt.Println("Responses")
fmt.Println(a)
// Extract the IP responses
var ipRRs []dns.RR
for _, rr := range a.Answer {
if rr.Header().Rrtype == question.Qtype {
ipRRs = append(ipRRs, rr)
}
}
// If there's only one IP in the response, nothing to probe
if len(ipRRs) < 2 {
return a, nil
}
// Send TCP probes to all, if anything returns an error, just return
// the original response rather than trying to be clever and pick one.
log = log.WithField("port", r.port)
var sorted []dns.RR
if r.opt.WaitAll {
sorted, err = r.probeAll(log, ipRRs)
} else {
sorted, err = r.probeFastest(log, ipRRs)
}
if err != nil {
log.WithError(err).Debug("tcp probe failed")
return a, nil
}
r.setTTL(sorted...)
// Merge the sorted list of RRs back into the original answer in the same
// positions. The original answer could have CNAMEs and other types in it.
for i, rr := range a.Answer {
if rr.Header().Rrtype == question.Qtype {
a.Answer[i] = sorted[0]
sorted = sorted[1:]
}
}
return a, nil
}
// Sets the TTL of the given RRs if the option was provided
func (r *FastestTCP) setTTL(rrs ...dns.RR) {
for _, rr := range rrs {
h := rr.Header()
if h.Ttl < r.opt.SuccessTTLMin {
h.Ttl = r.opt.SuccessTTLMin
}
}
}
func (r *FastestTCP) String() string {
return r.id
}
// Probes all IPs and returns only the RR with the fastest responding IP.
// Waits for the first one that comes back. Returns an error if the fastest response
// is an error.
func (r *FastestTCP) probeFastest(log logrus.FieldLogger, rrs []dns.RR) ([]dns.RR, error) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
resultCh := r.probe(ctx, log, rrs)
select {
case res := <-resultCh:
// Re-order the list in-place to put the fastest at the top
rr := res.rr
err := res.err
for i := 0; i < len(rrs); i++ {
if rrs[i] == rr {
return rrs, err
}
rrs[i], rr = rr, rrs[i]
}
return rrs, err
case <-ctx.Done():
return nil, ctx.Err()
}
}
// Probes all IPs and returns them in the order of response time, fastest first. Returns
// an error if any of the probes fail or if the probe times out.
func (r *FastestTCP) probeAll(log logrus.FieldLogger, rrs []dns.RR) ([]dns.RR, error) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
resultCh := r.probe(ctx, log, rrs)
results := make([]dns.RR, 0, len(rrs))
for i := 0; i < len(rrs); i++ {
select {
case res := <-resultCh:
if res.err != nil {
return nil, res.err
}
results = append(results, res.rr)
case <-ctx.Done():
return nil, ctx.Err()
}
}
return results, nil
}
type tcpProbeResult struct {
rr dns.RR
err error
}
// Probes all IPs and returns a channel with responses in the order they succeed or fail.
func (r *FastestTCP) probe(ctx context.Context, log logrus.FieldLogger, rrs []dns.RR) <-chan tcpProbeResult {
resultCh := make(chan tcpProbeResult)
for _, rr := range rrs {
var d net.Dialer
go func(rr dns.RR) {
var network, ip string
switch record := rr.(type) {
case *dns.A:
network, ip = "tcp4", record.A.String()
case *dns.AAAA:
network, ip = "tcp6", record.AAAA.String()
default:
resultCh <- tcpProbeResult{err: errors.New("unexpected resource type")}
return
}
start := time.Now()
log.WithField("ip", ip).Debug("sending tcp probe")
c, err := d.DialContext(ctx, network, net.JoinHostPort(ip, r.port))
if err != nil {
resultCh <- tcpProbeResult{err: err}
return
}
log.WithField("ip", ip).WithField("response-time", time.Since(start)).Debug("tcp probe finished")
defer c.Close()
resultCh <- tcpProbeResult{rr: rr}
}(rr)
}
return resultCh
}