|
| 1 | +import json |
| 2 | +import logging |
| 3 | + |
| 4 | +import smart_open |
| 5 | + |
| 6 | +from dateutil import parser |
| 7 | +from django.conf import settings |
| 8 | +from redis import Redis |
| 9 | +from rq import Queue, Worker |
| 10 | +from rq.job import Job |
| 11 | + |
| 12 | +from datahub.core.queues.job_scheduler import job_scheduler |
| 13 | +from datahub.ingest.boto3 import S3ObjectProcessor |
| 14 | +from datahub.ingest.models import IngestedObject |
| 15 | + |
| 16 | + |
| 17 | +logger = logging.getLogger(__name__) |
| 18 | + |
| 19 | + |
| 20 | +class QueueChecker: |
| 21 | + """Checks the Redis Queue for specific ingestion task.""" |
| 22 | + |
| 23 | + def __init__(self, queue_name: str) -> None: |
| 24 | + self.redis = Redis.from_url(settings.REDIS_BASE_URL) |
| 25 | + self.queue = Queue(queue_name, connection=self.redis) |
| 26 | + |
| 27 | + def match_job(self, job: Job, ingestion_task_function: callable, object_key: str) -> bool: |
| 28 | + """Determines if the job matches the ingestion task name and object key.""" |
| 29 | + function_name = f'{ingestion_task_function.__module__}.{ingestion_task_function.__name__}' |
| 30 | + return job.func_name == function_name and job.kwargs.get('object_key') == object_key |
| 31 | + |
| 32 | + def is_job_queued(self, ingestion_task_function: callable, object_key: str) -> bool: |
| 33 | + """Check if a job is queued or running.""" |
| 34 | + # Check queued jobs |
| 35 | + if any( |
| 36 | + self.match_job(job, ingestion_task_function, object_key) for job in self.queue.jobs |
| 37 | + ): |
| 38 | + return True |
| 39 | + # Check running jobs |
| 40 | + for worker in Worker.all(queue=self.queue): |
| 41 | + job = worker.get_current_job() |
| 42 | + if job and self.match_job(job, ingestion_task_function, object_key): |
| 43 | + return True |
| 44 | + return False |
| 45 | + |
| 46 | + |
| 47 | +class BaseObjectIdentificationTask: |
| 48 | + """Base class to identify new objects in S3 and determine if they should be ingested. |
| 49 | +
|
| 50 | + An example of how this class should be used: |
| 51 | + ``` |
| 52 | + def base_identification_task() -> None: |
| 53 | + '''Function to be scheduled and called by an RQ worker to identify objects to ingest.''' |
| 54 | + logger.info('Base identification task started...') |
| 55 | + identification_task = BaseObjectIdentificationTask(prefix='prefix/') |
| 56 | + identification_task.identify_new_objects(base_ingestion_task) |
| 57 | + logger.info('Base identification task finished.') |
| 58 | + ``` |
| 59 | + """ |
| 60 | + |
| 61 | + def __init__(self, prefix: str): |
| 62 | + self.long_queue_checker: QueueChecker = QueueChecker(queue_name='long-running') |
| 63 | + self.s3_processor: S3ObjectProcessor = S3ObjectProcessor(prefix=prefix) |
| 64 | + |
| 65 | + def identify_new_objects(self, ingestion_task_function: callable) -> None: |
| 66 | + """Entry point method to identify new objects and, if valid, schedule their ingestion.""" |
| 67 | + latest_object_key = self.s3_processor.get_most_recent_object() |
| 68 | + |
| 69 | + if not latest_object_key: |
| 70 | + logger.info('No objects found') |
| 71 | + return |
| 72 | + |
| 73 | + if self.long_queue_checker.is_job_queued( |
| 74 | + ingestion_task_function, latest_object_key, |
| 75 | + ): |
| 76 | + logger.info(f'{latest_object_key} has already been queued for ingestion') |
| 77 | + return |
| 78 | + |
| 79 | + if self.s3_processor.has_object_been_ingested(latest_object_key): |
| 80 | + logger.info(f'{latest_object_key} has already been ingested') |
| 81 | + return |
| 82 | + |
| 83 | + job_scheduler( |
| 84 | + function=ingestion_task_function, |
| 85 | + function_kwargs={ |
| 86 | + 'object_key': latest_object_key, |
| 87 | + 's3_processor': self.s3_processor, |
| 88 | + }, |
| 89 | + queue_name=self.long_queue_checker.queue.name, |
| 90 | + description=f'Ingest {latest_object_key}', |
| 91 | + ) |
| 92 | + logger.info(f'Scheduled ingestion of {latest_object_key}') |
| 93 | + |
| 94 | + |
| 95 | +def base_ingestion_task( |
| 96 | + object_key: str, |
| 97 | + s3_processor: S3ObjectProcessor, |
| 98 | +) -> None: |
| 99 | + """Function to be scheduled by the BaseObjectIdentificationTask.identify_new_objects method. |
| 100 | +
|
| 101 | + Once executed by an RQ worker, it will ingest the specified object. |
| 102 | +
|
| 103 | + This function serves as an example and is only used in tests of the base classes. |
| 104 | + It is not, and should not be, scheduled for execution in `cron-scheduler.py`. |
| 105 | + """ |
| 106 | + logger.info('Base ingestion task started...') |
| 107 | + ingestion_task = BaseObjectIngestionTask( |
| 108 | + object_key=object_key, |
| 109 | + s3_processor=s3_processor, |
| 110 | + ) |
| 111 | + ingestion_task.ingest_object() |
| 112 | + logger.info('Base ingestion task finished.') |
| 113 | + |
| 114 | + |
| 115 | +class BaseObjectIngestionTask: |
| 116 | + """Base class to ingest a specified object from S3.""" |
| 117 | + |
| 118 | + def __init__( |
| 119 | + self, |
| 120 | + object_key: str, |
| 121 | + s3_processor: S3ObjectProcessor, |
| 122 | + ) -> None: |
| 123 | + self.object_key = object_key |
| 124 | + self.s3_processor = s3_processor |
| 125 | + self.last_ingestion_datetime = self.s3_processor.get_last_ingestion_datetime() |
| 126 | + self.skipped_counter = 0 |
| 127 | + self.created_ids = [] |
| 128 | + self.updated_ids = [] |
| 129 | + self.errors = [] |
| 130 | + |
| 131 | + def ingest_object(self) -> None: |
| 132 | + """Process all records in the object key specified when the class instance was created.""" |
| 133 | + try: |
| 134 | + with smart_open.open( |
| 135 | + f's3://{self.s3_processor.bucket}/{self.object_key}', |
| 136 | + ) as s3_object: |
| 137 | + for line in s3_object: |
| 138 | + deserialized_line = json.loads(line) |
| 139 | + record = self._get_record_from_line(deserialized_line) |
| 140 | + if self._should_process_record(record): |
| 141 | + self._process_record(record) |
| 142 | + else: |
| 143 | + self.skipped_counter += 1 |
| 144 | + except Exception as e: |
| 145 | + logger.error(f'An error occurred trying to process {self.object_key}: {str(e)}') |
| 146 | + raise e |
| 147 | + |
| 148 | + # Record ingestion |
| 149 | + last_modified = self.s3_processor.get_object_last_modified_datetime(self.object_key) |
| 150 | + IngestedObject.objects.create(object_key=self.object_key, object_created=last_modified) |
| 151 | + logger.info(f'{self.object_key} ingested.') |
| 152 | + |
| 153 | + # Log metrics |
| 154 | + if self.created_ids: |
| 155 | + logger.info( |
| 156 | + f'{len(self.created_ids)} records created: {self.created_ids}', |
| 157 | + ) |
| 158 | + if self.updated_ids: |
| 159 | + logger.info( |
| 160 | + f'{len(self.updated_ids)} records updated: {self.updated_ids}', |
| 161 | + ) |
| 162 | + if self.errors: |
| 163 | + logger.warning(f'{len(self.errors)} records failed validation: {self.errors}') |
| 164 | + if self.skipped_counter: |
| 165 | + logger.info( |
| 166 | + f'{self.skipped_counter} records skipped.', |
| 167 | + ) |
| 168 | + |
| 169 | + def _get_record_from_line(self, deserialized_line: dict) -> dict: |
| 170 | + """Extracts the record from the deserialized line. |
| 171 | +
|
| 172 | + This method should be overridden if the record is nested. |
| 173 | + """ |
| 174 | + return deserialized_line |
| 175 | + |
| 176 | + def _should_process_record(self, record: dict) -> bool: |
| 177 | + """Determine if a record should be processed. |
| 178 | +
|
| 179 | + This method uses the incoming record's last modified date to check |
| 180 | + whether the record should be processed. If the incoming data has a |
| 181 | + similar field, please override the `_get_modified_datetime_str` method to |
| 182 | + specify which field should be used. |
| 183 | +
|
| 184 | + If the record has no modified (or similar) field, please override this |
| 185 | + method to set the desired rules to determine if it's processed or not. |
| 186 | + """ |
| 187 | + if self.last_ingestion_datetime is None: |
| 188 | + return True |
| 189 | + try: |
| 190 | + modified_datetime_str = self._get_modified_datetime_str(record) |
| 191 | + modified_datetime = parser.parse(modified_datetime_str) |
| 192 | + except ValueError as e: |
| 193 | + logger.error( |
| 194 | + f'An error occurred determining the last modified datetime: {str(e)}', |
| 195 | + ) |
| 196 | + # If unable to parse datetime string, assume record should be processed. |
| 197 | + return True |
| 198 | + return modified_datetime.timestamp() >= self.last_ingestion_datetime.timestamp() |
| 199 | + |
| 200 | + def _get_modified_datetime_str(self, record: dict) -> str: |
| 201 | + """Gets the last modified datetime string from the incoming record.""" |
| 202 | + return record['modified'] |
| 203 | + |
| 204 | + def _process_record(self, record: dict) -> None: |
| 205 | + """Processes a single record. |
| 206 | +
|
| 207 | + This method should take a single record, update an existing instance, or create a new one, |
| 208 | + and return None. |
| 209 | +
|
| 210 | + Depending on preference, you can use a DRF serializer or dictionary of mappings. |
| 211 | + Similarly, you can append information to the created, updated, and errors list for logging. |
| 212 | +
|
| 213 | + See below for an example using a DRF serializer and logging metrics: |
| 214 | + ``` |
| 215 | + serializer = SerializerClass(data=record) |
| 216 | + if serializer.is_valid(): |
| 217 | + # pop the `id` field from validated data so that it does not attempt to update it |
| 218 | + primary_key = UUID(serializer.validated_data.pop('id')) |
| 219 | + queryset = ModelClass.objects.filter(pk=primary_key) |
| 220 | + instance, created = queryset.update_or_create( |
| 221 | + pk=primary_key, |
| 222 | + defaults=serializer.validated_data, |
| 223 | + ) |
| 224 | + if created: |
| 225 | + self.created_ids.append(str(instance.id)) |
| 226 | + else: |
| 227 | + self.updated_ids.append(str(instance.id)) |
| 228 | + else: |
| 229 | + self.errors.append({ |
| 230 | + 'record': record, |
| 231 | + 'errors': serializer.errors, |
| 232 | + }) |
| 233 | + ``` |
| 234 | + """ |
| 235 | + raise NotImplementedError( |
| 236 | + 'Please override the _process_record method and tailor to your use case.', |
| 237 | + ) |
0 commit comments