Skip to content

Commit 204f5c3

Browse files
Merge pull request #373 from hotosm/feature/notification
Feature : Notification System
2 parents d7f276e + ecfb6c7 commit 204f5c3

File tree

9 files changed

+261
-67
lines changed

9 files changed

+261
-67
lines changed

backend/aiproject/celery.py

+1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
# read config from Django settings, the CELERY namespace would make celery
1919
# config keys has `CELERY` prefix
2020
app.config_from_object("django.conf:settings", namespace="CELERY")
21+
app.conf.task_track_started = True
2122

2223
# discover and load tasks.py from from all registered Django apps
2324
app.autodiscover_tasks(lambda: settings.INSTALLED_APPS)

backend/aiproject/settings.py

+14
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,10 @@
3333
LOG_PATH = env("LOG_PATH", default=os.path.join(os.getcwd(), "log"))
3434

3535
HOSTNAME = env("HOSTNAME", default="127.0.0.1")
36+
37+
FRONTEND_URL = env("FRONTEND_URL", default="https://fair.hotosm.org")
38+
39+
3640
EXPORT_TOOL_API_URL = env(
3741
"EXPORT_TOOL_API_URL",
3842
default="https://api-prod.raw-data.hotosm.org/v1",
@@ -270,3 +274,13 @@
270274

271275

272276
TEST_RUNNER = "tests.test_runners.NoDestroyTestRunner"
277+
278+
279+
EMAIL_BACKEND = "django.core.mail.backends.smtp.EmailBackend"
280+
EMAIL_HOST = os.getenv("EMAIL_HOST", "smtp.gmail.com")
281+
EMAIL_PORT = int(os.getenv("EMAIL_PORT", 587))
282+
EMAIL_USE_TLS = os.getenv("EMAIL_USE_TLS", "True") == "True"
283+
EMAIL_USE_SSL = os.getenv("EMAIL_USE_SSL", "False") == "True"
284+
EMAIL_HOST_USER = os.getenv("EMAIL_HOST_USER", "example-email@example.com")
285+
EMAIL_HOST_PASSWORD = os.getenv("EMAIL_HOST_PASSWORD", "example-email-password")
286+
DEFAULT_FROM_EMAIL = os.getenv("DEFAULT_FROM_EMAIL", "no-reply@example.com")

backend/core/models.py

+23
Original file line numberDiff line numberDiff line change
@@ -173,3 +173,26 @@ def is_displayable(self):
173173

174174
def __str__(self):
175175
return self.message
176+
177+
178+
class UserNotification(models.Model):
179+
180+
user = models.ForeignKey(
181+
OsmUser,
182+
to_field="osm_id",
183+
on_delete=models.CASCADE,
184+
related_name="notifications",
185+
)
186+
is_read = models.BooleanField(default=False)
187+
created_at = models.DateTimeField(default=timezone.now)
188+
read_at = models.DateTimeField(null=True, blank=True)
189+
message = models.TextField(max_length=500)
190+
191+
def mark_as_read(self):
192+
if not self.is_read:
193+
self.is_read = True
194+
self.read_at = timezone.now()
195+
self.save()
196+
197+
def __str__(self):
198+
return f"Notification for {self.user.username}: {self.message[:50]}..."

backend/core/serializers.py

+32-16
Original file line numberDiff line numberDiff line change
@@ -91,13 +91,6 @@ def __init__(self, *args, **kwargs):
9191
):
9292
self.fields["dataset"] = DatasetSerializer(read_only=True)
9393

94-
# def get_training(self, obj):
95-
# if not hasattr(self, "_cached_training"):
96-
# self._cached_training = Training.objects.filter(
97-
# id=obj.published_training
98-
# ).first()
99-
# return self._cached_training
100-
10194
def get_thumbnail_url(self, obj):
10295
training = Training.objects.filter(id=obj.published_training).first()
10396

@@ -144,14 +137,6 @@ def get_geometry(self, obj):
144137
}
145138
return None
146139

147-
# def to_representation(self, instance):
148-
# """
149-
# Override to_representation to customize GeoJSON structure.
150-
# """
151-
# representation = super().to_representation(instance)
152-
# representation["properties"]["id"] = representation.pop("id")
153-
# return representation
154-
155140

156141
class AOISerializer(
157142
GeoFeatureModelSerializer
@@ -451,6 +436,7 @@ class UserStatsSerializer(serializers.ModelSerializer):
451436
feedbacks_count = serializers.SerializerMethodField()
452437
approved_predictions_count = serializers.SerializerMethodField()
453438
profile_completion_percentage = serializers.SerializerMethodField()
439+
unread_notifications_count = serializers.SerializerMethodField()
454440

455441
class Meta:
456442
model = OsmUser
@@ -460,13 +446,28 @@ class Meta:
460446
"email",
461447
"date_joined",
462448
"img_url",
449+
"notifications_delivery_methods",
450+
"newsletter_subscription",
451+
"account_deletion_requested",
452+
"models_count",
453+
"datasets_count",
454+
"feedbacks_count",
455+
"approved_predictions_count",
456+
"profile_completion_percentage",
457+
"unread_notifications_count",
458+
]
459+
read_only_fields = [
460+
"osm_id",
461+
"username",
462+
"date_joined",
463+
"img_url",
463464
"models_count",
464465
"datasets_count",
465466
"feedbacks_count",
466467
"approved_predictions_count",
467468
"profile_completion_percentage",
469+
"unread_notifications_count",
468470
]
469-
read_only_fields = ["osm_id", "username", "date_joined", "img_url"]
470471

471472
def get_models_count(self, obj):
472473
return Model.objects.filter(user=obj).count()
@@ -489,3 +490,18 @@ def get_profile_completion_percentage(self, obj):
489490
if obj.email is not None and obj.email != "":
490491
profile_percentage += 25
491492
return profile_percentage
493+
494+
def get_unread_notifications_count(self, obj):
495+
return UserNotification.objects.filter(user=obj, is_read=False).count()
496+
497+
498+
class UserNotificationSerializer(serializers.ModelSerializer):
499+
class Meta:
500+
model = UserNotification
501+
fields = ("id", "is_read", "created_at", "read_at", "message")
502+
read_only_fields = (
503+
"id",
504+
"created_at",
505+
"read_at",
506+
"message",
507+
)

backend/core/tasks.py

+51-3
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
Label,
1818
Model,
1919
Training,
20+
UserNotification,
2021
)
2122
from core.serializers import (
2223
AOISerializer,
@@ -29,6 +30,7 @@
2930
from django.conf import settings
3031
from django.contrib.gis.db.models.aggregates import Extent
3132
from django.contrib.gis.geos import GEOSGeometry
33+
from django.core.mail import send_mail
3234
from django.shortcuts import get_object_or_404
3335
from django.utils import timezone
3436

@@ -514,14 +516,15 @@ def train_model(
514516

515517
training_instance = get_object_or_404(Training, id=training_id)
516518
model_instance = get_object_or_404(Model, id=training_instance.model.id)
519+
520+
send_notification(training_instance,"Started")
517521

518522
training_instance.status = "RUNNING"
519523
training_instance.started_at = timezone.now()
524+
training_instance.task_id = train_model.request.id
525+
520526
training_instance.save()
521527
os.makedirs(settings.LOG_PATH, exist_ok=True)
522-
if training_instance.task_id is None or training_instance.task_id.strip() == "":
523-
training_instance.task_id = train_model.request.id
524-
training_instance.save()
525528
log_file = os.path.join(settings.LOG_PATH, f"run_{train_model.request.id}.log")
526529

527530
if model_instance.base_model == "YOLO_V8_V1" and settings.YOLO_HOME is None:
@@ -567,10 +570,55 @@ def train_model(
567570
)
568571

569572
logging.info(f"Training task {training_id} completed successfully")
573+
send_notification(training_instance, "Completed")
570574
return response
571575

572576
except Exception as ex:
573577
training_instance.status = "FAILED"
574578
training_instance.finished_at = timezone.now()
575579
training_instance.save()
580+
send_notification(training_instance, "Failed")
576581
raise ex
582+
583+
def get_email_message(training_instance,status):
584+
585+
hostname = settings.FRONTEND_URL
586+
training_model_url = f"{hostname}/ai-models/{training_instance.model.id}"
587+
588+
message_template = (
589+
"Hi {username},\n\n"
590+
"Your training task (ID: {training_id}) of model {model_name} has {status}. You can view the details here:\n"
591+
"{training_model_url}\n\n"
592+
"Thank you for using fAIr - AI Assisted Mapping Tool.\n\n"
593+
"Best regards,\n"
594+
"The fAIr Dev Team\n\n"
595+
"Get Involved : https://www.hotosm.org/get-involved/\n"
596+
"https://github.com/hotosm/fAIr/"
597+
)
598+
599+
message = message_template.format(
600+
username=training_instance.user.username,
601+
training_id=training_instance.id,
602+
model_name=training_instance.model.name,
603+
status=status.lower(),
604+
training_model_url=training_model_url,
605+
hostname=hostname,
606+
607+
)
608+
subject = f"fAIr : Training {training_instance.id} {status.capitalize()}"
609+
return message, subject
610+
611+
612+
def send_notification(training_instance,status):
613+
if any(method in training_instance.user.notifications_delivery_methods for method in ["web", "email"]):
614+
UserNotification.objects.create(user=training_instance.user, message=f"Training {training_instance.id} has {status}.")
615+
if "email" in training_instance.user.notifications_delivery_methods:
616+
if training_instance.user.email and training_instance.user.email != '':
617+
message,subject=get_email_message(training_instance,status)
618+
send_mail(
619+
subject=subject,
620+
message=message,
621+
from_email=settings.DEFAULT_FROM_EMAIL,
622+
recipient_list=[training_instance.user.email],
623+
fail_silently=False,
624+
)

backend/core/urls.py

+10-4
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,7 @@
33
from django.urls import path
44
from rest_framework import routers
55

6-
# now import the views.py file into this code
7-
from .views import ( # APIStatus,
6+
from .views import (
87
AOIViewSet,
98
ApprovedPredictionsViewSet,
109
BannerViewSet,
@@ -18,13 +17,17 @@
1817
GenerateGpxView,
1918
LabelUploadView,
2019
LabelViewSet,
20+
MarkAllNotificationsAsRead,
21+
MarkNotificationAsRead,
2122
ModelCentroidView,
2223
ModelViewSet,
2324
RawdataApiAOIView,
2425
RawdataApiFeedbackView,
26+
TerminateTrainingView,
2527
TrainingViewSet,
2628
TrainingWorkspaceDownloadView,
2729
TrainingWorkspaceView,
30+
UserNotificationViewSet,
2831
UsersView,
2932
download_training_data,
3033
geojson2osmconverter,
@@ -48,7 +51,7 @@
4851
router.register(r"feedback-aoi", FeedbackAOIViewset)
4952
router.register(r"feedback-label", FeedbackLabelViewset)
5053
router.register(r"banner", BannerViewSet)
51-
54+
router.register(r'notifications/me', UserNotificationViewSet, basename='notifications')
5255

5356
urlpatterns = [
5457
path("", include(router.urls)),
@@ -63,6 +66,7 @@
6366
# path("download/<int:dataset_id>/", download_training_data),
6467
path("training/status/<str:run_id>/", run_task_status),
6568
path("training/publish/<int:training_id>/", publish_training),
69+
path("training/terminate/<int:training_id>/", TerminateTrainingView.as_view(), name="cancel_training"),
6670
path("feedback/training/submit/", FeedbackView.as_view()),
6771
# path("status/", APIStatus.as_view()),
6872
path("geojson2osm/", geojson2osmconverter, name="geojson2osmconverter"),
@@ -71,12 +75,14 @@
7175
path(
7276
"feedback-aoi/gpx/<int:feedback_aoi_id>/", GenerateFeedbackAOIGpxView.as_view()
7377
),
74-
# path("workspace/", TrainingWorkspaceView.as_view()),
7578
path(
7679
"workspace/download/<path:lookup_dir>/", TrainingWorkspaceDownloadView.as_view()
7780
),
7881
path("workspace/<path:lookup_dir>/", TrainingWorkspaceView.as_view()),
7982
path("kpi/stats/", get_kpi_stats, name="get_kpi_stats"),
83+
path("notifications/mark-as-read/<int:notification_id>/", MarkNotificationAsRead.as_view(), name="mark_notification_as_read"),
84+
path("notifications/mark-all-as-read/", MarkAllNotificationsAsRead.as_view(), name="mark_all_notifications_as_read"),
85+
8086
]
8187
if settings.ENABLE_PREDICTION_API:
8288
urlpatterns.append(path("prediction/", PredictionView.as_view()))

0 commit comments

Comments
 (0)