Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Validate user authorized when mode ALL in enrollments [TECH-1589] #15583

Merged
merged 13 commits into from
Nov 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ public static void validateOrgUnitMode(
private static void validateUserCanSearchOrgUnitModeALL(User user) throws BadRequestException {
if (user != null
&& !(user.isSuper()
|| user.isAuthorized(F_TRACKED_ENTITY_INSTANCE_SEARCH_IN_ALL_ORGUNITS.name()))) {
|| user.isAuthorized(F_TRACKED_ENTITY_INSTANCE_SEARCH_IN_ALL_ORGUNITS))) {
throw new BadRequestException(
"Current user is not authorized to query across all organisation units");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
*/
package org.hisp.dhis.tracker.export.enrollment;

import static org.hisp.dhis.common.OrganisationUnitSelectionMode.ACCESSIBLE;
import static org.hisp.dhis.common.OrganisationUnitSelectionMode.CAPTURE;
import static org.hisp.dhis.common.OrganisationUnitSelectionMode.CHILDREN;
import static org.hisp.dhis.common.OrganisationUnitSelectionMode.DESCENDANTS;
Expand Down Expand Up @@ -292,14 +291,6 @@ public void validate(EnrollmentQueryParams params) throws IllegalQueryException
throw new IllegalQueryException("Params cannot be null");
}

User user = params.getUser();

if (params.isOrganisationUnitMode(ACCESSIBLE)
&& (user == null || !user.hasDataViewOrganisationUnitWithFallback())) {
violation =
"Current user must be associated with at least one organisation unit when selection mode is ACCESSIBLE";
}

if (params.hasProgram() && params.hasTrackedEntityType()) {
violation = "Program and tracked entity cannot be specified simultaneously";
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
*/
package org.hisp.dhis.tracker.export.enrollment;

import static org.hisp.dhis.tracker.export.OperationsParamsValidator.validateOrgUnitMode;

import java.util.HashSet;
import java.util.Set;
import lombok.RequiredArgsConstructor;
Expand Down Expand Up @@ -72,6 +74,7 @@ public EnrollmentQueryParams map(EnrollmentOperationParams operationParams)

User user = currentUserService.getCurrentUser();
Set<OrganisationUnit> orgUnits = validateOrgUnits(operationParams.getOrgUnitUids(), user);
validateOrgUnitMode(operationParams.getOrgUnitMode(), user, program);

EnrollmentQueryParams params = new EnrollmentQueryParams();
params.setProgram(program);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
import java.util.Set;
import javax.annotation.Nonnull;
import lombok.RequiredArgsConstructor;
import org.hisp.dhis.common.OrganisationUnitSelectionMode;
import org.hisp.dhis.common.QueryFilter;
import org.hisp.dhis.common.UID;
import org.hisp.dhis.feedback.BadRequestException;
Expand Down Expand Up @@ -76,9 +75,7 @@ public TrackedEntityQueryParams map(TrackedEntityOperationParams operationParams
validateTrackedEntityType(operationParams.getTrackedEntityTypeUid());

User user = operationParams.getUser();
Set<OrganisationUnit> orgUnits =
validateOrgUnits(
user, operationParams.getOrganisationUnits(), operationParams.getOrgUnitMode());
Set<OrganisationUnit> orgUnits = validateOrgUnits(user, operationParams.getOrganisationUnits());

TrackedEntityQueryParams params = new TrackedEntityQueryParams();
mapAttributeFilters(params, operationParams.getFilters());
Expand Down Expand Up @@ -135,8 +132,7 @@ private void mapAttributeFilters(
}
}

private Set<OrganisationUnit> validateOrgUnits(
User user, Set<String> orgUnitIds, OrganisationUnitSelectionMode orgUnitMode)
private Set<OrganisationUnit> validateOrgUnits(User user, Set<String> orgUnitIds)
throws BadRequestException, ForbiddenException {
Set<OrganisationUnit> orgUnits = new HashSet<>();
for (String orgUnitUid : orgUnitIds) {
Expand All @@ -156,10 +152,6 @@ private Set<OrganisationUnit> validateOrgUnits(
orgUnits.add(orgUnit);
}

if (orgUnitMode == OrganisationUnitSelectionMode.CAPTURE && user != null) {
orgUnits.addAll(user.getOrganisationUnits());
}

return orgUnits;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
*/
package org.hisp.dhis.tracker.export.enrollment;

import static org.hisp.dhis.common.OrganisationUnitSelectionMode.ACCESSIBLE;
import static org.hisp.dhis.common.OrganisationUnitSelectionMode.SELECTED;
import static org.hisp.dhis.utils.Assertions.assertContainsOnly;
import static org.hisp.dhis.utils.Assertions.assertIsEmpty;
import static org.junit.jupiter.api.Assertions.assertEquals;
Expand Down Expand Up @@ -118,6 +120,8 @@ void setUp() {
orgUnit2.getUid(), user.getTeiSearchOrganisationUnitsWithFallback()))
.thenReturn(true);

user.setTeiSearchOrganisationUnits(Set.of(orgUnit1, orgUnit2));

program = new Program();
program.setUid(PROGRAM_UID);
when(programService.getProgram(PROGRAM_UID)).thenReturn(program);
Expand All @@ -135,7 +139,8 @@ void setUp() {
@Test
void shouldMapWithoutFetchingNullParamsWhenParamsAreNotSpecified()
throws BadRequestException, ForbiddenException {
EnrollmentOperationParams operationParams = EnrollmentOperationParams.EMPTY;
EnrollmentOperationParams operationParams =
EnrollmentOperationParams.builder().orgUnitMode(ACCESSIBLE).build();

mapper.map(operationParams);

Expand All @@ -151,10 +156,17 @@ void shouldMapOrgUnitsWhenOrgUnitUidsAreSpecified()
EnrollmentOperationParams operationParams =
EnrollmentOperationParams.builder()
.orgUnitUids(Set.of(ORG_UNIT_1_UID, ORG_UNIT_2_UID))
.orgUnitMode(SELECTED)
.programUid(program.getUid())
.build();
when(trackerAccessManager.canAccess(user, program, orgUnit1)).thenReturn(true);
when(trackerAccessManager.canAccess(user, program, orgUnit2)).thenReturn(true);
when(organisationUnitService.isInUserHierarchy(
orgUnit1.getUid(), user.getTeiSearchOrganisationUnitsWithFallback()))
.thenReturn(true);
when(organisationUnitService.isInUserHierarchy(
orgUnit2.getUid(), user.getTeiSearchOrganisationUnitsWithFallback()))
.thenReturn(true);

EnrollmentQueryParams params = mapper.map(operationParams);

Expand All @@ -166,6 +178,7 @@ void shouldThrowExceptionWhenOrgUnitNotFound() {
EnrollmentOperationParams operationParams =
EnrollmentOperationParams.builder()
.orgUnitUids(Set.of("JW6BrFd0HLu", ORG_UNIT_2_UID))
.orgUnitMode(SELECTED)
.programUid(PROGRAM_UID)
.build();

Expand Down Expand Up @@ -193,7 +206,7 @@ void shouldThrowExceptionWhenOrgUnitNotInScope() {
@Test
void shouldMapProgramWhenProgramUidIsSpecified() throws BadRequestException, ForbiddenException {
EnrollmentOperationParams requestParams =
EnrollmentOperationParams.builder().programUid(PROGRAM_UID).build();
EnrollmentOperationParams.builder().programUid(PROGRAM_UID).orgUnitMode(ACCESSIBLE).build();

EnrollmentQueryParams params = mapper.map(requestParams);

Expand All @@ -214,7 +227,10 @@ void shouldThrowExceptionWhenProgramNotFound() {
void shouldMapTrackedEntityTypeWhenTrackedEntityTypeUidIsSpecified()
throws BadRequestException, ForbiddenException {
EnrollmentOperationParams operationParams =
EnrollmentOperationParams.builder().trackedEntityTypeUid(TRACKED_ENTITY_TYPE_UID).build();
EnrollmentOperationParams.builder()
.trackedEntityTypeUid(TRACKED_ENTITY_TYPE_UID)
.orgUnitMode(ACCESSIBLE)
.build();

EnrollmentQueryParams params = mapper.map(operationParams);

Expand All @@ -235,7 +251,10 @@ void shouldThrowExceptionWhenTrackedEntityTypeNotFound() {
void shouldMapTrackedEntityWhenTrackedEntityUidIsSpecified()
throws BadRequestException, ForbiddenException {
EnrollmentOperationParams operationParams =
EnrollmentOperationParams.builder().trackedEntityUid(TRACKED_ENTITY_UID).build();
EnrollmentOperationParams.builder()
.trackedEntityUid(TRACKED_ENTITY_UID)
.orgUnitMode(ACCESSIBLE)
.build();

EnrollmentQueryParams params = mapper.map(operationParams);

Expand All @@ -259,6 +278,7 @@ void shouldMapOrderInGivenOrder() throws BadRequestException, ForbiddenException
EnrollmentOperationParams.builder()
.orderBy("enrollmentDate", SortDirection.ASC)
.orderBy("created", SortDirection.DESC)
.orgUnitMode(ACCESSIBLE)
.build();

EnrollmentQueryParams params = mapper.map(operationParams);
Expand All @@ -273,9 +293,10 @@ void shouldMapOrderInGivenOrder() throws BadRequestException, ForbiddenException
@Test
void shouldMapNullOrderingParamsWhenNoOrderingParamsAreSpecified()
throws BadRequestException, ForbiddenException {
EnrollmentOperationParams requestParams = EnrollmentOperationParams.EMPTY;
EnrollmentOperationParams operationParams =
EnrollmentOperationParams.builder().orgUnitMode(ACCESSIBLE).build();

EnrollmentQueryParams params = mapper.map(requestParams);
EnrollmentQueryParams params = mapper.map(operationParams);

assertIsEmpty(params.getOrder());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -522,6 +522,7 @@ void shouldReturnPaginatedEnrollmentsGivenNonDefaultPageSize()
EnrollmentOperationParams operationParams =
EnrollmentOperationParams.builder()
.orgUnitUids(Set.of(orgUnit.getUid()))
.orgUnitMode(SELECTED)
.orderBy("enrollmentDate", SortDirection.ASC)
.build();

Expand Down Expand Up @@ -553,6 +554,7 @@ void shouldReturnPaginatedEnrollmentsGivenNonDefaultPageSizeAndTotalPages()
EnrollmentOperationParams operationParams =
EnrollmentOperationParams.builder()
.orgUnitUids(Set.of(orgUnit.getUid()))
.orgUnitMode(SELECTED)
.orderBy("enrollmentDate", SortDirection.ASC)
.build();

Expand Down Expand Up @@ -590,7 +592,10 @@ void shouldOrderEnrollmentsByPrimaryKeyDescByDefault()
.toList();

EnrollmentOperationParams params =
EnrollmentOperationParams.builder().orgUnitUids(Set.of(orgUnit.getUid())).build();
EnrollmentOperationParams.builder()
.orgUnitUids(Set.of(orgUnit.getUid()))
.orgUnitMode(SELECTED)
.build();

List<String> enrollments = getEnrollments(params);

Expand All @@ -603,6 +608,7 @@ void shouldOrderEnrollmentsByEnrolledAtAsc()
EnrollmentOperationParams params =
EnrollmentOperationParams.builder()
.orgUnitUids(Set.of(orgUnit.getUid()))
.orgUnitMode(SELECTED)
.orderBy("enrollmentDate", SortDirection.ASC)
.build();

Expand All @@ -617,6 +623,7 @@ void shouldOrderEnrollmentsByEnrolledAtDesc()
EnrollmentOperationParams params =
EnrollmentOperationParams.builder()
.orgUnitUids(Set.of(orgUnit.getUid()))
.orgUnitMode(SELECTED)
.orderBy("enrollmentDate", SortDirection.DESC)
.build();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,10 @@
*/
package org.hisp.dhis.tracker.export.enrollment;

import static org.hisp.dhis.common.OrganisationUnitSelectionMode.ACCESSIBLE;
import static org.hisp.dhis.common.OrganisationUnitSelectionMode.ALL;
import static org.hisp.dhis.common.OrganisationUnitSelectionMode.CAPTURE;
import static org.hisp.dhis.common.OrganisationUnitSelectionMode.SELECTED;
import static org.hisp.dhis.tracker.TrackerTestUtils.oneHourAfter;
import static org.hisp.dhis.tracker.TrackerTestUtils.oneHourBefore;
import static org.hisp.dhis.tracker.TrackerTestUtils.uids;
Expand Down Expand Up @@ -69,6 +72,7 @@
import org.hisp.dhis.trackedentityattributevalue.TrackedEntityAttributeValue;
import org.hisp.dhis.user.User;
import org.hisp.dhis.user.UserService;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.springframework.beans.factory.annotation.Autowired;

Expand Down Expand Up @@ -410,7 +414,10 @@ void shouldGetEnrollmentsWhenUserHasReadAccessToProgramAndNoOrgUnitNorOrgUnitMod
manager.updateNoAcl(programA);

EnrollmentOperationParams params =
EnrollmentOperationParams.builder().programUid(programA.getUid()).build();
EnrollmentOperationParams.builder()
.programUid(programA.getUid())
.orgUnitMode(ACCESSIBLE)
.build();

List<Enrollment> enrollments = enrollmentService.getEnrollments(params);

Expand Down Expand Up @@ -447,6 +454,7 @@ void shouldGetEnrollmentWhenEnrollmentsAndOtherParamsAreSpecified()
EnrollmentOperationParams.builder()
.programUid(programA.getUid())
.enrollmentUids(Set.of(enrollmentA.getUid()))
.orgUnitMode(ACCESSIBLE)
.build();

List<Enrollment> enrollments = enrollmentService.getEnrollments(params);
Expand All @@ -464,6 +472,7 @@ void shouldGetEnrollmentsByTrackedEntityWhenUserHasAccessToTrackedEntityType()
EnrollmentOperationParams params =
EnrollmentOperationParams.builder()
.orgUnitUids(Set.of(trackedEntityA.getOrganisationUnit().getUid()))
.orgUnitMode(SELECTED)
.trackedEntityUid(trackedEntityA.getUid())
.build();

Expand All @@ -485,6 +494,7 @@ void shouldFailGettingEnrollmentsByTrackedEntityWhenUserHasNoAccessToTrackedEnti
EnrollmentOperationParams params =
EnrollmentOperationParams.builder()
.orgUnitUids(Set.of(trackedEntityA.getOrganisationUnit().getUid()))
.orgUnitMode(SELECTED)
.trackedEntityUid(trackedEntityA.getUid())
.build();

Expand All @@ -501,6 +511,7 @@ void shouldReturnEnrollmentIfEnrollmentWasUpdatedBeforePassedDateAndTime()
EnrollmentOperationParams operationParams =
EnrollmentOperationParams.builder()
.orgUnitUids(Set.of(orgUnitA.getUid()))
.orgUnitMode(SELECTED)
.lastUpdated(oneHourBeforeLastUpdated)
.build();

Expand All @@ -517,6 +528,7 @@ void shouldReturnEmptyIfEnrollmentWasUpdatedAfterPassedDateAndTime()
EnrollmentOperationParams operationParams =
EnrollmentOperationParams.builder()
.orgUnitUids(Set.of(orgUnitA.getUid()))
.orgUnitMode(SELECTED)
.lastUpdated(oneHourAfterLastUpdated)
.build();

Expand All @@ -534,6 +546,7 @@ void shouldReturnEnrollmentIfEnrollmentStartedBeforePassedDateAndTime()
EnrollmentOperationParams operationParams =
EnrollmentOperationParams.builder()
.orgUnitUids(Set.of(orgUnitA.getUid()))
.orgUnitMode(SELECTED)
.programUid(programA.getUid())
.programStartDate(oneHourBeforeEnrollmentDate)
.build();
Expand All @@ -552,6 +565,7 @@ void shouldReturnEmptyIfEnrollmentStartedAfterPassedDateAndTime()
EnrollmentOperationParams operationParams =
EnrollmentOperationParams.builder()
.orgUnitUids(Set.of(orgUnitA.getUid()))
.orgUnitMode(SELECTED)
.programUid(programA.getUid())
.programStartDate(oneHourAfterEnrollmentDate)
.build();
Expand All @@ -570,6 +584,7 @@ void shouldReturnEnrollmentIfEnrollmentEndedAfterPassedDateAndTime()
EnrollmentOperationParams operationParams =
EnrollmentOperationParams.builder()
.orgUnitUids(Set.of(orgUnitA.getUid()))
.orgUnitMode(SELECTED)
.programUid(programA.getUid())
.programEndDate(oneHourAfterEnrollmentDate)
.build();
Expand All @@ -588,6 +603,7 @@ void shouldReturnEmptyIfEnrollmentEndedBeforePassedDateAndTime()
EnrollmentOperationParams operationParams =
EnrollmentOperationParams.builder()
.orgUnitUids(Set.of(orgUnitA.getUid()))
.orgUnitMode(SELECTED)
.programUid(programA.getUid())
.programEndDate(oneHourBeforeEnrollmentDate)
.build();
Expand All @@ -597,6 +613,31 @@ void shouldReturnEmptyIfEnrollmentEndedBeforePassedDateAndTime()
assertIsEmpty(enrollments);
}

@Test
void shouldFailWhenOrgUnitModeAllAndUserNotAuthorized() {
EnrollmentOperationParams operationParams =
EnrollmentOperationParams.builder().orgUnitMode(ALL).build();

BadRequestException exception =
Assertions.assertThrows(
BadRequestException.class, () -> enrollmentService.getEnrollments(operationParams));
Assertions.assertEquals(
"Current user is not authorized to query across all organisation units",
exception.getMessage());
}

@Test
void shouldReturnAllEnrollmentsWhenOrgUnitModeAllAndUserAuthorized()
throws ForbiddenException, BadRequestException, NotFoundException {
injectSecurityContext(admin);

EnrollmentOperationParams operationParams =
EnrollmentOperationParams.builder().orgUnitMode(ALL).build();

List<Enrollment> enrollments = enrollmentService.getEnrollments(operationParams);
assertContainsOnly(List.of(enrollmentA, enrollmentB, enrollmentChildA), enrollments);
}

private static List<String> attributeUids(Enrollment enrollment) {
return enrollment.getTrackedEntity().getTrackedEntityAttributeValues().stream()
.map(v -> v.getAttribute().getUid())
Expand Down
Loading
Loading