Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: initial py-test setup & update Dockerfile #404

Merged
merged 8 commits into from
Dec 20, 2024
Merged
4 changes: 4 additions & 0 deletions src/backend/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@ RUN --mount=type=cache,target=/root/.cache/uv \
RUN --mount=type=cache,target=/root/.cache/uv \
uv sync

# Install the test dependencies using uv
RUN --mount=type=cache,target=/root/.cache/uv \
uv sync --group test

# Run stage (final stage)
FROM python:$PYTHON_BASE AS service

Expand Down
2 changes: 1 addition & 1 deletion src/backend/app/models/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class IntEnum(int, Enum):
pass


class FinalOutput(Enum):
class FinalOutput(str, Enum):
ORTHOPHOTO_2D = "ORTHOPHOTO_2D"
ORTHOPHOTO_3D = "ORTHOPHOTO_3D"
DIGITAL_TERRAIN_MODEL = "DIGITAL_TERRAIN_MODEL"
Expand Down
5 changes: 2 additions & 3 deletions src/backend/app/projects/project_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
computed_field,
Field,
model_validator,
root_validator,
EmailStr,
)
from pydantic.functional_validators import AfterValidator
Expand Down Expand Up @@ -110,7 +109,7 @@ class ProjectIn(BaseModel):
)
final_output: List[FinalOutput] = Field(
...,
example=[
json_schema_extra=[
"ORTHOPHOTO_2D",
"ORTHOPHOTO_3D",
"DIGITAL_TERRAIN_MODEL",
Expand Down Expand Up @@ -538,7 +537,7 @@ class Pagination(BaseModel):
per_page: int
total: int

@root_validator(pre=True)
@model_validator(mode="before")
def calculate_pagination(cls, values):
page = values.get("page", 1)
total = values.get("total", 1)
Expand Down
21 changes: 12 additions & 9 deletions src/backend/app/tasks/task_logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,19 +42,22 @@ async def get_task_stats(db: Connection, user_data: AuthUser):
WHERE
(
%(role)s = 'DRONE_PILOT'
AND te.user_id = %(user_id)s
AND te.user_id = %(user_id)s AND te.state NOT IN ('UNLOCKED_TO_MAP')
)
OR
(
%(role)s = 'PROJECT_CREATOR'
AND (
te.project_id IN (
SELECT p.id
FROM projects p
WHERE p.author_id = %(user_id)s
%(role)s = 'PROJECT_CREATOR'
AND (
te.user_id = %(user_id)s AND te.state NOT IN ('REQUEST_FOR_MAPPING')
OR
te.project_id IN (
SELECT p.id
FROM projects p
WHERE
p.author_id = %(user_id)s
)
)
OR te.user_id = %(user_id)s -- Grant permissions equivalent to DRONE_PILOT
))
)
ORDER BY te.task_id, te.created_at DESC
) AS te;
"""
Expand Down
23 changes: 15 additions & 8 deletions src/backend/app/tasks/task_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,19 +224,26 @@ async def get_tasks_by_user(
WHERE
(
%(role)s = 'DRONE_PILOT'
AND task_events.user_id = %(user_id)s
AND task_events.user_id = %(user_id)s AND task_events.state NOT IN ('UNLOCKED_TO_MAP')
)
OR
(
%(role)s = 'PROJECT_CREATOR' AND (
task_events.project_id IN (
SELECT p.id
FROM projects p
WHERE p.author_id = %(user_id)s
%(role)s = 'PROJECT_CREATOR'
AND (
(
task_events.user_id = %(user_id)s AND task_events.state NOT IN ('REQUEST_FOR_MAPPING')
)
OR
(
task_events.project_id IN (
SELECT p.id
FROM projects p
WHERE
p.author_id = %(user_id)s
)
)
)
OR task_events.user_id = %(user_id)s
)
)
ORDER BY
tasks.id, task_events.created_at DESC
OFFSET %(skip)s
Expand Down
1 change: 1 addition & 0 deletions src/backend/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ dependencies = [
"bcrypt>=4.2.1",
"drone-flightplan>=0.3.2",
"Scrapy==2.12.0",
"asgi-lifespan>=2.1.0",
]
requires-python = ">=3.11"
license = {text = "GPL-3.0-only"}
Expand Down
1 change: 1 addition & 0 deletions src/backend/tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Backend tests using PyTest."""
140 changes: 140 additions & 0 deletions src/backend/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
from typing import AsyncGenerator, Any
from app.db.database import get_db
from app.users.user_deps import login_required
from app.models.enums import UserRole
from fastapi import FastAPI
from app.main import get_application
from app.users.user_schemas import AuthUser
import pytest_asyncio
from app.config import settings
from asgi_lifespan import LifespanManager
from httpx import ASGITransport, AsyncClient
from psycopg import AsyncConnection
from app.users.user_schemas import DbUser
import pytest
from app.projects.project_schemas import ProjectIn, DbProject


@pytest_asyncio.fixture(scope="function")
async def db() -> AsyncConnection:
"""The psycopg async database connection using psycopg3."""
db_conn = await AsyncConnection.connect(
conninfo=settings.DTM_DB_URL.unicode_string(),
)
try:
yield db_conn
finally:
await db_conn.close()


@pytest_asyncio.fixture(scope="function")
async def user(db) -> AuthUser:
"""Create a test user."""
db_user = await DbUser.get_or_create_user(
db,
AuthUser(
id="101039844375937810000",
email="[email protected]",
name="admin",
profile_img="",
role=UserRole.PROJECT_CREATOR,
),
)
return db_user


@pytest_asyncio.fixture(scope="function")
async def project_info(db, user):
"""
Fixture to create project metadata for testing.

"""
print(
f"User passed to project_info fixture: {user}, ID: {getattr(user, 'id', 'No ID')}"
)

project_metadata = ProjectIn(
name="TEST 98982849249278787878778",
description="",
outline={
"type": "FeatureCollection",
"features": [
{
"id": "d10fbd780ecd3ff7851cb222467616a0",
"type": "Feature",
"properties": {},
"geometry": {
"coordinates": [
[
[-69.49779538720068, 18.629654277305633],
[-69.48497355306813, 18.616997544638636],
[-69.54053483430786, 18.608390428368665],
[-69.5410690773959, 18.614466085056165],
[-69.49779538720068, 18.629654277305633],
]
],
"type": "Polygon",
},
}
],
},
no_fly_zones=None,
gsd_cm_px=1,
task_split_dimension=400,
is_terrain_follow=False,
per_task_instructions="",
deadline_at=None,
visibility=0,
requires_approval_from_manager_for_locking=False,
requires_approval_from_regulator=False,
front_overlap=1,
side_overlap=1,
final_output=["ORTHOPHOTO_2D"],
)

try:
await DbProject.create(db, project_metadata, getattr(user, "id", ""))
return project_metadata
except Exception as e:
pytest.fail(f"Fixture setup failed with exception: {str(e)}")


@pytest_asyncio.fixture(autouse=True)
async def app() -> AsyncGenerator[FastAPI, Any]:
"""Get the FastAPI test server."""
yield get_application()


@pytest_asyncio.fixture(scope="function")
def drone_info():
"""Test drone information."""
return {
"model": "DJI Mavic-12344",
"manufacturer": "DJI",
"camera_model": "DJI Camera 1",
"sensor_width": 13.2,
"sensor_height": 8.9,
"max_battery_health": 0.85,
"focal_length": 24.0,
"image_width": 400,
"image_height": 300,
"max_altitude": 500.0,
"max_speed": 72.0,
"weight": 1.5,
}


@pytest_asyncio.fixture(scope="function")
async def client(app: FastAPI, db: AsyncConnection):
"""The FastAPI test server."""
# Override server db connection
app.dependency_overrides[get_db] = lambda: db
app.dependency_overrides[login_required] = lambda: user

async with LifespanManager(app) as manager:
async with AsyncClient(
transport=ASGITransport(app=manager.app),
base_url="http://test",
follow_redirects=True,
) as ac:
yield ac
30 changes: 30 additions & 0 deletions src/backend/tests/test_drones_routes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from app.models.enums import HTTPStatus
import pytest


@pytest.mark.asyncio
async def test_create_drone(client, drone_info):
"""Create a new project."""

response = await client.post("/api/drones/create-drone", json=drone_info)
assert response.status_code == HTTPStatus.OK

return response.json()


@pytest.mark.asyncio
async def test_read_drone(client, drone_info):
"""Test retrieving a drone record."""

response = await client.post("/api/drones/create-drone", json=drone_info)
assert response.status_code == HTTPStatus.OK
drone_id = response.json().get("drone_id")
response = await client.get(f"/api/drones/{drone_id}")
assert response.status_code == HTTPStatus.OK
drone_data = response.json()
assert drone_data.get("model") == drone_info["model"]


if __name__ == "__main__":
"""Main func if file invoked directly."""
pytest.main()
27 changes: 27 additions & 0 deletions src/backend/tests/test_projects_routes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# import pytest
# import json


# @pytest.mark.asyncio
# async def test_create_project_with_files(client, project_info,):
# """
# Test to verify the project creation API with file upload (image as binary data).
# """
# project_info_json = json.dumps(project_info.model_dump())
# files = {
# "project_info": (None, project_info_json, "application/json"),
# "dem": None,
# "image": None
# }

# files = {k: v for k, v in files.items() if v is not None}
# response = await client.post(
# "/api/projects/",
# files=files
# )
# assert response.status_code == 201
# return response.json()

# if __name__ == "__main__":
# """Main func if file invoked directly."""
# pytest.main()
36 changes: 36 additions & 0 deletions src/backend/tests/test_users_routes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import pytest
from app.config import settings
import jwt
import pytest_asyncio
from datetime import datetime, timedelta
from loguru import logger as log


@pytest_asyncio.fixture(scope="function")
def token(user):
"""
Create a reset password token for a given user.
"""
payload = {
"sub": user.email_address,
"exp": datetime.utcnow()
+ timedelta(minutes=settings.RESET_PASSWORD_TOKEN_EXPIRE_MINUTES),
}
return jwt.encode(payload, settings.SECRET_KEY, algorithm=settings.ALGORITHM)


@pytest.mark.asyncio
async def test_reset_password_success(client, token):
"""
Test successful password reset using a valid token.
"""
new_password = "QPassword@12334"

response = await client.post(
f"/api/users/reset-password?token={token}&new_password={new_password}"
)

if response.status_code != 200:
log.debug("Response:", response.status_code, response.json())

assert response.status_code == 200
Loading
Loading