Skip to content

Commit bcc8713

Browse files
Add base identification and ingestion tasks
1 parent 79d9986 commit bcc8713

File tree

2 files changed

+513
-0
lines changed

2 files changed

+513
-0
lines changed

datahub/ingest/tasks.py

+237
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,237 @@
1+
import json
2+
import logging
3+
4+
import smart_open
5+
6+
from dateutil import parser
7+
from django.conf import settings
8+
from redis import Redis
9+
from rq import Queue, Worker
10+
from rq.job import Job
11+
12+
from datahub.core.queues.job_scheduler import job_scheduler
13+
from datahub.ingest.boto3 import S3ObjectProcessor
14+
from datahub.ingest.models import IngestedObject
15+
16+
17+
logger = logging.getLogger(__name__)
18+
19+
20+
class QueueChecker:
21+
"""Checks the Redis Queue for specific ingestion task."""
22+
23+
def __init__(self, queue_name: str) -> None:
24+
self.redis = Redis.from_url(settings.REDIS_BASE_URL)
25+
self.queue = Queue(queue_name, connection=self.redis)
26+
27+
def match_job(self, job: Job, ingestion_task_function: callable, object_key: str) -> bool:
28+
"""Determines if the job matches the ingestion task name and object key."""
29+
function_name = f'{ingestion_task_function.__module__}.{ingestion_task_function.__name__}'
30+
return job.func_name == function_name and job.kwargs.get('object_key') == object_key
31+
32+
def is_job_queued(self, ingestion_task_function: callable, object_key: str) -> bool:
33+
"""Check if a job is queued or running."""
34+
# Check queued jobs
35+
if any(
36+
self.match_job(job, ingestion_task_function, object_key) for job in self.queue.jobs
37+
):
38+
return True
39+
# Check running jobs
40+
for worker in Worker.all(queue=self.queue):
41+
job = worker.get_current_job()
42+
if job and self.match_job(job, ingestion_task_function, object_key):
43+
return True
44+
return False
45+
46+
47+
class BaseObjectIdentificationTask:
48+
"""Base class to identify new objects in S3 and determine if they should be ingested.
49+
50+
An example of how this class should be used:
51+
```
52+
def base_identification_task() -> None:
53+
'''Function to be scheduled and called by an RQ worker to identify objects to ingest.'''
54+
logger.info('Base identification task started...')
55+
identification_task = BaseObjectIdentificationTask(prefix='prefix/')
56+
identification_task.identify_new_objects(base_ingestion_task)
57+
logger.info('Base identification task finished.')
58+
```
59+
"""
60+
61+
def __init__(self, prefix: str):
62+
self.long_queue_checker: QueueChecker = QueueChecker(queue_name='long-running')
63+
self.s3_processor: S3ObjectProcessor = S3ObjectProcessor(prefix=prefix)
64+
65+
def identify_new_objects(self, ingestion_task_function: callable) -> None:
66+
"""Entry point method to identify new objects and, if valid, schedule their ingestion."""
67+
latest_object_key = self.s3_processor.get_most_recent_object()
68+
69+
if not latest_object_key:
70+
logger.info('No objects found')
71+
return
72+
73+
if self.long_queue_checker.is_job_queued(
74+
ingestion_task_function, latest_object_key,
75+
):
76+
logger.info(f'{latest_object_key} has already been queued for ingestion')
77+
return
78+
79+
if self.s3_processor.has_object_been_ingested(latest_object_key):
80+
logger.info(f'{latest_object_key} has already been ingested')
81+
return
82+
83+
job_scheduler(
84+
function=ingestion_task_function,
85+
function_kwargs={
86+
'object_key': latest_object_key,
87+
's3_processor': self.s3_processor,
88+
},
89+
queue_name=self.long_queue_checker.queue.name,
90+
description=f'Ingest {latest_object_key}',
91+
)
92+
logger.info(f'Scheduled ingestion of {latest_object_key}')
93+
94+
95+
def base_ingestion_task(
96+
object_key: str,
97+
s3_processor: S3ObjectProcessor,
98+
) -> None:
99+
"""Function to be scheduled by the BaseObjectIdentificationTask.identify_new_objects method.
100+
101+
Once executed by an RQ worker, it will ingest the specified object.
102+
103+
This function serves as an example and is only used in tests of the base classes.
104+
It is not, and should not be, scheduled for execution in `cron-scheduler.py`.
105+
"""
106+
logger.info('Base ingestion task started...')
107+
ingestion_task = BaseObjectIngestionTask(
108+
object_key=object_key,
109+
s3_processor=s3_processor,
110+
)
111+
ingestion_task.ingest_object()
112+
logger.info('Base ingestion task finished.')
113+
114+
115+
class BaseObjectIngestionTask:
116+
"""Base class to ingest a specified object from S3."""
117+
118+
def __init__(
119+
self,
120+
object_key: str,
121+
s3_processor: S3ObjectProcessor,
122+
) -> None:
123+
self.object_key = object_key
124+
self.s3_processor = s3_processor
125+
self.last_ingestion_datetime = self.s3_processor.get_last_ingestion_datetime()
126+
self.skipped_counter = 0
127+
self.created_ids = []
128+
self.updated_ids = []
129+
self.errors = []
130+
131+
def ingest_object(self) -> None:
132+
"""Process all records in the object key specified when the class instance was created."""
133+
try:
134+
with smart_open.open(
135+
f's3://{self.s3_processor.bucket}/{self.object_key}',
136+
) as s3_object:
137+
for line in s3_object:
138+
deserialized_line = json.loads(line)
139+
record = self._get_record_from_line(deserialized_line)
140+
if self._should_process_record(record):
141+
self._process_record(record)
142+
else:
143+
self.skipped_counter += 1
144+
except Exception as e:
145+
logger.error(f'An error occurred trying to process {self.object_key}: {str(e)}')
146+
raise e
147+
148+
# Record ingestion
149+
last_modified = self.s3_processor.get_object_last_modified_datetime(self.object_key)
150+
IngestedObject.objects.create(object_key=self.object_key, object_created=last_modified)
151+
logger.info(f'{self.object_key} ingested.')
152+
153+
# Log metrics
154+
if self.created_ids:
155+
logger.info(
156+
f'{len(self.created_ids)} records created: {self.created_ids}',
157+
)
158+
if self.updated_ids:
159+
logger.info(
160+
f'{len(self.updated_ids)} records updated: {self.updated_ids}',
161+
)
162+
if self.errors:
163+
logger.warning(f'{len(self.errors)} records failed validation: {self.errors}')
164+
if self.skipped_counter:
165+
logger.info(
166+
f'{self.skipped_counter} records skipped.',
167+
)
168+
169+
def _get_record_from_line(self, deserialized_line: dict) -> dict:
170+
"""Extracts the record from the deserialized line.
171+
172+
This method should be overridden if the record is nested.
173+
"""
174+
return deserialized_line
175+
176+
def _should_process_record(self, record: dict) -> bool:
177+
"""Determine if a record should be processed.
178+
179+
This method uses the incoming record's last modified date to check
180+
whether the record should be processed. If the incoming data has a
181+
similar field, please override the `_get_modified_datetime_str` method to
182+
specify which field should be used.
183+
184+
If the record has no modified (or similar) field, please override this
185+
method to set the desired rules to determine if it's processed or not.
186+
"""
187+
if self.last_ingestion_datetime is None:
188+
return True
189+
try:
190+
modified_datetime_str = self._get_modified_datetime_str(record)
191+
modified_datetime = parser.parse(modified_datetime_str)
192+
except ValueError as e:
193+
logger.error(
194+
f'An error occurred determining the last modified datetime: {str(e)}',
195+
)
196+
# If unable to parse datetime string, assume record should be processed.
197+
return True
198+
return modified_datetime.timestamp() >= self.last_ingestion_datetime.timestamp()
199+
200+
def _get_modified_datetime_str(self, record: dict) -> str:
201+
"""Gets the last modified datetime string from the incoming record."""
202+
return record['modified']
203+
204+
def _process_record(self, record: dict) -> None:
205+
"""Processes a single record.
206+
207+
This method should take a single record, update an existing instance, or create a new one,
208+
and return None.
209+
210+
Depending on preference, you can use a DRF serializer or dictionary of mappings.
211+
Similarly, you can append information to the created, updated, and errors list for logging.
212+
213+
See below for an example using a DRF serializer and logging metrics:
214+
```
215+
serializer = SerializerClass(data=record)
216+
if serializer.is_valid():
217+
# pop the `id` field from validated data so that it does not attempt to update it
218+
primary_key = UUID(serializer.validated_data.pop('id'))
219+
queryset = ModelClass.objects.filter(pk=primary_key)
220+
instance, created = queryset.update_or_create(
221+
pk=primary_key,
222+
defaults=serializer.validated_data,
223+
)
224+
if created:
225+
self.created_ids.append(str(instance.id))
226+
else:
227+
self.updated_ids.append(str(instance.id))
228+
else:
229+
self.errors.append({
230+
'record': record,
231+
'errors': serializer.errors,
232+
})
233+
```
234+
"""
235+
raise NotImplementedError(
236+
'Please override the _process_record method and tailor to your use case.',
237+
)

0 commit comments

Comments
 (0)