Skip to content

Commit

Permalink
add a retry wrapper for test timeouts
Browse files Browse the repository at this point in the history
  • Loading branch information
rthalley committed Oct 8, 2024
1 parent 8ce4fcf commit e9da624
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 0 deletions.
11 changes: 11 additions & 0 deletions tests/test_doh.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
import socket
import unittest

import dns.exception

try:
import ssl

Expand Down Expand Up @@ -88,6 +90,7 @@ def setUp(self):
def tearDown(self):
self.session.close()

@tests.util.retry_on_timeout
def test_get_request(self):
nameserver_url = random.choice(KNOWN_ANYCAST_DOH_RESOLVER_URLS)
q = dns.message.make_query("example.com.", dns.rdatatype.A)
Expand All @@ -101,6 +104,7 @@ def test_get_request(self):
)
self.assertTrue(q.is_response(r))

@tests.util.retry_on_timeout
def test_post_request(self):
nameserver_url = random.choice(KNOWN_ANYCAST_DOH_RESOLVER_URLS)
q = dns.message.make_query("example.com.", dns.rdatatype.A)
Expand All @@ -114,6 +118,7 @@ def test_post_request(self):
)
self.assertTrue(q.is_response(r))

@tests.util.retry_on_timeout
def test_build_url_from_ip(self):
self.assertTrue(resolver_v4_addresses or resolver_v6_addresses)
if resolver_v4_addresses:
Expand Down Expand Up @@ -159,12 +164,14 @@ def test_build_url_from_ip(self):
# )
# self.assertTrue(q.is_response(r))

@tests.util.retry_on_timeout
def test_new_session(self):
nameserver_url = random.choice(KNOWN_ANYCAST_DOH_RESOLVER_URLS)
q = dns.message.make_query("example.com.", dns.rdatatype.A)
r = dns.query.https(q, nameserver_url, timeout=4)
self.assertTrue(q.is_response(r))

@tests.util.retry_on_timeout
def test_resolver(self):
res = dns.resolver.Resolver(configure=False)
res.nameservers = ["https://dns.google/dns-query"]
Expand All @@ -173,6 +180,7 @@ def test_resolver(self):
self.assertTrue("8.8.8.8" in seen)
self.assertTrue("8.8.4.4" in seen)

@tests.util.retry_on_timeout
def test_padded_get(self):
nameserver_url = random.choice(KNOWN_PAD_AWARE_DOH_RESOLVER_URLS)
q = dns.message.make_query("example.com.", dns.rdatatype.A, use_edns=0, pad=128)
Expand All @@ -194,6 +202,7 @@ def test_padded_get(self):
"Aioquic cannot be imported; no DNS over HTTP3 (DOH3)",
)
class DNSOverHTTP3TestCase(unittest.TestCase):
@tests.util.retry_on_timeout
def testDoH3GetRequest(self):
nameserver_url = random.choice(KNOWN_ANYCAST_DOH3_RESOLVER_URLS)
q = dns.message.make_query("dns.google.", dns.rdatatype.A)
Expand All @@ -207,6 +216,7 @@ def testDoH3GetRequest(self):
)
self.assertTrue(q.is_response(r))

@tests.util.retry_on_timeout
def testDoH3PostRequest(self):
nameserver_url = random.choice(KNOWN_ANYCAST_DOH3_RESOLVER_URLS)
q = dns.message.make_query("dns.google.", dns.rdatatype.A)
Expand All @@ -220,6 +230,7 @@ def testDoH3PostRequest(self):
)
self.assertTrue(q.is_response(r))

@tests.util.retry_on_timeout
def test_build_url_from_ip(self):
self.assertTrue(resolver_v4_addresses or resolver_v6_addresses)
if resolver_v4_addresses:
Expand Down
20 changes: 20 additions & 0 deletions tests/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.

import enum
import functools
import inspect
import os

import dns.exception
import dns.message
import dns.name
import dns.query
Expand Down Expand Up @@ -131,3 +133,21 @@ def is_docker() -> bool:
return os.path.isfile("/.dockerenv")
except Exception:
return False


def retry_on_timeout(f):
@functools.wraps(f)
def wrapper(*args, **kwargs):
bad = True
for i in range(3):
try:
print("TRY", i, "for", f.__name__)
f(*args, **kwargs)
bad = False
break
except dns.exception.Timeout:
pass
if bad:
raise dns.exception.Timeout

return wrapper

0 comments on commit e9da624

Please sign in to comment.