-
Notifications
You must be signed in to change notification settings - Fork 10
/
Copy pathtasks.py
274 lines (225 loc) · 8.53 KB
/
tasks.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
from logging import getLogger
from django.apps import apps
from django.db.models import Exists, NOT_PROVIDED, OuterRef, Subquery
from django.db.transaction import atomic
from django_pglocks import advisory_lock
from datahub.company.models import Company, CompanyExportCountry
from datahub.core.queues.job_scheduler import job_scheduler
logger = getLogger(__name__)
def replace_null_with_default(model_label, field_name, default=None, batch_size=5000):
"""
Task that replaces NULL values for a model field with the default argument if specified
or the field's default value otherwise.
This is designed to perform updates in small batches to avoid lengthy locks on a large
number of rows.
"""
model = apps.get_model(model_label)
field = model._meta.get_field(field_name)
resolved_default = default # so that the input is not changed
if resolved_default is None:
if field.default in (NOT_PROVIDED, None):
raise ValueError(f'{field_name} does not have a non-null default value')
resolved_default = field.default
if callable(resolved_default):
raise ValueError(f'callable defaults for {field_name} are not supported')
if not field.null:
raise ValueError(f'{field_name} is not nullable')
# Unevaluated subquery to select a batch of rows
subquery = model.objects.filter(
**{field_name: None},
).values(
'pk',
)[:batch_size]
# Update the batch of rows to use the default value instead
num_updated = model.objects.filter(
pk__in=Subquery(subquery),
).update(
**{field_name: resolved_default},
)
logger.info(
f'NULL replaced with {resolved_default!r} for {num_updated} objects, model {model_label}, '
f'field {field_name}',
)
# If there are definitely no more rows needing updating, return
if num_updated < batch_size:
return
# Schedule another task to update another batch of rows
job = job_scheduler(
function=replace_null_with_default,
function_args=(model_label, field_name),
function_kwargs={'default': default, 'batch_size': batch_size},
)
logger.info(f'Task {job.id} replace_null_with_default')
def copy_foreign_key_to_m2m_field(
model_label,
source_fk_field_name,
target_m2m_field_name,
batch_size=5000,
):
"""
Task that copies non-null values from a foreign key to a to-many field (for objects where the
to-many field is empty).
Usage example:
copy_foreign_key_to_m2m_field('interaction.Interaction', 'contact', 'contacts')
Note: This does not create reversion revisions on the model referenced by model_label. For new
fields, the new versions would simply show the new field being added, so would not be
particularly useful. If you do need revisions to be created, this task is not suitable.
"""
lock_name = (
f'leeloo-copy_foreign_key_to_m2m_field-{model_label}-{source_fk_field_name}'
f'-{target_m2m_field_name}'
)
with advisory_lock(lock_name, wait=False) as lock_held:
if not lock_held:
logger.warning(
f'Another copy_foreign_key_to_m2m_field task is in progress for '
f'({model_label}, {source_fk_field_name}, {target_m2m_field_name}). Aborting...',
)
return
num_processed = _copy_foreign_key_to_m2m_field(
model_label,
source_fk_field_name,
target_m2m_field_name,
batch_size=batch_size,
)
# If there are definitely no more rows needing processing, return
if num_processed < batch_size:
return
# Schedule another task to update another batch of rows.
#
# This must be outside of the atomic block, otherwise it will probably run before the
# current changes have been committed.
#
# (Similarly, the lock should also be released before the next task is scheduled.)
job = job_scheduler(
function=copy_foreign_key_to_m2m_field,
function_args=(model_label, source_fk_field_name, target_m2m_field_name),
function_kwargs={'batch_size': batch_size},
)
logger.info(f'Task {job.id} copy_foreign_key_to_m2m_field')
@atomic
def _copy_foreign_key_to_m2m_field(
model_label,
source_fk_field_name,
target_m2m_field_name,
batch_size=5000,
):
"""
The main logic for the copy_foreign_key_to_m2m_field task.
Processes a single batch in a transaction.
"""
model = apps.get_model(model_label)
source_fk_field = model._meta.get_field(source_fk_field_name)
target_m2m_field = model._meta.get_field(target_m2m_field_name)
m2m_model = target_m2m_field.remote_field.through
# e.g. 'interaction_id' for Interaction.contacts
m2m_column_name = target_m2m_field.m2m_column_name()
# e.g. 'contact_id' for Interaction.contacts
m2m_reverse_column_name = target_m2m_field.m2m_reverse_name()
has_no_m2m_values_subquery = ~Exists(
m2m_model.objects.filter(**{m2m_column_name: OuterRef('pk')}),
)
# Select a batch of rows. The rows are locked to avoid race conditions.
batch_queryset = model.objects.select_for_update().filter(
has_no_m2m_values_subquery,
**{
f'{source_fk_field_name}__isnull': False,
},
).values(
'pk',
source_fk_field.attname,
)[:batch_size]
objects_to_create = [
m2m_model(
**{
m2m_column_name: row['pk'],
m2m_reverse_column_name: row[source_fk_field.attname],
},
) for row in batch_queryset
]
# Create many-to-many objects for the batch
created_objects = m2m_model.objects.bulk_create(objects_to_create)
num_created = len(created_objects)
logger.info(
f'{num_created} {model_label}.{target_m2m_field_name} many-to-many objects created',
)
return len(objects_to_create)
def copy_export_countries_to_company_export_country_model(
status,
batch_size=5000,
):
"""
Task that copies all export countries from Company model to CompanyExportCountry
"""
key_switch = {
'future_interest': 'future_interest_countries',
'currently_exporting': 'export_to_countries',
}
num_updated = _copy_export_countries(key_switch[status], status, batch_size)
# If there are definitely no more rows needing processing, return
if num_updated < batch_size:
return
job = job_scheduler(
function=copy_export_countries_to_company_export_country_model,
function_kwargs={
'batch_size': batch_size,
'status': status,
},
)
logger.info(f'Task {job.id} copy_export_countries_to_company_export_country_model')
@atomic
def _copy_export_countries(key, status, batch_size):
"""
Main logic for copying export companies from Company model to
CompanyExportCountry one
"""
export_countries = _get_company_countries(key, status, batch_size)
num_updated = _copy_company_countries(
key,
export_countries,
status,
)
logger.info(
f'Company.{key} copied to CompanyExportCountry '
f'for {num_updated} Company export countries',
)
return num_updated
def _get_company_countries(source_field, status, batch_size):
no_company_country_subquery = ~Exists(
CompanyExportCountry.objects.filter(
company_id=OuterRef('pk'),
status=status,
),
)
has_existing_old_countries = Exists(
Company.objects.filter(
**{
'pk': OuterRef('pk'),
f'{source_field}__isnull': False,
},
),
)
batch_queryset = Company.objects.select_for_update().filter(
no_company_country_subquery,
has_existing_old_countries,
).only(
'pk',
)
return batch_queryset[:batch_size]
def _copy_company_countries(source_field, company_with_uncopied_countries, status):
company_export_country_model = apps.get_model('company', 'CompanyExportCountry')
num_updated = 0
for company in company_with_uncopied_countries:
num_updated += 1
for country in getattr(company, source_field).all():
export_country, created = company_export_country_model.objects.get_or_create(
company=company,
country=country,
defaults={
'status': status,
},
)
if not created and export_country.status != status:
export_country.status = status
export_country.save()
return num_updated