diff --git a/client/html/user_edit.tpl b/client/html/user_edit.tpl index d518aabba..22dcd8350 100644 --- a/client/html/user_edit.tpl +++ b/client/html/user_edit.tpl @@ -68,6 +68,12 @@ <% } %> + + <% if (ctx.canEditBlocklist) { %> +
  • + <%= ctx.makeTextInput({text: 'Blocklist'}) %> +
  • + <% } %>
    diff --git a/client/js/controllers/user_controller.js b/client/js/controllers/user_controller.js index 8cf46584d..bd41aa1e7 100644 --- a/client/js/controllers/user_controller.js +++ b/client/js/controllers/user_controller.js @@ -89,6 +89,7 @@ class UserController { canEditAvatar: api.hasPrivilege( `users:edit:${infix}:avatar` ), + canEditBlocklist: api.hasPrivilege(`users:edit:${infix}:blocklist`), canEditAnything: api.hasPrivilege(`users:edit:${infix}`), canListTokens: api.hasPrivilege( `userTokens:list:${infix}` diff --git a/client/js/models/user.js b/client/js/models/user.js index 28dc3efe5..40b6ebecb 100644 --- a/client/js/models/user.js +++ b/client/js/models/user.js @@ -3,11 +3,19 @@ const api = require("../api.js"); const uri = require("../util/uri.js"); const events = require("../events.js"); +const misc = require("../util/misc.js"); class User extends events.EventTarget { constructor() { + const TagList = require("./tag_list.js"); + super(); this._orig = {}; + + for (let obj of [this, this._orig]) { + obj._blocklist = new TagList(); + } + this._updateFromResponse({}); } @@ -71,6 +79,10 @@ class User extends events.EventTarget { throw "Invalid operation"; } + get blocklist() { + return this._blocklist; + } + set name(value) { this._name = value; } @@ -95,6 +107,10 @@ class User extends events.EventTarget { this._password = value; } + set blocklist(value) { + this._blocklist = value || ""; + } + static fromResponse(response) { const ret = new User(); ret._updateFromResponse(response); @@ -121,6 +137,11 @@ class User extends events.EventTarget { if (this._rank !== this._orig._rank) { detail.rank = this._rank; } + if (misc.arraysDiffer(this._blocklist, this._orig._blocklist)) { + detail.blocklist = this._blocklist.map( + (relation) => relation.names[0] + ); + } if (this._avatarStyle !== this._orig._avatarStyle) { detail.avatarStyle = this._avatarStyle; } @@ -187,6 +208,10 @@ class User extends events.EventTarget { _dislikedPostCount: response.dislikedPostCount, }; + for (let obj of [this, this._orig]) { + obj._blocklist.sync(response.blocklist); + } + Object.assign(this, map); Object.assign(this._orig, map); diff --git a/client/js/views/user_edit_view.js b/client/js/views/user_edit_view.js index 4886726a2..bee0bb4fe 100644 --- a/client/js/views/user_edit_view.js +++ b/client/js/views/user_edit_view.js @@ -4,6 +4,8 @@ const events = require("../events.js"); const api = require("../api.js"); const views = require("../util/views.js"); const FileDropperControl = require("../controls/file_dropper_control.js"); +const TagInputControl = require("../controls/tag_input_control.js") +const misc = require("../util/misc.js"); const template = views.getTemplate("user-edit"); @@ -41,6 +43,13 @@ class UserEditView extends events.EventTarget { }); } + if (this._blocklistFieldNode) { + new TagInputControl( + this._blocklistFieldNode, + this._user.blocklist + ); + } + this._formNode.addEventListener("submit", (e) => this._evtSubmit(e)); } @@ -83,6 +92,10 @@ class UserEditView extends events.EventTarget { ? this._rankInputNode.value : undefined, + blocklist: this._blocklistFieldNode + ? misc.splitByWhitespace(this._blocklistFieldNode.value) + : undefined, + avatarStyle: this._avatarStyleInputNode ? this._avatarStyleInputNode.value : undefined, @@ -101,6 +114,10 @@ class UserEditView extends events.EventTarget { return this._hostNode.querySelector("form"); } + get _blocklistFieldNode() { + return this._formNode.querySelector(".blocklist input"); + } + get _rankInputNode() { return this._formNode.querySelector("[name=rank]"); } diff --git a/server/config.yaml.dist b/server/config.yaml.dist index 193aac3ac..69be858e2 100644 --- a/server/config.yaml.dist +++ b/server/config.yaml.dist @@ -67,6 +67,12 @@ webhooks: default_rank: regular +# default blocklisted tags (space separated) +default_tag_blocklist: '' + +# Apply blocklist for anonymous viewers too +default_tag_blocklist_for_anonymous: yes + privileges: 'users:create:self': anonymous # Registration permission 'users:create:any': administrator @@ -76,11 +82,13 @@ privileges: 'users:edit:any:pass': moderator 'users:edit:any:email': moderator 'users:edit:any:avatar': moderator + 'users:edit:any:blocklist': moderator 'users:edit:any:rank': moderator 'users:edit:self:name': regular 'users:edit:self:pass': regular 'users:edit:self:email': regular 'users:edit:self:avatar': regular + 'users:edit:self:blocklist': regular 'users:edit:self:rank': moderator # one can't promote themselves or anyone to upper rank than their own. 'users:delete:any': administrator 'users:delete:self': regular diff --git a/server/szurubooru/api/info_api.py b/server/szurubooru/api/info_api.py index 757b09cfe..53b639627 100644 --- a/server/szurubooru/api/info_api.py +++ b/server/szurubooru/api/info_api.py @@ -43,6 +43,8 @@ def get_info(ctx: rest.Context, _params: Dict[str, str] = {}) -> rest.Response: "tagNameRegex": config.config["tag_name_regex"], "tagCategoryNameRegex": config.config["tag_category_name_regex"], "defaultUserRank": config.config["default_rank"], + "defaultTagBlocklist": config.config["default_tag_blocklist"], + "defaultTagBlocklistForAnonymous": config.config["default_tag_blocklist_for_anonymous"], "enableSafety": config.config["enable_safety"], "contactEmail": config.config["contact_email"], "canSendMails": bool(config.config["smtp"]["host"]), diff --git a/server/szurubooru/api/tag_api.py b/server/szurubooru/api/tag_api.py index 6b4c807e2..724e71d73 100644 --- a/server/szurubooru/api/tag_api.py +++ b/server/szurubooru/api/tag_api.py @@ -2,7 +2,7 @@ from typing import Dict, List, Optional from szurubooru import db, model, rest, search -from szurubooru.func import auth, serialization, snapshots, tags, versions +from szurubooru.func import auth, serialization, snapshots, tags, versions, users _search_executor = search.Executor(search.configs.TagSearchConfig()) diff --git a/server/szurubooru/api/user_api.py b/server/szurubooru/api/user_api.py index a6196cb8c..11a884595 100644 --- a/server/szurubooru/api/user_api.py +++ b/server/szurubooru/api/user_api.py @@ -1,7 +1,7 @@ -from typing import Any, Dict +from typing import Any, Dict, List -from szurubooru import model, rest, search -from szurubooru.func import auth, serialization, users, versions +from szurubooru import db, model, rest, search +from szurubooru.func import auth, serialization, snapshots, users, versions, tags _search_executor = search.Executor(search.configs.UserSearchConfig()) @@ -17,6 +17,18 @@ def _serialize( ) +def _create_tag_if_needed(tag_names: List[str], user: model.User) -> None: + # Taken from tag_api.py + if not tag_names: + return + _existing_tags, new_tags = tags.get_or_create_tags_by_names(tag_names) + if len(new_tags): + auth.verify_privilege(user, "tags:create") + db.session.flush() + for tag in new_tags: + snapshots.create(tag, user) + + @rest.routes.get("/users/?") def get_users( ctx: rest.Context, _params: Dict[str, str] = {} @@ -50,6 +62,10 @@ def create_user( ) ctx.session.add(user) ctx.session.commit() + to_add, _ = users.update_user_blocklist(user, None) + for e in to_add: + ctx.session.add(e) + ctx.session.commit() return _serialize(ctx, user, force_show_email=True) @@ -80,6 +96,16 @@ def update_user(ctx: rest.Context, params: Dict[str, str]) -> rest.Response: if ctx.has_param("rank"): auth.verify_privilege(ctx.user, "users:edit:%s:rank" % infix) users.update_user_rank(user, ctx.get_param_as_string("rank"), ctx.user) + if ctx.has_param("blocklist"): + auth.verify_privilege(ctx.user, "users:edit:%s:blocklist" % infix) + blocklist = ctx.get_param_as_string_list("blocklist") + _create_tag_if_needed(blocklist, user) # Non-existing tags are created. + blocklist_tags = tags.get_tags_by_names(blocklist) + to_add, to_remove = users.update_user_blocklist(user, blocklist_tags) + for e in to_remove: + ctx.session.delete(e) + for e in to_add: + ctx.session.add(e) if ctx.has_param("avatarStyle"): auth.verify_privilege(ctx.user, "users:edit:%s:avatar" % infix) users.update_user_avatar( diff --git a/server/szurubooru/func/tags.py b/server/szurubooru/func/tags.py index 28a2a76bc..10eaebf13 100644 --- a/server/szurubooru/func/tags.py +++ b/server/szurubooru/func/tags.py @@ -159,6 +159,9 @@ def get_tag_by_name(name: str) -> model.Tag: def get_tags_by_names(names: List[str]) -> List[model.Tag]: + """ + Returns a list of all tags which names include all the letters from the input list + """ names = util.icase_unique(names) if len(names) == 0: return [] @@ -175,6 +178,24 @@ def get_tags_by_names(names: List[str]) -> List[model.Tag]: ) +def get_tags_by_exact_names(names: List[str]) -> List[model.Tag]: + """ + Returns a list of tags matching the names from the input list + """ + entries = [] + if len(names) == 0: + return [] + names = [name.lower() for name in names] + entries = ( + db.session.query(model.Tag) + .join(model.TagName) + .filter( + sa.func.lower(model.TagName.name).in_(names) + ) + .all()) + return entries + + def get_or_create_tags_by_names( names: List[str], ) -> Tuple[List[model.Tag], List[model.Tag]]: diff --git a/server/szurubooru/func/users.py b/server/szurubooru/func/users.py index 5cbe3cc0f..c1be991e6 100644 --- a/server/szurubooru/func/users.py +++ b/server/szurubooru/func/users.py @@ -1,3 +1,4 @@ +import copy import re from datetime import datetime from typing import Any, Callable, Dict, List, Optional, Union @@ -5,7 +6,7 @@ import sqlalchemy as sa from szurubooru import config, db, errors, model, rest -from szurubooru.func import auth, files, images, serialization, util +from szurubooru.func import auth, files, images, serialization, util, tags class UserNotFoundError(errors.NotFoundError): @@ -107,6 +108,7 @@ def _serializers(self) -> Dict[str, Callable[[], Any]]: "lastLoginTime": self.serialize_last_login_time, "version": self.serialize_version, "rank": self.serialize_rank, + "blocklist": self.serialize_blocklist, "avatarStyle": self.serialize_avatar_style, "avatarUrl": self.serialize_avatar_url, "commentCount": self.serialize_comment_count, @@ -138,6 +140,9 @@ def serialize_avatar_style(self) -> Any: def serialize_avatar_url(self) -> Any: return get_avatar_url(self.user) + def serialize_blocklist(self) -> Any: + return [tags.serialize_tag(tag) for tag in get_blocklist_tag_from_user(self.user)] + def serialize_comment_count(self) -> Any: return self.user.comment_count @@ -294,6 +299,66 @@ def update_user_rank( user.rank = rank +def get_blocklist_from_user(user: model.User) -> List[model.UserTagBlocklist]: + """ + Return the UserTagBlocklist objects related to given user + """ + rez = (db.session.query(model.UserTagBlocklist) + .filter( + model.UserTagBlocklist.user_id == user.user_id + ) + .all()) + return rez + + +def get_blocklist_tag_from_user(user: model.User) -> List[model.UserTagBlocklist]: + """ + Return the Tags blocklisted by given user + """ + rez = (db.session.query(model.UserTagBlocklist.tag_id) + .filter( + model.UserTagBlocklist.user_id == user.user_id + )) + rez2 = (db.session.query(model.Tag) + .filter( + model.Tag.tag_id.in_(rez) + ).all()) + return rez2 + + +def update_user_blocklist(user: model.User, new_blocklist_tags: Optional[List[model.Tag]]) -> List[List[model.UserTagBlocklist]]: + """ + Modify blocklist for given user. + If new_blocklist_tags is None, set the blocklist to configured default tag blocklist. + """ + assert user + to_add: List[model.UserTagBlocklist] = [] + to_remove: List[model.UserTagBlocklist] = [] + + if new_blocklist_tags is None: # We're creating the user, use default config blocklist + if 'default_tag_blocklist' in config.config.keys(): + for e in tags.get_tags_by_exact_names(config.config['default_tag_blocklist'].split(' ')): + to_add.append(model.UserTagBlocklist(user_id=user.user_id, tag_id=e.tag_id)) + else: + new_blocklist_ids: List[int] = [e.tag_id for e in new_blocklist_tags] + previous_blocklist_tags: List[model.Tag] = get_blocklist_from_user(user) + previous_blocklist_ids: List[int] = [e.tag_id for e in previous_blocklist_tags] + original_previous_blocklist_ids = copy.copy(previous_blocklist_ids) + + ## Remove tags no longer in the new list + for i in range(len(original_previous_blocklist_ids)): + old_tag_id = original_previous_blocklist_ids[i] + if old_tag_id not in new_blocklist_ids: + to_remove.append(previous_blocklist_tags[i]) + previous_blocklist_ids.remove(old_tag_id) + + ## Add tags not yet in the original list + for new_tag_id in new_blocklist_ids: + if new_tag_id not in previous_blocklist_ids: + to_add.append(model.UserTagBlocklist(user_id=user.user_id, tag_id=new_tag_id)) + return to_add, to_remove + + def update_user_avatar( user: model.User, avatar_style: str, avatar_content: Optional[bytes] = None ) -> None: diff --git a/server/szurubooru/migrations/versions/9ba5e3a6ee7c_add_blocklist.py b/server/szurubooru/migrations/versions/9ba5e3a6ee7c_add_blocklist.py new file mode 100644 index 000000000..317f2c962 --- /dev/null +++ b/server/szurubooru/migrations/versions/9ba5e3a6ee7c_add_blocklist.py @@ -0,0 +1,30 @@ +''' +Add blocklist related fields + +add_blocklist + +Revision ID: 9ba5e3a6ee7c +Created at: 2023-05-20 22:28:10.824954 +''' + +import sqlalchemy as sa +from alembic import op + +revision = '9ba5e3a6ee7c' +down_revision = 'adcd63ff76a2' +branch_labels = None +depends_on = None + + +def upgrade(): + op.create_table( + "user_tag_blocklist", + sa.Column("user_id", sa.Integer(), nullable=False), + sa.Column("tag_id", sa.Integer(), nullable=False), + sa.ForeignKeyConstraint(["user_id"], ["user.id"]), + sa.ForeignKeyConstraint(["tag_id"], ["tag.id"]), + sa.PrimaryKeyConstraint("user_id", "tag_id"), + ) + +def downgrade(): + op.drop_table('user_tag_blocklist') diff --git a/server/szurubooru/model/__init__.py b/server/szurubooru/model/__init__.py index 21a178ef4..de28af4c5 100644 --- a/server/szurubooru/model/__init__.py +++ b/server/szurubooru/model/__init__.py @@ -16,4 +16,4 @@ from szurubooru.model.snapshot import Snapshot from szurubooru.model.tag import Tag, TagImplication, TagName, TagSuggestion from szurubooru.model.tag_category import TagCategory -from szurubooru.model.user import User, UserToken +from szurubooru.model.user import UserTagBlocklist, User, UserToken diff --git a/server/szurubooru/model/user.py b/server/szurubooru/model/user.py index 41a9b30b5..5186f4290 100644 --- a/server/szurubooru/model/user.py +++ b/server/szurubooru/model/user.py @@ -5,6 +5,46 @@ from szurubooru.model.post import Post, PostFavorite, PostScore +class UserTagBlocklist(Base): + __tablename__ = "user_tag_blocklist" + + user_id = sa.Column( + "user_id", + sa.Integer, + sa.ForeignKey("user.id"), + primary_key=True, + nullable=False, + index=True, + ) + tag_id = sa.Column( + "tag_id", + sa.Integer, + sa.ForeignKey("tag.id"), + primary_key=True, + nullable=False, + index=True, + ) + + tag = sa.orm.relationship( + "Tag", + backref=sa.orm.backref("user_tag_blocklist", cascade="all, delete-orphan"), + ) + user = sa.orm.relationship( + "User", + backref=sa.orm.backref("user_tag_blocklist", cascade="all, delete-orphan"), + ) + + def __init__(self, user_id: int=None, tag_id: int=None, user=None, tag=None) -> None: + if user_id is not None: + self.user_id = user_id + if tag_id is not None: + self.tag_id = tag_id + if user is not None: + self.user = user + if tag is not None: + self.tag = tag + + class User(Base): __tablename__ = "user" @@ -35,6 +75,7 @@ class User(Base): "avatar_style", sa.Unicode(32), nullable=False, default=AVATAR_GRAVATAR ) + blocklist = sa.orm.relationship("UserTagBlocklist") comments = sa.orm.relationship("Comment") @property diff --git a/server/szurubooru/search/configs/post_search_config.py b/server/szurubooru/search/configs/post_search_config.py index 8d4672d46..d482bcc64 100644 --- a/server/szurubooru/search/configs/post_search_config.py +++ b/server/szurubooru/search/configs/post_search_config.py @@ -2,9 +2,9 @@ import sqlalchemy as sa -from szurubooru import db, errors, model -from szurubooru.func import util -from szurubooru.search import criteria, tokens +from szurubooru import config, db, errors, model +from szurubooru.func import tags, users, util +from szurubooru.search import criteria, parser, tokens from szurubooru.search.configs import util as search_util from szurubooru.search.configs.base_search_config import ( BaseSearchConfig, @@ -178,6 +178,32 @@ def on_search_query_parsed(self, search_query: SearchQuery) -> SaQuery: new_special_tokens.append(token) search_query.special_tokens = new_special_tokens + blocklist_to_use = "" + + if self.user: # Ensure there's a user object + if (self.user.rank == model.User.RANK_ANONYMOUS) and config.config["default_tag_blocklist_for_anonymous"]: + # Anonymous user, if configured to use default blocklist, do so + blocklist_to_use = config.config["default_tag_blocklist"] + else: + # Registered user, use their blocklist + user_blocklist_tags = users.get_blocklist_tag_from_user(self.user) + if user_blocklist_tags: + user_blocklist = db.session.query(model.Tag.first_name).filter( + model.Tag.tag_id.in_([e.tag_id for e in user_blocklist_tags]) + ).all() + blocklist_to_use = [e[0] for e in user_blocklist] + blocklist_to_use = " ".join(blocklist_to_use) + + if len(blocklist_to_use) > 0: + # TODO Sort an already parsed and checked version instead? + blocklist_query = parser.Parser().parse(blocklist_to_use) + search_query_orig_list = [e.criterion.original_text for e in search_query.anonymous_tokens] + for t in blocklist_query.anonymous_tokens: + if t.criterion.original_text in search_query_orig_list: + continue + t.negated = True + search_query.anonymous_tokens.append(t) + def create_around_query(self) -> SaQuery: return db.session.query(model.Post).options(sa.orm.lazyload("*")) diff --git a/server/szurubooru/tests/api/test_info.py b/server/szurubooru/tests/api/test_info.py index 37099e8d4..c804867f6 100644 --- a/server/szurubooru/tests/api/test_info.py +++ b/server/szurubooru/tests/api/test_info.py @@ -26,6 +26,8 @@ def test_info_api( "tag_name_regex": "3", "tag_category_name_regex": "4", "default_rank": "5", + "default_tag_blocklist": "testTag", + "default_tag_blocklist_for_anonymous": True, "privileges": { "test_key1": "test_value1", "test_key2": "test_value2", @@ -48,6 +50,8 @@ def test_info_api( "tagNameRegex": "3", "tagCategoryNameRegex": "4", "defaultUserRank": "5", + "defaultTagBlocklist": "testTag", + "defaultTagBlocklistForAnonymous": True, "privileges": { "testKey1": "test_value1", "testKey2": "test_value2", diff --git a/server/szurubooru/tests/api/test_post_blocklist.py b/server/szurubooru/tests/api/test_post_blocklist.py new file mode 100644 index 000000000..7c4cf9cdd --- /dev/null +++ b/server/szurubooru/tests/api/test_post_blocklist.py @@ -0,0 +1,139 @@ +from datetime import datetime +from unittest.mock import patch + +import pytest + +from szurubooru import api, db, errors, model +from szurubooru.func import posts + + +## TODO: Add following tests: +## - Retrieve posts without blocklist active for current registered user +## - Retrieve posts with blocklist active for current registered user +## - Retrieve posts without blocklist active for anonymous user +## - Retrieve posts with blocklist active for anonymous user +## - Creation of user with default blocklist (test that user_blocklist entries are properly added to db, with right infos) +## - Modification of user with/without blocklist changes +## - Retrieve posts with a query including a blocklisted tag (it should include results with the tag) +## - Behavior when creating user with default blocklist and tags from this list don't exist (blocklist entry shouldn't be added) +## - Test all small functions used across blocklist features + + +def test_blocklist(user_factory, post_factory, context_factory, config_injector, user_blocklist_factory, tag_factory): + """ + Test that user blocklist is applied on post retrieval + """ + tag1 = tag_factory(names=['tag1']) + tag2 = tag_factory(names=['tag2']) + tag3 = tag_factory(names=['tag3']) + post1 = post_factory(id=11, tags=[tag1, tag2]) + post2 = post_factory(id=12, tags=[tag1]) + post3 = post_factory(id=13, tags=[tag2]) + post4 = post_factory(id=14, tags=[tag3]) + post5 = post_factory(id=15) + user1 = user_factory(rank=model.User.RANK_REGULAR) + blocklist1 = user_blocklist_factory(tag=tag1, user=user1) + config_injector({ + "privileges": { + "posts:list": model.User.RANK_REGULAR, + } + }) + db.session.add_all([tag1, tag2, tag3, user1, blocklist1, post1, post2, post3, post4, post5]) + db.session.flush() + # We can't check that the posts we retrieve are the ones we want + with patch("szurubooru.func.posts.serialize_post"): + posts.serialize_post.side_effect = ( + lambda post, *_args, **_kwargs: "serialized post %d" % post.post_id + ) + result = api.post_api.get_posts( + context_factory( + params={"query": "", "offset": 0}, + user=user1, + ) + ) + assert result == { + "query": "", + "offset": 0, + "limit": 100, + "total": 3, + "results": ["serialized post 15", "serialized post 14", "serialized post 13"], + } + + +# def test_blocklist_no_anonymous(user_factory, post_factory, context_factory, config_injector, tag_factory): +# """ +# Test that default blocklist isn't applied on anonymous users on post retrieval if disabled in configuration +# """ +# tag1 = tag_factory(names=['tag1']) +# post1 = post_factory(id=21, tags=[tag1]) +# post2 = post_factory(id=22, tags=[tag1]) +# post3 = post_factory(id=23) +# user1 = user_factory(rank=model.User.RANK_ANONYMOUS) +# config_injector({ +# "default_tag_blocklist": "tag1", +# "default_tag_blocklist_for_anonymous": False, +# "privileges": { +# "posts:list": model.User.RANK_ANONYMOUS, +# } +# }) +# db.session.add_all([tag1, post1, post2, post3]) +# db.session.flush() +# with patch("szurubooru.func.posts.serialize_post"): +# posts.serialize_post.side_effect = ( +# lambda post, *_args, **_kwargs: "serialized post %d" % post.post_id +# ) +# result = api.post_api.get_posts( +# context_factory( +# params={"query": "", "offset": 0}, +# user=user1, +# ) +# ) +# assert result == { +# "query": "", +# "offset": 0, +# "limit": 100, +# "total": 3, +# "results": ["serialized post 23", "serialized post 22", "serialized post 21"], +# } + + +def test_blocklist_anonymous(user_factory, post_factory, context_factory, config_injector, tag_factory): + """ + Test that default blocklist is applied on anonymous users on post retrieval if enabled in configuration + """ + tag1 = tag_factory(names=['tag1']) + tag2 = tag_factory(names=['tag2']) + tag3 = tag_factory(names=['tag3']) + post1 = post_factory(id=31, tags=[tag1, tag2]) + post2 = post_factory(id=32, tags=[tag1]) + post3 = post_factory(id=33, tags=[tag2]) + post4 = post_factory(id=34, tags=[tag3]) + post5 = post_factory(id=35) + config_injector({ + "default_tag_blocklist": "tag3", + "default_tag_blocklist_for_anonymous": True, + "privileges": { + "posts:list": model.User.RANK_ANONYMOUS, + } + }) + db.session.add_all([tag1, tag2, tag3, post1, post2, post3, post4, post5]) + db.session.flush() + with patch("szurubooru.func.posts.serialize_post"): + posts.serialize_post.side_effect = ( + lambda post, *_args, **_kwargs: "serialized post %d" % post.post_id + ) + result = api.post_api.get_posts( + context_factory( + params={"query": "", "offset": 0}, + user=user_factory(rank=model.User.RANK_ANONYMOUS), + ) + ) + assert result == { + "query": "", + "offset": 0, + "limit": 100, + "total": 4, + "results": ["serialized post 35", "serialized post 33", "serialized post 32", "serialized post 31"], + } + +## TODO: Test when we add blocklist items to the query diff --git a/server/szurubooru/tests/api/test_user_creating.py b/server/szurubooru/tests/api/test_user_creating.py index d55e1f7fa..2b2f22e80 100644 --- a/server/szurubooru/tests/api/test_user_creating.py +++ b/server/szurubooru/tests/api/test_user_creating.py @@ -21,6 +21,8 @@ def test_creating_user(user_factory, context_factory, fake_datetime): "szurubooru.func.users.update_user_rank" ), patch( "szurubooru.func.users.update_user_avatar" + ), patch( + "szurubooru.func.users.update_user_blocklist" ), patch( "szurubooru.func.users.serialize_user" ), fake_datetime( @@ -28,6 +30,7 @@ def test_creating_user(user_factory, context_factory, fake_datetime): ): users.serialize_user.return_value = "serialized user" users.create_user.return_value = user + users.update_user_blocklist.return_value = ([],[]) result = api.user_api.create_user( context_factory( params={ @@ -50,6 +53,7 @@ def test_creating_user(user_factory, context_factory, fake_datetime): assert not users.update_user_email.called users.update_user_rank.called_once_with(user, "moderator") users.update_user_avatar.called_once_with(user, "manual", b"...") + users.update_user_blocklist.called_once_with(user, None) @pytest.mark.parametrize("field", ["name", "password"]) diff --git a/server/szurubooru/tests/api/test_user_updating.py b/server/szurubooru/tests/api/test_user_updating.py index 304e4892f..412d4b466 100644 --- a/server/szurubooru/tests/api/test_user_updating.py +++ b/server/szurubooru/tests/api/test_user_updating.py @@ -14,11 +14,13 @@ def inject_config(config_injector): "users:edit:self:name": model.User.RANK_REGULAR, "users:edit:self:pass": model.User.RANK_REGULAR, "users:edit:self:email": model.User.RANK_REGULAR, + "users:edit:self:blocklist": model.User.RANK_REGULAR, "users:edit:self:rank": model.User.RANK_MODERATOR, "users:edit:self:avatar": model.User.RANK_MODERATOR, "users:edit:any:name": model.User.RANK_MODERATOR, "users:edit:any:pass": model.User.RANK_MODERATOR, "users:edit:any:email": model.User.RANK_MODERATOR, + "users:edit:any:blocklist": model.User.RANK_MODERATOR, "users:edit:any:rank": model.User.RANK_ADMINISTRATOR, "users:edit:any:avatar": model.User.RANK_ADMINISTRATOR, }, diff --git a/server/szurubooru/tests/conftest.py b/server/szurubooru/tests/conftest.py index 280987caf..e2780ce1d 100644 --- a/server/szurubooru/tests/conftest.py +++ b/server/szurubooru/tests/conftest.py @@ -136,6 +136,20 @@ def factory( return factory +@pytest.fixture +def user_blocklist_factory(user_factory, tag_factory): + def factory(tag=None, user=None): + if user is None: + user = user_factory() + if tag is None: + tag = tag_factory() + return model.UserTagBlocklist( + tag=tag, user=user + ) + + return factory + + @pytest.fixture def tag_category_factory(): def factory(name=None, color="dummy", order=1, default=False): @@ -172,6 +186,7 @@ def factory( id=None, safety=model.Post.SAFETY_SAFE, type=model.Post.TYPE_IMAGE, + tags=[], checksum="...", ): post = model.Post() @@ -182,6 +197,7 @@ def factory( post.flags = [] post.mime_type = "application/octet-stream" post.creation_time = datetime(1996, 1, 1) + post.tags = tags return post return factory diff --git a/server/szurubooru/tests/func/test_users.py b/server/szurubooru/tests/func/test_users.py index 94e9c7c1b..908e39f07 100644 --- a/server/szurubooru/tests/func/test_users.py +++ b/server/szurubooru/tests/func/test_users.py @@ -158,6 +158,7 @@ def test_serialize_user(user_factory): "avatarUrl": "https://example.com/avatar.png", "likedPostCount": 66, "dislikedPostCount": 33, + "blocklist": [], "commentCount": 0, "favoritePostCount": 0, "uploadedPostCount": 0, @@ -235,7 +236,7 @@ def test_create_user_for_first_user(fake_datetime): "szurubooru.func.users.update_user_password" ), patch("szurubooru.func.users.update_user_email"), fake_datetime( "1997-01-01" - ): + ), patch("szurubooru.func.users.update_user_blocklist"): user = users.create_user("name", "password", "email") assert user.creation_time == datetime(1997, 1, 1) assert user.last_login_time is None @@ -251,7 +252,8 @@ def test_create_user_for_subsequent_users(user_factory, config_injector): db.session.flush() with patch("szurubooru.func.users.update_user_name"), patch( "szurubooru.func.users.update_user_email" - ), patch("szurubooru.func.users.update_user_password"): + ), patch("szurubooru.func.users.update_user_password" + ), patch("szurubooru.func.users.update_user_blocklist"): user = users.create_user("name", "password", "email") assert user.rank == model.User.RANK_REGULAR