Skip to content

Commit

Permalink
Blocklist: Add backend elements:
Browse files Browse the repository at this point in the history
- Add default blocklist to user when created
- Tags are created if added to a user blocklist
- Add matching migration to DB to add the user blocklist table
- Various other things
  • Loading branch information
Soblow authored and Lugrim committed May 5, 2024
1 parent e5f61d2 commit 82721c0
Show file tree
Hide file tree
Showing 9 changed files with 220 additions and 9 deletions.
2 changes: 2 additions & 0 deletions server/szurubooru/api/info_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ def get_info(ctx: rest.Context, _params: Dict[str, str] = {}) -> rest.Response:
"tagNameRegex": config.config["tag_name_regex"],
"tagCategoryNameRegex": config.config["tag_category_name_regex"],
"defaultUserRank": config.config["default_rank"],
"defaultTagBlocklist": config.config["default_tag_blocklist"],
"defaultTagBlocklistForAnonymous": config.config["default_tag_blocklist_for_anonymous"],
"enableSafety": config.config["enable_safety"],
"contactEmail": config.config["contact_email"],
"canSendMails": bool(config.config["smtp"]["host"]),
Expand Down
2 changes: 1 addition & 1 deletion server/szurubooru/api/tag_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Dict, List, Optional

from szurubooru import db, model, rest, search
from szurubooru.func import auth, serialization, snapshots, tags, versions
from szurubooru.func import auth, serialization, snapshots, tags, versions, users

_search_executor = search.Executor(search.configs.TagSearchConfig())

Expand Down
32 changes: 29 additions & 3 deletions server/szurubooru/api/user_api.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Any, Dict
from typing import Any, Dict, List

from szurubooru import model, rest, search
from szurubooru.func import auth, serialization, users, versions
from szurubooru import db, model, rest, search
from szurubooru.func import auth, serialization, snapshots, users, versions, tags

_search_executor = search.Executor(search.configs.UserSearchConfig())

Expand All @@ -17,6 +17,18 @@ def _serialize(
)


def _create_tag_if_needed(tag_names: List[str], user: model.User) -> None:
# Taken from tag_api.py
if not tag_names:
return
_existing_tags, new_tags = tags.get_or_create_tags_by_names(tag_names)
if len(new_tags):
auth.verify_privilege(user, "tags:create")
db.session.flush()
for tag in new_tags:
snapshots.create(tag, user)


@rest.routes.get("/users/?")
def get_users(
ctx: rest.Context, _params: Dict[str, str] = {}
Expand Down Expand Up @@ -50,6 +62,10 @@ def create_user(
)
ctx.session.add(user)
ctx.session.commit()
to_add, _ = users.update_user_blocklist(user, None)
for e in to_add:
ctx.session.add(e)
ctx.session.commit()

return _serialize(ctx, user, force_show_email=True)

Expand Down Expand Up @@ -80,6 +96,16 @@ def update_user(ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
if ctx.has_param("rank"):
auth.verify_privilege(ctx.user, "users:edit:%s:rank" % infix)
users.update_user_rank(user, ctx.get_param_as_string("rank"), ctx.user)
if ctx.has_param("blocklist"):
auth.verify_privilege(ctx.user, "users:edit:%s:blocklist" % infix)
blocklist = ctx.get_param_as_string_list("blocklist")
_create_tag_if_needed(blocklist, user) # Non-existing tags are created.
blocklist_tags = tags.get_tags_by_names(blocklist)
to_add, to_remove = users.update_user_blocklist(user, blocklist_tags)
for e in to_remove:
ctx.session.delete(e)
for e in to_add:
ctx.session.add(e)
if ctx.has_param("avatarStyle"):
auth.verify_privilege(ctx.user, "users:edit:%s:avatar" % infix)
users.update_user_avatar(
Expand Down
21 changes: 21 additions & 0 deletions server/szurubooru/func/tags.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,9 @@ def get_tag_by_name(name: str) -> model.Tag:


def get_tags_by_names(names: List[str]) -> List[model.Tag]:
"""
Returns a list of all tags which names include all the letters from the input list
"""
names = util.icase_unique(names)
if len(names) == 0:
return []
Expand All @@ -175,6 +178,24 @@ def get_tags_by_names(names: List[str]) -> List[model.Tag]:
)


def get_tags_by_exact_names(names: List[str]) -> List[model.Tag]:
"""
Returns a list of tags matching the names from the input list
"""
entries = []
if len(names) == 0:
return []
names = [name.lower() for name in names]
entries = (
db.session.query(model.Tag)
.join(model.TagName)
.filter(
sa.func.lower(model.TagName.name).in_(names)
)
.all())
return entries


def get_or_create_tags_by_names(
names: List[str],
) -> Tuple[List[model.Tag], List[model.Tag]]:
Expand Down
67 changes: 66 additions & 1 deletion server/szurubooru/func/users.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import copy
import re
from datetime import datetime
from typing import Any, Callable, Dict, List, Optional, Union

import sqlalchemy as sa

from szurubooru import config, db, errors, model, rest
from szurubooru.func import auth, files, images, serialization, util
from szurubooru.func import auth, files, images, serialization, util, tags


class UserNotFoundError(errors.NotFoundError):
Expand Down Expand Up @@ -107,6 +108,7 @@ def _serializers(self) -> Dict[str, Callable[[], Any]]:
"lastLoginTime": self.serialize_last_login_time,
"version": self.serialize_version,
"rank": self.serialize_rank,
"blocklist": self.serialize_blocklist,
"avatarStyle": self.serialize_avatar_style,
"avatarUrl": self.serialize_avatar_url,
"commentCount": self.serialize_comment_count,
Expand Down Expand Up @@ -138,6 +140,9 @@ def serialize_avatar_style(self) -> Any:
def serialize_avatar_url(self) -> Any:
return get_avatar_url(self.user)

def serialize_blocklist(self) -> Any:
return [tags.serialize_tag(tag) for tag in get_blocklist_tag_from_user(self.user)]

def serialize_comment_count(self) -> Any:
return self.user.comment_count

Expand Down Expand Up @@ -294,6 +299,66 @@ def update_user_rank(
user.rank = rank


def get_blocklist_from_user(user: model.User) -> List[model.UserTagBlocklist]:
"""
Return the UserTagBlocklist objects related to given user
"""
rez = (db.session.query(model.UserTagBlocklist)
.filter(
model.UserTagBlocklist.user_id == user.user_id
)
.all())
return rez


def get_blocklist_tag_from_user(user: model.User) -> List[model.UserTagBlocklist]:
"""
Return the Tags blocklisted by given user
"""
rez = (db.session.query(model.UserTagBlocklist.tag_id)
.filter(
model.UserTagBlocklist.user_id == user.user_id
))
rez2 = (db.session.query(model.Tag)
.filter(
model.Tag.tag_id.in_(rez)
).all())
return rez2


def update_user_blocklist(user: model.User, new_blocklist_tags: Optional[List[model.Tag]]) -> List[List[model.UserTagBlocklist]]:
"""
Modify blocklist for given user.
If new_blocklist_tags is None, set the blocklist to configured default tag blocklist.
"""
assert user
to_add: List[model.UserTagBlocklist] = []
to_remove: List[model.UserTagBlocklist] = []

if new_blocklist_tags is None: # We're creating the user, use default config blocklist
if 'default_tag_blocklist' in config.config.keys():
for e in tags.get_tags_by_exact_names(config.config['default_tag_blocklist'].split(' ')):
to_add.append(model.UserTagBlocklist(user_id=user.user_id, tag_id=e.tag_id))
else:
new_blocklist_ids: List[int] = [e.tag_id for e in new_blocklist_tags]
previous_blocklist_tags: List[model.Tag] = get_blocklist_from_user(user)
previous_blocklist_ids: List[int] = [e.tag_id for e in previous_blocklist_tags]
original_previous_blocklist_ids = copy.copy(previous_blocklist_ids)

## Remove tags no longer in the new list
for i in range(len(original_previous_blocklist_ids)):
old_tag_id = original_previous_blocklist_ids[i]
if old_tag_id not in new_blocklist_ids:
to_remove.append(previous_blocklist_tags[i])
previous_blocklist_ids.remove(old_tag_id)

## Add tags not yet in the original list
for new_tag_id in new_blocklist_ids:
if new_tag_id not in previous_blocklist_ids:
to_add.append(model.UserTagBlocklist(user_id=user.user_id, tag_id=new_tag_id))
return to_add, to_remove


def update_user_avatar(
user: model.User, avatar_style: str, avatar_content: Optional[bytes] = None
) -> None:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
'''
Add blocklist related fields
add_blocklist
Revision ID: 9ba5e3a6ee7c
Created at: 2023-05-20 22:28:10.824954
'''

import sqlalchemy as sa
from alembic import op

revision = '9ba5e3a6ee7c'
down_revision = 'adcd63ff76a2'
branch_labels = None
depends_on = None


def upgrade():
op.create_table(
"user_tag_blocklist",
sa.Column("user_id", sa.Integer(), nullable=False),
sa.Column("tag_id", sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(["user_id"], ["user.id"]),
sa.ForeignKeyConstraint(["tag_id"], ["tag.id"]),
sa.PrimaryKeyConstraint("user_id", "tag_id"),
)

def downgrade():
op.drop_table('user_tag_blocklist')
2 changes: 1 addition & 1 deletion server/szurubooru/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@
from szurubooru.model.snapshot import Snapshot
from szurubooru.model.tag import Tag, TagImplication, TagName, TagSuggestion
from szurubooru.model.tag_category import TagCategory
from szurubooru.model.user import User, UserToken
from szurubooru.model.user import UserTagBlocklist, User, UserToken
41 changes: 41 additions & 0 deletions server/szurubooru/model/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,46 @@
from szurubooru.model.post import Post, PostFavorite, PostScore


class UserTagBlocklist(Base):
__tablename__ = "user_tag_blocklist"

user_id = sa.Column(
"user_id",
sa.Integer,
sa.ForeignKey("user.id"),
primary_key=True,
nullable=False,
index=True,
)
tag_id = sa.Column(
"tag_id",
sa.Integer,
sa.ForeignKey("tag.id"),
primary_key=True,
nullable=False,
index=True,
)

tag = sa.orm.relationship(
"Tag",
backref=sa.orm.backref("user_tag_blocklist", cascade="all, delete-orphan"),
)
user = sa.orm.relationship(
"User",
backref=sa.orm.backref("user_tag_blocklist", cascade="all, delete-orphan"),
)

def __init__(self, user_id: int=None, tag_id: int=None, user=None, tag=None) -> None:
if user_id is not None:
self.user_id = user_id
if tag_id is not None:
self.tag_id = tag_id
if user is not None:
self.user = user
if tag is not None:
self.tag = tag


class User(Base):
__tablename__ = "user"

Expand Down Expand Up @@ -35,6 +75,7 @@ class User(Base):
"avatar_style", sa.Unicode(32), nullable=False, default=AVATAR_GRAVATAR
)

blocklist = sa.orm.relationship("UserTagBlocklist")
comments = sa.orm.relationship("Comment")

@property
Expand Down
32 changes: 29 additions & 3 deletions server/szurubooru/search/configs/post_search_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

import sqlalchemy as sa

from szurubooru import db, errors, model
from szurubooru.func import util
from szurubooru.search import criteria, tokens
from szurubooru import config, db, errors, model
from szurubooru.func import tags, users, util
from szurubooru.search import criteria, parser, tokens
from szurubooru.search.configs import util as search_util
from szurubooru.search.configs.base_search_config import (
BaseSearchConfig,
Expand Down Expand Up @@ -178,6 +178,32 @@ def on_search_query_parsed(self, search_query: SearchQuery) -> SaQuery:
new_special_tokens.append(token)
search_query.special_tokens = new_special_tokens

blocklist_to_use = ""

if self.user: # Ensure there's a user object
if (self.user.rank == model.User.RANK_ANONYMOUS) and config.config["default_tag_blocklist_for_anonymous"]:
# Anonymous user, if configured to use default blocklist, do so
blocklist_to_use = config.config["default_tag_blocklist"]
else:
# Registered user, use their blocklist
user_blocklist_tags = users.get_blocklist_tag_from_user(self.user)
if user_blocklist_tags:
user_blocklist = db.session.query(model.Tag.first_name).filter(
model.Tag.tag_id.in_([e.tag_id for e in user_blocklist_tags])
).all()
blocklist_to_use = [e[0] for e in user_blocklist]
blocklist_to_use = " ".join(blocklist_to_use)

if len(blocklist_to_use) > 0:
# TODO Sort an already parsed and checked version instead?
blocklist_query = parser.Parser().parse(blocklist_to_use)
search_query_orig_list = [e.criterion.original_text for e in search_query.anonymous_tokens]
for t in blocklist_query.anonymous_tokens:
if t.criterion.original_text in search_query_orig_list:
continue
t.negated = True
search_query.anonymous_tokens.append(t)

def create_around_query(self) -> SaQuery:
return db.session.query(model.Post).options(sa.orm.lazyload("*"))

Expand Down

0 comments on commit 82721c0

Please sign in to comment.