Skip to content

Commit

Permalink
chore: add type annotations to notifications
Browse files Browse the repository at this point in the history
Also fixes some of the unhandled conditions.
  • Loading branch information
nijel committed Feb 28, 2025
1 parent 52d7b17 commit e6d788f
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 37 deletions.
96 changes: 64 additions & 32 deletions weblate/accounts/notifications.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from collections import defaultdict
from copy import copy
from email.utils import formataddr
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, cast
from uuid import uuid4

from dateutil.relativedelta import relativedelta
Expand All @@ -31,7 +31,16 @@
from weblate.auth.models import User
from weblate.lang.models import Language
from weblate.logger import LOGGER
from weblate.trans.models import Alert, Change, Component, Project, Translation
from weblate.trans.models import (
Alert,
Announcement,
Change,
Comment,
Component,
Project,
Translation,
Unit,
)
from weblate.utils.errors import report_error
from weblate.utils.markdown import get_mention_users
from weblate.utils.ratelimit import rate_limit
Expand Down Expand Up @@ -127,9 +136,9 @@ def __init__(
self.perm_cache = {}

def get_language_filter(
self, change: Change, translation: Translation
self, change: Change | None, translation: Translation | None
) -> Language | None:
if self.filter_languages:
if self.filter_languages and translation is not None:
return translation.language
return None

Expand Down Expand Up @@ -184,7 +193,14 @@ def filter_subscriptions(
.prefetch_related("user", "user__profile", "user__profile__watched")
)

def get_subscriptions(self, change, project, component, translation, users):
def get_subscriptions(
self,
change: Change | None,
project: Project | None,
component: Component | None,
translation: Translation | None,
users: list[int] | None,
) -> Iterable[Subscription]:
lang_filter = self.get_language_filter(change, translation)
cache_key: tuple[int | str | None, ...] = (
lang_filter.id if lang_filter else None,
Expand Down Expand Up @@ -219,13 +235,13 @@ def is_admin(self, user: User, project):

def get_users(
self,
frequency,
change=None,
project=None,
component=None,
translation=None,
users=None,
):
frequency: NotificationFrequency,
change: Change | None = None,
project: Project | None = None,
component: Component | None = None,
translation: Translation | None = None,
users: list[int] | None = None,
) -> Iterable[User]:
if self.has_required_attrs(change):
return
if change is not None:
Expand Down Expand Up @@ -426,8 +442,9 @@ def notify_immediate(self, change) -> None:
subscription=user.current_subscription,
)
# Delete onetime subscription
if user.current_subscription.onetime:
user.current_subscription.delete()
current_subscription = cast("Subscription", user.current_subscription)
if current_subscription.onetime:
current_subscription.delete()

def send_digest(self, language, email, changes, subscription=None) -> None:
with override("en" if language is None else language):
Expand Down Expand Up @@ -611,8 +628,14 @@ class NewCommentNotificaton(Notification):
filter_languages = True
required_attr = "comment"

def get_language_filter(self, change, translation):
if not change.comment.unit.is_source:
def get_language_filter(
self, change: Change | None, translation: Translation | None
) -> Language | None:
if (
translation is not None
and change is not None
and not cast("Unit", change.unit).is_source
):
return translation.language
return None

Expand All @@ -636,14 +659,14 @@ class MentionCommentNotificaton(Notification):

def get_users(
self,
frequency,
change=None,
project=None,
component=None,
translation=None,
users=None,
):
if self.has_required_attrs(change):
frequency: NotificationFrequency,
change: Change | None = None,
project: Project | None = None,
component: Component | None = None,
translation: Translation | None = None,
users: list[int] | None = None,
) -> Iterable[User]:
if change is None or self.has_required_attrs(change):
return []
return super().get_users(
frequency,
Expand All @@ -652,7 +675,9 @@ def get_users(
component,
translation,
list(
get_mention_users(change.comment.comment).values_list("id", flat=True)
get_mention_users(cast("Comment", change.comment).comment).values_list(
"id", flat=True
)
),
)

Expand All @@ -668,17 +693,20 @@ class LastAuthorCommentNotificaton(Notification):

def get_users(
self,
frequency: int,
frequency: NotificationFrequency,
change: Change | None = None,
project: Project | None = None,
component: Component | None = None,
translation: Translation | None = None,
users: list[int] | None = None,
):
last_author = change.unit.get_last_content_change()[0]
users = [] if last_author.is_anonymous else [last_author.pk]
) -> Iterable[User]:
change_users: list[int] = []
if change is not None:
last_author = cast("Unit", change.unit).get_last_content_change()[0]
if not last_author.is_anonymous:
change_users.append(last_author.pk)
return super().get_users(
frequency, change, project, component, translation, users
frequency, change, project, component, translation, change_users
)


Expand Down Expand Up @@ -743,8 +771,12 @@ class NewAnnouncementNotificaton(Notification):
def should_skip(self, user: User, change) -> bool:
return not change.announcement.notify

def get_language_filter(self, change, translation):
return change.announcement.language
def get_language_filter(
self, change: Change | None, translation: Translation | None
) -> Language | None:
if change is None:
return None
return cast("Announcement", change.announcement).language


@register_notification
Expand Down
8 changes: 3 additions & 5 deletions weblate/auth/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,15 +60,13 @@
from weblate.utils.validators import CRUD_RE, validate_fullname, validate_username

if TYPE_CHECKING:
from collections.abc import (
Iterable,
Mapping,
)
from collections.abc import Iterable, Mapping

from social_core.backends.base import BaseAuth
from social_django.models import DjangoStorage
from social_django.strategy import DjangoStrategy

from weblate.accounts.models import Subscription
from weblate.auth.permissions import PermissionResult
from weblate.wladmin.models import SupportStatusDict

Expand Down Expand Up @@ -539,7 +537,7 @@ def __init__(self, *args, **kwargs) -> None:
self.extra_data: dict[str, str] = {}
self.cla_cache: dict[tuple[int, int], bool] = {}
self._permissions: PermissionsDictType = {}
self.current_subscription = None
self.current_subscription: Subscription | None = None
for name in self.DUMMY_FIELDS:
if name in kwargs:
self.extra_data[name] = kwargs.pop(name)
Expand Down

0 comments on commit e6d788f

Please sign in to comment.