Skip to content

Commit

Permalink
tests: add type checking and fix type bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
joelynch committed Oct 27, 2023
1 parent d7e724b commit f68bd7c
Show file tree
Hide file tree
Showing 28 changed files with 91 additions and 54 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ repos:
rev: v1.0.0
hooks:
- id: mypy
exclude: ^setup.py|^tests/|^vendor/|^astacus/proto/
exclude: ^setup.py|^vendor/|^astacus/proto/
additional_dependencies:
- types-PyYAML>=6.0.12.2
- types-requests>=2.28.11.5
Expand Down
4 changes: 4 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,10 @@ yapf: black
black:
pre-commit run black --all-files

.PHONY: mypy
mypy:
pre-commit run mypy --all-files

.PHONY: reformat
reformat: isort black

Expand Down
4 changes: 2 additions & 2 deletions astacus/common/cassandra/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from .utils import is_system_keyspace
from astacus.common.utils import AstacusModel
from cassandra import metadata as cm
from typing import Any, Dict, Iterator, List, Set
from typing import Any, Dict, Iterator, List, Mapping, Set

import hashlib
import itertools
Expand Down Expand Up @@ -236,7 +236,7 @@ def _extract_dcs(metadata: cm.KeyspaceMetadata) -> Dict[str, str]:


class CassandraKeyspace(CassandraNamed):
network_topology_strategy_dcs: Dict[str, str] # not [str, int] because of transient replication
network_topology_strategy_dcs: Mapping[str, str] # not [str, int] because of transient replication
durable_writes: bool

aggregates: List[CassandraAggregate]
Expand Down
14 changes: 7 additions & 7 deletions astacus/coordinator/plugins/clickhouse/steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ class RetrieveDatabasesAndTablesStep(Step[DatabasesAndTables]):
node), and relies on that to query only the first server of the cluster.
"""

clients: List[ClickHouseClient]
clients: Sequence[ClickHouseClient]

async def run_step(self, cluster: Cluster, context: StepsContext) -> DatabasesAndTables:
clickhouse_client = self.clients[0]
Expand Down Expand Up @@ -279,7 +279,7 @@ class RemoveFrozenTablesStep(Step[None]):
When the system unfreeze flag is enabled, clears frozen parts from all disks in a single go.
"""

clients: List[ClickHouseClient]
clients: Sequence[ClickHouseClient]
freeze_name: str
unfreeze_timeout: float

Expand All @@ -293,7 +293,7 @@ async def run_step(self, cluster: Cluster, context: StepsContext) -> None:

@dataclasses.dataclass
class FreezeUnfreezeTablesStepBase(Step[None]):
clients: List[ClickHouseClient]
clients: Sequence[ClickHouseClient]
freeze_name: str
freeze_unfreeze_timeout: float

Expand Down Expand Up @@ -429,7 +429,7 @@ class RestoreReplicatedDatabasesStep(Step[None]):
After this step, all tables will be empty.
"""

clients: List[ClickHouseClient]
clients: Sequence[ClickHouseClient]
replicated_databases_zookeeper_path: str
replicated_database_settings: ReplicatedDatabaseSettings
drop_databases_timeout: float
Expand Down Expand Up @@ -615,7 +615,7 @@ class RestoreReplicaStep(Step[None]):
"""

zookeeper_client: ZooKeeperClient
clients: List[ClickHouseClient]
clients: Sequence[ClickHouseClient]
disks: Disks
restart_timeout: float
max_concurrent_restart_per_node: int
Expand Down Expand Up @@ -699,7 +699,7 @@ class AttachMergeTreePartsStep(Step[None]):
details.
"""

clients: List[ClickHouseClient]
clients: Sequence[ClickHouseClient]
disks: Disks
attach_timeout: float
max_concurrent_attach_per_node: int
Expand Down Expand Up @@ -735,7 +735,7 @@ class SyncTableReplicasStep(Step[None]):
are all exchanged between all nodes.
"""

clients: List[ClickHouseClient]
clients: Sequence[ClickHouseClient]
sync_timeout: float
max_concurrent_sync_per_node: int

Expand Down
9 changes: 6 additions & 3 deletions tests/integration/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ async def run_process_and_wait_for_pattern(
fail_pattern: Optional[str] = None,
env: Mapping[str, str] = MappingProxyType({}),
timeout: float = 10.0,
) -> AsyncIterator[asyncio.subprocess.Process]:
) -> AsyncIterator[subprocess.Popen[bytes]]:
# This stringification is a workaround for a bug in pydev (pydev_monkey.py:111)
str_args = [str(arg) for arg in args]
pattern_found = asyncio.Event()
Expand Down Expand Up @@ -102,7 +102,7 @@ def read_logs() -> None:

@dataclasses.dataclass
class Service:
process: asyncio.subprocess.Process
process: subprocess.Popen[bytes] | asyncio.subprocess.Process
data_dir: Path
port: int
host: str = "localhost"
Expand Down Expand Up @@ -161,6 +161,8 @@ async def create_zookeeper(ports: Ports) -> AsyncIterator[Service]:
java_path = await get_command_path("java")
if java_path is None:
pytest.skip("java installation not found")
# newer versions of mypy should be able to infer that this is unreachable
assert False
port = ports.allocate()
with tempfile.TemporaryDirectory(prefix=f"zookeeper_{port}_") as data_dir_str:
data_dir = Path(data_dir_str)
Expand All @@ -177,6 +179,7 @@ async def create_zookeeper(ports: Ports) -> AsyncIterator[Service]:
command = get_zookeeper_command(java_path=java_path, data_dir=data_dir, port=port)
if command is None:
pytest.skip("zookeeper installation not found")
assert False
async with contextlib.AsyncExitStack() as stack:
max_attempts = 10
for attempt in range(max_attempts):
Expand All @@ -190,9 +193,9 @@ async def create_zookeeper(ports: Ports) -> AsyncIterator[Service]:
timeout=20.0,
)
)
yield Service(process=process, port=port, data_dir=data_dir)
break
except FailPatternFoundError:
if attempt + 1 == max_attempts:
raise
await asyncio.sleep(2.0)
yield Service(process=process, port=port, data_dir=data_dir)
13 changes: 11 additions & 2 deletions tests/integration/coordinator/plugins/clickhouse/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import pytest
import rohmu
import secrets
import subprocess
import sys
import tempfile
import urllib.parse
Expand Down Expand Up @@ -126,7 +127,7 @@ class MinioBucket:

@dataclasses.dataclass(frozen=True)
class MinioService:
process: asyncio.subprocess.Process
process: subprocess.Popen[bytes]
data_dir: Path
host: str
server_port: int
Expand Down Expand Up @@ -217,7 +218,15 @@ async def create_minio_service(ports: Ports) -> AsyncIterator[MinioService]:
"MINIO_ROOT_USER": root_user,
"MINIO_ROOT_PASSWORD": root_password,
}
command = ["/usr/bin/minio", "server", data_dir, "--address", server_netloc, "--console-address", console_netloc]
command: List[Union[str, Path]] = [
"/usr/bin/minio",
"server",
data_dir,
"--address",
server_netloc,
"--console-address",
console_netloc,
]
async with run_process_and_wait_for_pattern(
args=command,
cwd=data_dir,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .conftest import ClickHouseCommand, create_clickhouse_service, get_clickhouse_client
from astacus.coordinator.plugins.clickhouse.client import ClickHouseClientQueryError
from tests.integration.conftest import Ports, Service
from typing import cast, Sequence

import pytest
import time
Expand All @@ -18,7 +19,7 @@
@pytest.mark.asyncio
async def test_client_execute(clickhouse: Service) -> None:
client = get_clickhouse_client(clickhouse)
response = await client.execute(b"SHOW DATABASES")
response = cast(Sequence[list[str]], await client.execute(b"SHOW DATABASES"))
assert sorted(list(response)) == [["INFORMATION_SCHEMA"], ["default"], ["information_schema"], ["system"]]


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ async def test_restores_table_with_nested_fields(restored_cluster: List[ClickHou
assert response == [[123, [4], [5]]]


async def check_object_storage_data(cluster: List[ClickHouseClient]) -> None:
async def check_object_storage_data(cluster: Sequence[ClickHouseClient]) -> None:
s1_data = [[123, "foo"], [456, "bar"]]
s2_data = [[789, "baz"]]
cluster_data = [s1_data, s1_data, s2_data]
Expand Down
11 changes: 9 additions & 2 deletions tests/integration/coordinator/plugins/clickhouse/test_steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from astacus.coordinator.plugins.clickhouse.steps import RetrieveDatabasesAndTablesStep
from base64 import b64decode
from tests.integration.conftest import create_zookeeper, Ports
from typing import cast, Sequence
from uuid import UUID

import pytest
Expand Down Expand Up @@ -50,11 +51,17 @@ async def test_retrieve_tables(ports: Ports, clickhouse_command: ClickHouseComma
step = RetrieveDatabasesAndTablesStep(clients=[client])
context = StepsContext()
databases, tables = await step.run_step(Cluster(nodes=[]), context=context)
database_uuid_lines = list(await client.execute(b"SELECT base64Encode(name),uuid FROM system.databases"))
database_uuid_lines = cast(
Sequence[tuple[str, str]],
await client.execute(b"SELECT base64Encode(name),uuid FROM system.databases"),
)
database_uuids = {
b64decode(database_name): UUID(database_uuid) for database_name, database_uuid in database_uuid_lines
}
table_uuid_lines = list(await client.execute("SELECT uuid FROM system.tables where name = 'tablé_1'".encode()))
table_uuid_lines: Sequence[Sequence[str]] = cast(
list[tuple[str]],
list(await client.execute("SELECT uuid FROM system.tables where name = 'tablé_1'".encode())),
)
table_uuid = UUID(table_uuid_lines[0][0])

assert databases == [
Expand Down
5 changes: 3 additions & 2 deletions tests/integration/coordinator/plugins/flink/test_steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from astacus.coordinator.plugins.flink.manifest import FlinkManifest
from astacus.coordinator.plugins.flink.steps import FlinkManifestStep, RestoreDataStep, RetrieveDataStep
from astacus.coordinator.plugins.zookeeper import KazooZooKeeperClient
from unittest.mock import Mock
from uuid import uuid4

import pytest
Expand Down Expand Up @@ -35,6 +36,6 @@ async def test_restore_data(zookeeper_client: KazooZooKeeperClient):
manifest = FlinkManifest(data=data)
context = StepsContext()
context.set_result(FlinkManifestStep, manifest)
await RestoreDataStep(zookeeper_client, ["catalog", "flink"]).run_step(cluster=None, context=context)
res = await RetrieveDataStep(zookeeper_client, ["catalog", "flink"]).run_step(cluster=None, context=None)
await RestoreDataStep(zookeeper_client, ["catalog", "flink"]).run_step(cluster=Mock(), context=context)
res = await RetrieveDataStep(zookeeper_client, ["catalog", "flink"]).run_step(cluster=Mock(), context=Mock())
assert res == data
1 change: 1 addition & 0 deletions tests/system/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,4 +188,5 @@ def astacus_run(


def astacus_ls(astacus: TestNode) -> List[str]:
assert astacus.root_path
return sorted(str(x.relative_to(astacus.root_path)) for x in astacus.root_path.glob("**/*"))
1 change: 1 addition & 0 deletions tests/system/test_config_reload.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def test_reload_config(tmpdir, rootdir: str, astacus1: TestNode, astacus2: TestN
check_without_status = astacus_run(rootdir, astacus1, "check-reload", check=True, capture_output=True)
assert check_without_status.returncode == 0
assert "Configuration does not need to be reloaded" in check_without_status.stdout.decode()
assert astacus1.root_path
# Write some data to backup, including files that don't match the reloaded glob
(astacus1.root_path / "saved.foo").write_text("dont_care")
(astacus1.root_path / "ignored.bar").write_text("dont_care")
Expand Down
9 changes: 5 additions & 4 deletions tests/unit/common/cassandra/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@
"""

from astacus.common.cassandra.config import CassandraClientConfiguration
from pytest_mock import MockFixture

import astacus.common.cassandra.client as client_module
import pytest


def test_cassandra_session(mocker):
def test_cassandra_session(mocker: MockFixture) -> None:
ccluster = mocker.MagicMock()
csession = mocker.MagicMock()
session = client_module.CassandraSession(cluster=ccluster, session=csession)
Expand All @@ -32,7 +33,7 @@ def test_cassandra_session(mocker):
assert len(csession.execute.mock_calls) == 2


def create_client(mocker, ssl=False):
def create_client(mocker: MockFixture, ssl: bool = False) -> client_module.CassandraClient:
mocker.patch.object(client_module, "Cluster")
mocker.patch.object(client_module, "WhiteListRoundRobinPolicy")

Expand All @@ -47,14 +48,14 @@ def create_client(mocker, ssl=False):


@pytest.mark.parametrize("ssl", [False, True])
def test_cassandra_client(mocker, ssl):
def test_cassandra_client(mocker: MockFixture, ssl: bool) -> None:
client: client_module.CassandraClient = create_client(mocker, ssl=ssl)
with client.connect() as session:
assert isinstance(session, client_module.CassandraSession)


@pytest.mark.asyncio
async def test_cassandra_client_run(mocker):
async def test_cassandra_client_run(mocker: MockFixture):
client = create_client(mocker)

def test_fun(cas):
Expand Down
3 changes: 2 additions & 1 deletion tests/unit/common/test_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
"""
from astacus.common import op
from astacus.common.exceptions import ExpiredOperationException
from astacus.common.statsd import StatsClient
from starlette.background import BackgroundTasks
from starlette.datastructures import URL
Expand Down Expand Up @@ -35,7 +36,7 @@ def set_status_fail(self):
(None, op.Op.Status.done, None),
# If operation throws ExpiredOperationException, op status
# should stay running as it may point to the next operation
(op.ExpiredOperationException, op.Op.Status.running, None),
(ExpiredOperationException, op.Op.Status.running, None),
# If operation throws 'something else', it should fail the op status
(AssertionError, op.Op.Status.fail, AssertionError),
],
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/common/test_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from astacus.common import exceptions
from astacus.common.cachingjsonstorage import CachingJsonStorage
from astacus.common.rohmustorage import RohmuConfig, RohmuStorage
from astacus.common.storage import FileStorage, JsonStorage
from astacus.common.storage import FileStorage, Json, JsonStorage
from contextlib import nullcontext as does_not_raise
from pathlib import Path
from rohmu.object_storage import google
Expand All @@ -24,7 +24,7 @@
TEXT_HEXDIGEST_DATA = b"data" * 15

TEST_JSON = "jsonblob"
TEST_JSON_DATA = {"foo": 7, "array": [1, 2, 3], "true": True}
TEST_JSON_DATA: Json = {"foo": 7, "array": [1, 2, 3], "true": True}


def create_storage(*, tmpdir, engine, **kw):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class RetrieveTestCase:
duplicate_address: bool = False

# Output
expected_error: Optional[Exception] = None
expected_error: Optional[type[Exception]] = None

def __str__(self):
return self.name
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/coordinator/plugins/cassandra/test_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from astacus.coordinator.plugins.base import StepFailedError
from astacus.coordinator.plugins.cassandra import plugin
from astacus.coordinator.plugins.cassandra.model import CassandraConfigurationNode
from tests.unit.node.test_node_cassandra import CassandraTestConfig
from tests.unit.conftest import CassandraTestConfig
from types import SimpleNamespace

import pytest
Expand Down
10 changes: 5 additions & 5 deletions tests/unit/coordinator/plugins/clickhouse/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,16 @@
See LICENSE for details
"""
from astacus.coordinator.plugins.clickhouse.client import HttpClickHouseClient
from astacus.coordinator.plugins.clickhouse.config import ClickHouseConfiguration, ClickHouseNode
from astacus.coordinator.plugins.clickhouse.plugin import get_clickhouse_clients, get_zookeeper_client
from astacus.coordinator.plugins.zookeeper import KazooZooKeeperClient, KazooZooKeeperConnection
from astacus.coordinator.plugins.clickhouse.config import ClickHouseConfiguration, ClickHouseNode, get_clickhouse_clients
from astacus.coordinator.plugins.zookeeper import KazooZooKeeperClient, KazooZooKeeperConnection, ZooKeeperUser
from astacus.coordinator.plugins.zookeeper_config import (
get_zookeeper_client,
ZooKeeperConfiguration,
ZooKeeperConfigurationUser,
ZooKeeperNode,
ZooKeeperUser,
)
from kazoo.client import KazooClient
from pydantic import SecretStr
from typing import cast, List

import pytest
Expand All @@ -31,7 +31,7 @@ def test_get_zookeeper_client() -> None:
def test_get_authenticated_zookeeper_client() -> None:
configuration = ZooKeeperConfiguration(
nodes=[ZooKeeperNode(host="::1", port=5556)],
user=ZooKeeperConfigurationUser(username="local-user", password="secret"),
user=ZooKeeperConfigurationUser(username="local-user", password=SecretStr("secret")),
)
client = get_zookeeper_client(configuration)
assert client is not None and isinstance(client, KazooZooKeeperClient)
Expand Down
Loading

0 comments on commit f68bd7c

Please sign in to comment.