Skip to content

Commit

Permalink
feat: added tests for ambulance APIs
Browse files Browse the repository at this point in the history
  • Loading branch information
aeswibon committed Sep 28, 2023
1 parent 3e4f6fc commit b8fbb3c
Show file tree
Hide file tree
Showing 4 changed files with 257 additions and 10 deletions.
9 changes: 5 additions & 4 deletions care/facility/api/serializers/ambulance.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ class Meta:


class AmbulanceSerializer(serializers.ModelSerializer):
id = serializers.UUIDField(source="external_id", read_only=True)
drivers = serializers.ListSerializer(child=AmbulanceDriverSerializer())

primary_district_object = DistrictSerializer(
Expand All @@ -31,7 +32,7 @@ class Meta:
"secondary_district_object",
"third_district_object",
)
exclude = ("created_by",)
exclude = ("created_by", "external_id")

def validate(self, obj):
validated = super().validate(obj)
Expand All @@ -46,7 +47,7 @@ def create(self, validated_data):
drivers = validated_data.pop("drivers", [])
validated_data.pop("created_by", None)

ambulance = super(AmbulanceSerializer, self).create(validated_data)
ambulance = super().create(validated_data)

for d in drivers:
d["ambulance"] = ambulance
Expand All @@ -55,12 +56,12 @@ def create(self, validated_data):

def update(self, instance, validated_data):
validated_data.pop("drivers", [])
ambulance = super(AmbulanceSerializer, self).update(instance, validated_data)
ambulance = super().update(instance, validated_data)
return ambulance


class DeleteDriverSerializer(serializers.Serializer):
driver_id = serializers.IntegerField()
driver_id = serializers.UUIDField()

def update(self, instance, validated_data):
raise NotImplementedError
Expand Down
17 changes: 12 additions & 5 deletions care/facility/api/viewsets/ambulance.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class AmbulanceViewSet(
):
permission_classes = (IsAuthenticated,)
serializer_class = AmbulanceSerializer
lookup_field = "external_id"
queryset = Ambulance.objects.filter(deleted=False).select_related(
"primary_district", "secondary_district", "third_district"
)
Expand All @@ -65,9 +66,12 @@ def get_serializer_class(self):

@extend_schema(tags=["ambulance"])
@action(methods=["POST"], detail=True)
def add_driver(self, request):
def add_driver(self, request, *args, **kwargs):
"""
Endpoint to add a driver to an ambulance
"""
ambulance = self.get_object()
serializer = self.get_serializer(data=request.data)
serializer = self.get_serializer(ambulance, data=request.data)
serializer.is_valid(raise_exception=True)

driver = ambulance.ambulancedriver_set.create(**serializer.validated_data)
Expand All @@ -78,13 +82,16 @@ def add_driver(self, request):

@extend_schema(tags=["ambulance"])
@action(methods=["DELETE"], detail=True)
def remove_driver(self, request):
def remove_driver(self, request, *args, **kwargs):
"""
Endpoint to remove a driver from an ambulance
"""
ambulance = self.get_object()
serializer = self.get_serializer(data=request.data)
serializer = self.get_serializer(ambulance, data=request.data)
serializer.is_valid(raise_exception=True)

driver = ambulance.ambulancedriver_set.filter(
id=serializer.validated_data["driver_id"]
external_id=serializer.validated_data["driver_id"]
).first()
if not driver:
raise serializers.ValidationError({"driver_id": "Detail not found"})
Expand Down
221 changes: 221 additions & 0 deletions care/facility/tests/test_ambulance_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
"""
Test module for Ambulance API
"""

from rest_framework.test import APITestCase

from care.facility.models.ambulance import Ambulance
from care.utils.tests.test_utils import TestUtils


class TestAmbulance(TestUtils, APITestCase):
"""
Test class for Ambulance
"""

@classmethod
def setUpTestData(cls) -> None:
cls.state = cls.create_state()
cls.district = cls.create_district(cls.state)
cls.local_body = cls.create_local_body(cls.district)
cls.super_user = cls.create_super_user("su", cls.district)
cls.user = cls.create_user(
"user", district=cls.district, local_body=cls.local_body
)
cls.facility = cls.create_facility(cls.super_user, cls.district, cls.local_body)
cls.patient = cls.create_patient(
cls.district, cls.facility, local_body=cls.local_body
)

def setUp(self):
super().setUp()
self.ambulance = self.create_ambulance(district=self.district, user=self.user)

def get_base_url(self) -> str:
return "/api/v1/ambulance"

def get_url(self, entry_id=None, action=None):
"""
Constructs the url for ambulance api
"""
base_url = f"{self.get_base_url()}/"
if entry_id is not None:
base_url += f"{entry_id}/"
if action is not None:
base_url += f"{action}/"
return base_url

def get_detail_representation(self, obj=None) -> dict:
return {
"vehicle_number": obj.vehicle_number,
"ambulance_type": obj.ambulance_type,
"owner_name": obj.owner_name,
"owner_phone_number": obj.owner_phone_number,
"owner_is_smart_phone": obj.owner_is_smart_phone,
"deleted": obj.deleted,
"has_oxygen": obj.has_oxygen,
"has_ventilator": obj.has_ventilator,
"has_suction_machine": obj.has_suction_machine,
"has_defibrillator": obj.has_defibrillator,
"insurance_valid_till_year": obj.insurance_valid_till_year,
"has_free_service": obj.has_free_service,
"primary_district": obj.primary_district.id,
"primary_district_object": {
"id": obj.primary_district.id,
"name": obj.primary_district.name,
"state": obj.primary_district.state.id,
},
"secondary_district": obj.secondary_district,
"third_district": obj.third_district,
"secondary_district_object": None,
"third_district_object": None,
}

def get_list_representation(self, obj=None) -> dict:
return {
"drivers": list(obj.drivers),
**self.get_detail_representation(obj),
}

def get_create_representation(self) -> dict:
"""
Returns a representation of a ambulance create request body
"""
return {
"vehicle_number": "WW73O2195",
"owner_name": "string",
"owner_phone_number": "+918800900466",
"owner_is_smart_phone": True,
"has_oxygen": True,
"has_ventilator": True,
"has_suction_machine": True,
"has_defibrillator": True,
"insurance_valid_till_year": 2020,
"ambulance_type": 1,
"primary_district": self.district.id,
}

def test_create_ambulance(self):
"""
Test to create ambulance
"""

# Test with invalid data
res = self.client.post(
self.get_url(action="create"), data=self.get_create_representation()
)
self.assertEqual(res.status_code, 400)
self.assertEqual(res.json()["drivers"][0], "This field is required.")

data = {
"drivers": [
{
"name": "string",
"phone_number": "+919013526849",
"is_smart_phone": True,
}
],
}
data.update(self.get_create_representation())
res = self.client.post(self.get_url(action="create"), data=data, format="json")
self.assertEqual(res.status_code, 400)
self.assertEqual(
res.json()["non_field_errors"][0],
"The ambulance must provide a price or be marked as free",
)

# Test with valid data
data.update({"price_per_km": 100})
res = self.client.post(self.get_url(action="create"), data=data, format="json")
self.assertEqual(res.status_code, 201)
self.assertTrue(
Ambulance.objects.filter(vehicle_number=data["vehicle_number"]).exists()
)

def test_list_ambulance(self):
"""
Test to list ambulance
"""
res = self.client.get(self.get_url())
self.assertEqual(res.status_code, 200)
self.assertEqual(res.json()["count"], 1)
self.assertDictContainsSubset(
self.get_list_representation(self.ambulance), res.json()["results"][0]
)

def test_retrieve_ambulance(self):
"""
Test to retrieve ambulance
"""
res = self.client.get(self.get_url(entry_id=self.ambulance.external_id))
self.assertEqual(res.status_code, 200)
self.assertDictContainsSubset(
self.get_detail_representation(self.ambulance), res.json()
)

def test_update_ambulance(self):
"""
Test to update ambulance
"""

res = self.client.patch(
self.get_url(entry_id=self.ambulance.external_id),
data={"vehicle_number": "WW73O2200", "has_free_service": True},
)
self.assertEqual(res.status_code, 200)
self.ambulance.refresh_from_db()
self.assertEqual(self.ambulance.vehicle_number, "WW73O2200")

def test_delete_ambulance(self):
"""
Test to delete ambulance
"""
res = self.client.delete(self.get_url(entry_id=self.ambulance.external_id))
self.assertEqual(res.status_code, 204)
self.ambulance.refresh_from_db()
self.assertTrue(self.ambulance.deleted)

def test_add_driver(self):
"""
Test to add driver
"""

res = self.client.post(
self.get_url(entry_id=self.ambulance.external_id, action="add_driver"),
data={
"name": "string",
"phone_number": "+919013526800",
"is_smart_phone": True,
},
)

self.assertEqual(res.status_code, 201)
self.assertTrue(
self.ambulance.drivers.filter(phone_number="+919013526800").exists()
)

def test_remove_driver(self):
"""
Test to remove driver
"""

res = self.client.post(
self.get_url(entry_id=self.ambulance.external_id, action="add_driver"),
data={
"name": "string",
"phone_number": "+919013526800",
"is_smart_phone": True,
},
)

driver_id = res.json()["external_id"]

res = self.client.delete(
self.get_url(
entry_id=self.ambulance.external_id,
action="remove_driver",
),
data={"driver_id": driver_id},
)
self.assertEqual(res.status_code, 204)
self.assertFalse(self.ambulance.drivers.exists())
20 changes: 19 additions & 1 deletion care/utils/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
PatientRegistration,
User,
)
from care.facility.models.ambulance import Ambulance
from care.facility.models.asset import Asset, AssetLocation
from care.facility.models.facility import FacilityUser
from care.users.models import District, State
Expand Down Expand Up @@ -109,7 +110,24 @@ def create_local_body(cls, district: District, **kwargs) -> LocalBody:
return LocalBody.objects.create(**data)

@classmethod
def get_user_data(cls, district: District, user_type: str = None):
def create_ambulance(cls, district: District = None, user: User = None, **kwargs):
return Ambulance.objects.create(
vehicle_number="KL01AB1234",
owner_name="Foo",
owner_phone_number="9998887776",
primary_district=district or cls.district,
has_oxygen=True,
has_ventilator=True,
has_suction_machine=True,
has_defibrillator=True,
insurance_valid_till_year=2021,
price_per_km=10,
has_free_service=False,
created_by=user or cls.user,
)

@classmethod
def get_user_data(cls, district: District = None, user_type: str = None):
"""
Returns the data to be used for API testing
Expand Down

0 comments on commit b8fbb3c

Please sign in to comment.