Skip to content

Commit 306cfa1

Browse files
Merge pull request #377 from hotosm/hotfix/notification-training-id
HOTFix : Training Instance in Notification
2 parents b6e9bfa + 9b5a15d commit 306cfa1

File tree

5 files changed

+130
-68
lines changed

5 files changed

+130
-68
lines changed

backend/core/models.py

+1
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,7 @@ class UserNotification(models.Model):
187187
created_at = models.DateTimeField(default=timezone.now)
188188
read_at = models.DateTimeField(null=True, blank=True)
189189
message = models.TextField(max_length=500)
190+
training = models.ForeignKey(Training, to_field="id", on_delete=models.DO_NOTHING)
190191

191192
def mark_as_read(self):
192193
if not self.is_read:

backend/core/serializers.py

+17-2
Original file line numberDiff line numberDiff line change
@@ -496,12 +496,27 @@ def get_unread_notifications_count(self, obj):
496496

497497

498498
class UserNotificationSerializer(serializers.ModelSerializer):
499+
training_model = serializers.SerializerMethodField()
500+
499501
class Meta:
500502
model = UserNotification
501-
fields = ("id", "is_read", "created_at", "read_at", "message")
503+
fields = (
504+
"id",
505+
"is_read",
506+
"created_at",
507+
"read_at",
508+
"message",
509+
"training_model",
510+
)
502511
read_only_fields = (
503512
"id",
504513
"created_at",
505514
"read_at",
506515
"message",
507-
)
516+
"training_model",
517+
)
518+
519+
def get_training_model(self, obj):
520+
if obj.training and obj.training.model:
521+
return obj.training.model.id
522+
return None

backend/core/tasks.py

+3-47
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,10 @@
3030
from django.conf import settings
3131
from django.contrib.gis.db.models.aggregates import Extent
3232
from django.contrib.gis.geos import GEOSGeometry
33-
from django.core.mail import send_mail
3433
from django.shortcuts import get_object_or_404
3534
from django.utils import timezone
3635

37-
from .utils import S3Uploader
36+
from .utils import S3Uploader, send_notification
3837

3938
logging.basicConfig(
4039
format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO
@@ -516,8 +515,8 @@ def train_model(
516515

517516
training_instance = get_object_or_404(Training, id=training_id)
518517
model_instance = get_object_or_404(Model, id=training_instance.model.id)
519-
520-
send_notification(training_instance,"Started")
518+
519+
send_notification(training_instance, "Started")
521520

522521
training_instance.status = "RUNNING"
523522
training_instance.started_at = timezone.now()
@@ -579,46 +578,3 @@ def train_model(
579578
training_instance.save()
580579
send_notification(training_instance, "Failed")
581580
raise ex
582-
583-
def get_email_message(training_instance,status):
584-
585-
hostname = settings.FRONTEND_URL
586-
training_model_url = f"{hostname}/ai-models/{training_instance.model.id}"
587-
588-
message_template = (
589-
"Hi {username},\n\n"
590-
"Your training task (ID: {training_id}) of model {model_name} has {status}. You can view the details here:\n"
591-
"{training_model_url}\n\n"
592-
"Thank you for using fAIr - AI Assisted Mapping Tool.\n\n"
593-
"Best regards,\n"
594-
"The fAIr Dev Team\n\n"
595-
"Get Involved : https://www.hotosm.org/get-involved/\n"
596-
"https://github.com/hotosm/fAIr/"
597-
)
598-
599-
message = message_template.format(
600-
username=training_instance.user.username,
601-
training_id=training_instance.id,
602-
model_name=training_instance.model.name,
603-
status=status.lower(),
604-
training_model_url=training_model_url,
605-
hostname=hostname,
606-
607-
)
608-
subject = f"fAIr : Training {training_instance.id} {status.capitalize()}"
609-
return message, subject
610-
611-
612-
def send_notification(training_instance,status):
613-
if any(method in training_instance.user.notifications_delivery_methods for method in ["web", "email"]):
614-
UserNotification.objects.create(user=training_instance.user, message=f"Training {training_instance.id} has {status}.")
615-
if "email" in training_instance.user.notifications_delivery_methods:
616-
if training_instance.user.email and training_instance.user.email != '':
617-
message,subject=get_email_message(training_instance,status)
618-
send_mail(
619-
subject=subject,
620-
message=message,
621-
from_email=settings.DEFAULT_FROM_EMAIL,
622-
recipient_list=[training_instance.user.email],
623-
fail_silently=False,
624-
)

backend/core/utils.py

+52-1
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,12 @@
1515
import requests
1616
from botocore.exceptions import ClientError, NoCredentialsError
1717
from django.conf import settings
18+
from django.core.mail import send_mail
1819
from django.http import HttpResponseRedirect
1920
from gpxpy.gpx import GPX, GPXTrack, GPXTrackSegment, GPXWaypoint
2021
from tqdm import tqdm
2122

22-
from .models import AOI, FeedbackAOI, FeedbackLabel, Label
23+
from .models import AOI, FeedbackAOI, FeedbackLabel, Label, UserNotification
2324
from .serializers import FeedbackLabelSerializer, LabelSerializer
2425

2526

@@ -450,3 +451,53 @@ def _upload_directory(self, directory_path, bucket_name):
450451
"total_files_uploaded": total_files,
451452
"s3_path": f"s3://{bucket_name}/{self.parent}/",
452453
}
454+
455+
456+
def get_email_message(training_instance, status):
457+
458+
hostname = settings.FRONTEND_URL
459+
training_model_url = f"{hostname}/ai-models/{training_instance.model.id}"
460+
461+
message_template = (
462+
"Hi {username},\n\n"
463+
"Your training task (ID: {training_id}) of model {model_name} has {status}. You can view the details here:\n"
464+
"{training_model_url}\n\n"
465+
"Thank you for using fAIr - AI Assisted Mapping Tool.\n\n"
466+
"Best regards,\n"
467+
"The fAIr Dev Team\n\n"
468+
"Get Involved : https://www.hotosm.org/get-involved/\n"
469+
"https://github.com/hotosm/fAIr/"
470+
)
471+
472+
message = message_template.format(
473+
username=training_instance.user.username,
474+
training_id=training_instance.id,
475+
model_name=training_instance.model.name,
476+
status=status.lower(),
477+
training_model_url=training_model_url,
478+
hostname=hostname,
479+
)
480+
subject = f"fAIr : Training {training_instance.id} {status.capitalize()}"
481+
return message, subject
482+
483+
484+
def send_notification(training_instance, status):
485+
if any(
486+
method in training_instance.user.notifications_delivery_methods
487+
for method in ["web", "email"]
488+
):
489+
UserNotification.objects.create(
490+
user=training_instance.user,
491+
message=f"Training {training_instance.id} has {status}.",
492+
training=training_instance,
493+
)
494+
if "email" in training_instance.user.notifications_delivery_methods:
495+
if training_instance.user.email and training_instance.user.email != "":
496+
message, subject = get_email_message(training_instance, status)
497+
send_mail(
498+
subject=subject,
499+
message=message,
500+
from_email=settings.DEFAULT_FROM_EMAIL,
501+
recipient_list=[training_instance.user.email],
502+
fail_silently=False,
503+
)

backend/core/views.py

+57-18
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@
7878
process_rawdata,
7979
request_rawdata,
8080
s3_object_exists,
81+
send_notification,
8182
)
8283

8384
if settings.ENABLE_PREDICTION_API:
@@ -983,13 +984,14 @@ class UserNotificationViewSet(ReadOnlyModelViewSet):
983984
permission_classes = [IsOsmAuthenticated]
984985
serializer_class = UserNotificationSerializer
985986
filter_backends = [DjangoFilterBackend, filters.OrderingFilter]
986-
filterset_fields = ['is_read']
987-
ordering = ['-created_at']
988-
ordering_fields = ['created_at', 'read_at','is_read']
987+
filterset_fields = ["is_read"]
988+
ordering = ["-created_at"]
989+
ordering_fields = ["created_at", "read_at", "is_read"]
989990

990991
def get_queryset(self):
991992
return UserNotification.objects.filter(user=self.request.user)
992993

994+
993995
class MarkNotificationAsRead(APIView):
994996
authentication_classes = [OsmAuthentication]
995997
permission_classes = [IsOsmAuthenticated]
@@ -999,19 +1001,29 @@ class MarkNotificationAsRead(APIView):
9991001
)
10001002
def post(self, request, notification_id, format=None):
10011003
try:
1002-
notification = UserNotification.objects.get(id=notification_id, user=request.user)
1004+
notification = UserNotification.objects.get(
1005+
id=notification_id, user=request.user
1006+
)
10031007

10041008
if notification.is_read:
1005-
return Response({"detail": "Notification is already marked as read."}, status=status.HTTP_200_OK)
1009+
return Response(
1010+
{"detail": "Notification is already marked as read."},
1011+
status=status.HTTP_200_OK,
1012+
)
10061013

10071014
notification.is_read = True
10081015
notification.read_at = timezone.now()
10091016
notification.save()
10101017

1011-
return Response({"detail": "Notification marked as read."}, status=status.HTTP_200_OK)
1018+
return Response(
1019+
{"detail": "Notification marked as read."}, status=status.HTTP_200_OK
1020+
)
10121021

10131022
except UserNotification.DoesNotExist:
1014-
return Response({"detail": "Notification not found."}, status=status.HTTP_404_NOT_FOUND)
1023+
return Response(
1024+
{"detail": "Notification not found."}, status=status.HTTP_404_NOT_FOUND
1025+
)
1026+
10151027

10161028
class MarkAllNotificationsAsRead(APIView):
10171029
authentication_classes = [OsmAuthentication]
@@ -1023,20 +1035,29 @@ class MarkAllNotificationsAsRead(APIView):
10231035
200: openapi.Response(
10241036
description="All unread notifications marked as read.",
10251037
examples={
1026-
"application/json": {"detail": "All unread notifications marked as read."}
1038+
"application/json": {
1039+
"detail": "All unread notifications marked as read."
1040+
}
10271041
},
10281042
),
10291043
},
10301044
)
10311045
def post(self, request, format=None):
1032-
unread_notifications = UserNotification.objects.filter(user=request.user, is_read=False)
1046+
unread_notifications = UserNotification.objects.filter(
1047+
user=request.user, is_read=False
1048+
)
10331049

10341050
if not unread_notifications.exists():
1035-
return Response({"detail": "No unread notifications found."}, status=status.HTTP_404_NOT_FOUND)
1051+
return Response(
1052+
{"detail": "No unread notifications found."},
1053+
status=status.HTTP_404_NOT_FOUND,
1054+
)
10361055

10371056
unread_notifications.update(is_read=True, read_at=timezone.now())
1038-
return Response({"detail": "All unread notifications marked as read."}, status=status.HTTP_200_OK)
1039-
1057+
return Response(
1058+
{"detail": "All unread notifications marked as read."},
1059+
status=status.HTTP_200_OK,
1060+
)
10401061

10411062

10421063
class TerminateTrainingView(APIView):
@@ -1049,17 +1070,35 @@ def post(self, request, training_id, format=None):
10491070

10501071
task_id = training_instance.task_id
10511072
if not task_id:
1052-
return Response({"detail": "No task associated with this training."}, status=status.HTTP_400_BAD_REQUEST)
1073+
return Response(
1074+
{"detail": "No task associated with this training."},
1075+
status=status.HTTP_400_BAD_REQUEST,
1076+
)
10531077

1054-
task = AsyncResult(task_id,app=current_app)
1055-
if task.state in ["PENDING", "STARTED", "RETRY"]:
1078+
task = AsyncResult(task_id, app=current_app)
1079+
if (
1080+
task.state in ["PENDING", "STARTED", "RETRY", "FAILURE"]
1081+
and training_instance.status != "FAILED"
1082+
):
10561083
current_app.control.revoke(task_id, terminate=True)
10571084
training_instance.status = "FAILED"
10581085
training_instance.finished_at = now()
10591086
training_instance.save()
1060-
return Response({"detail": "Training task cancelled successfully."}, status=status.HTTP_200_OK)
1087+
send_notification(training_instance, "Cancelled")
1088+
return Response(
1089+
{"detail": "Training task cancelled successfully."},
1090+
status=status.HTTP_200_OK,
1091+
)
10611092
else:
1062-
return Response({"detail": f"Task cannot be cancelled. Current state: {task.state}"}, status=status.HTTP_400_BAD_REQUEST)
1093+
return Response(
1094+
{
1095+
"detail": f"Task cannot be cancelled. Current state: {task.state}"
1096+
},
1097+
status=status.HTTP_400_BAD_REQUEST,
1098+
)
10631099

10641100
except Training.DoesNotExist:
1065-
return Response({"detail": "Training not found or do not belong to you"}, status=status.HTTP_404_NOT_FOUND)
1101+
return Response(
1102+
{"detail": "Training not found or do not belong to you"},
1103+
status=status.HTTP_404_NOT_FOUND,
1104+
)

0 commit comments

Comments
 (0)