Tomislav's blog

Better living through optimized Django

Every engineer that loves Django and has a blog has at least one of these posts.

Django's ORM is excellent, but given enough time it's easy for approaches that weren't mistakes to grow into mistakes This is a great thing, because it usually means your company didn't go bankrupt, you're still here and can fix things, and the company is doing well because the scale increased (hopefully your compensation as well).

This is a recap of my recent experience into optimizing Celery tasks that started out as non-problematic, but with the passage of time became problematic, causing server and database stability issues.

prefetch_related + iterator() problem

Up until Django 4.1, calling iterator() on a queryset with prefetch_related() caused the prefetched data to be dropped, causing N+1 queries problem.

In Django 4.1 the iterator started accepting batch_size argument that allows us to get the best of both worlds -- avoid pulling the entire dataset into memory while avoiding the N+1 queries problem. Or at least turn the N+1 into N / batch_size + 1, which is considerably better. But the batch_size's default value is None which reverts to old behavior, discarding the prefetched data silently.

The pseudocode from the problematic task looked something like this:

profiles = Profile.objects.filter(
    community=community
).select_related(
    a bunch of joins here
).prefetch_related(
    'user__groups', 
    a bunch of other m2ms
)

BATCH_SIZE = 1000
records = []
for profile in profiles.iterator():
    group_names = [g.name for g in profile.user.groups.all()]
    primary_group = determine_primary_group(group_names)
    denorm_profile = DenormProfileRecord(
        user_id=profile.user_id,
        primary_group=primary_group,
        groups=group_names,
    )
    records.append(denorm_profile)
    if len(records) == BATCH_SIZE:
        DenormProfileRecord.objects.bulk_create(records)
        records = []
DenormProfileRecord.objects.bulk_create(records)

This task took hours, and used up a lot of database resources. When it was originally written, there were not that many profiles and there were not many other tasks demanding database resources so it worked well. But, we were lucky: people registered in increasing numbers, other business processes took their own chunk of database's resources, and this really became a bottleneck.

My first optimization attempt was:

  1. Increase the batch size to 10k, to reduce the number of bulk_create () calls
  2. Use only() on the Profile queryset to avoid fetching unnecessary data
  3. Since I believed only() reduced the memory requirements enough, I dropped the iterator() to enable prefetch_related() to do its thing

This optimization turned out to be ... less than optimal.

One week and one server and database outage later, I was forced to revisit my optimization.

My first failing was in not examining the entire context in which this code is run. It's part of a Celery task executed by a Celery worker with --concurrency=4, meaning that it's possible that we try to refresh 4 big communities at the same time.

Second, I failed to account for some communities having 100s of thousands of profiles. Removing the iterator() call means all of these profiles are loaded into memory at once.

Third, I underestimated the difference in memory consumption between Python models instances (which are still constructed when you use only()) and Python built-in types.

The second optimization attempt was:

  1. Recognizing that we don't really need model instances from the prefetched relations, we just need certain values -- we can get much better performance by using PostgreSQL-specific ArrayAgg
  2. only() only marginally reduced the memory footprint and since we don't actually need the Profile model instances, we can get a huge benefit from enumerating all required fields in a values_list() call and avoid constructing the model instances completely

The final version looked something like:

profiles = Profile.objects.filter(
    community=community
).annotate(
    group_names=ArrayAgg(
        'user__groups', 
        filter=Q(user__groups__isnull=False),
        distinct=True, 
        default=[]
    ),
    ...other values that can be aggreged as well,
).values_list(
    'user_id',
    'group_names',
    ...all other fields that were necessary for DenormProfileRecord
    named=True
)

BATCH_SIZE = 10000
records = []
for profile in profiles.iterator():
    primary_group = determine_primary_group(profile.group_names)
    denorm_profile = DenormProfileRecord(
        user_id=profile.user_id,
        primary_group=primary_group,
        groups=profile.group_names,
    )
    records.append(denorm_profile)
    if len(records) == BATCH_SIZE:
        DenormProfileRecord.objects.bulk_create(records)
        records = []
DenormProfileRecord.objects.bulk_create(records)

This way we could keep the iterator() call since there were no prefetch_related() calls. The values_list() optimization wasn't actually necessary because we only had a single row of Profile data in memory at the same time, but I kept it just in case.

This reduced the memory strain on the server from "fills up the RAM and swap and causes OOM killer to go on a rampage" to "unnoticeable". The runtime dropped from several hours to ~30s.

Optimizing Redis access

This one isn't really Django ORM related, but it was done in the same batch of optimizations so I'll touch on it. We utilize Redis to keep a sorted set of video name prefixes allowing live autocomplete while typing in the search box on the site.

Populating this Redis sorted set is done a daily basis: it's completely dropped and recreated from scratch.

The function looks something like:

from redis.client import Redis

AUTOCOMPLETE_REDIS_KEY = 'autocomplete'
MAX_PREFIX_LENGHT = 8


def update_autocomplete_prefixes():
    redis = Redis.from_url(REDIS_CACHE_LOCATION)
    redis.delete(AUTOCOMPLETE_REDIS_KEY)

    for video_title in Video.objects.values_list('title', flat=True):
        title = strip(video_title.lower())
        for i range(1, MAX_PREFIX_LENGHT + 1):
            redis.zadd({title[0:i]: 0})

Upon investigating the used redis library, it turns out each zadd() call is a single network request. As the number of videos grew, the number of network requests grew as well, until this task took about 15 minutes to complete.

The optimization approach here was to collect all Redis updates in a single dictionary and push it in a single network call. This approach also allowed moving the delete call much closer to the single zadd() call, reducing the time where the autocomplete prefixes were only partially available.

One small database related improvement was pushing the strip() and lower() calls to the database utilizing the Trim() and Lower() database functions.

The rewritten task looks something like:

from redis.client import Redis

AUTOCOMPLETE_REDIS_KEY = 'autocomplete'
MAX_PREFIX_LENGHT = 8


def update_autocomplete_prefixes():
    prefixes = {}
    for video_title in Video.objects.annotate(values_list(Trim(Lower('title')), flat=True):
        prefixes |= {video_title[0:i]: 0 for i in range(1, MAX_PREFIX_LENGTH + 1)}

    redis = Redis.from_url(REDIS_CACHE_LOCATION)
    redis.delete(AUTOCOMPLETE_REDIS_KEY)
    redis.zadd(prefixes)

This reduced the runtime of the task from 15 minutes to ~5 seconds.

I also entertained the idea of using memoryviews to avoid constructing new string objects for each prefix, but there were risks associated with handling unicode characters (which were present in the video titles, and memoryviews operate on bytes), and not really being familiar with how the redis Python library handles the passed data (it could quite easily cast these memoryviews back to strings, annuling any gains).

Optimizing deletion of old records

For debugging purposes, we retain a copy of each email sent to our users. Since we like to keep things simple, this data is kept within a table in PostgreSQL, and the old records are purged from the table daily. Retention policy is 2 weeks, so every day there is a Celery task that identifies old records and deletes them:

def remove_old_emails():
    old_mailer_ids = list(
        Mailer.objects.filter(sent__lte=timezone.now() - relativedelta(weeks=2)
    )
    old_emails = Email.objects.filter(mailer_id__in=old_mailer_ids)
    old_email._raw_delete(old_emails.db)

This regularly took 20-30 minutes, even with the _raw_delete optimization. The table is not humongous, it definitely should not take that long.

For historical reasons (when the table was humongous and a join would kill the database) the table doesn't have any foreign key constraints, and all other tables are referenced through soft foreign keys (e.g. mailer_id is an integer column with an index, without a foreign key constraint). In the meantime we introduced a data retention policy to manage the table's size.

The problematic part quickly emerged upon inspecting the SQL query: the list of old mailer IDs has 100k members, and is growing daily since it's a list of all mailers ever sent. This makes Postgres' life hard, and degrades every query to a full table sequential scan.

The solution is clear: reduce the list of mailer IDs to something manageable. Since the task is run daily, it's safe to reduce the list to mailers that were sent between 2 weeks ago and 2 weeks and 1 day ago. We want to have some redundancy so we increased that range to 2 weeks an 3 days ago, in case something prevents the task from running for a day or two.

Postgres started using an index scan instead of a sequential scan, and things sped up drastically -- the runtime dropped from 20-30 minutes to 3-5 minutes.

Optimizing creation of new many-to-many records

Using Model.objects.create() in a for loop is a surefire way to degrade the performance of you database -- every create() call is a network request to the database with an INSERT command.

for_date = date(2024, 8, 24)
impression_data = Impression.objects.filter(
    created__date=for_date
).values(
    'content_type', 'object_id'
).annotate(total_impressions=Count('id'))
for content_impression in impression_data:
    DailyImpressionStats.objects.create(**content_impression)

One of basic ways to improve performance, if memory allows it, is to accumulate unsaved model instances in memory and then create them all at once using a bulk_create call:

for_date = date(2024, 8, 24)
impression_data = Impression.objects.filter(
    created__date=for_date
).values(
    'content_type', 'object_id'
).annotate(total_impressions=Count('id'))
daily_impressions_list = []
for content_impression in impression_data:
    daily_impressions_list.append(DailyImpressionStats(**content_impression))

DailyImpressionStats.objects.bulk_create(daily_impressions_list)

This is a fairly common optimization, but here's a follow-up problem: what if we also have to set a many-to-many relationship on the model we want to bulk create? At first, it seems like we cannot use the bulk_create() approach anymore:

for content_impression in impression_data:
    daily_impression = DailyImpressionStats(**content_impression)
    tags = get_tags_for_content(
        content_type=content_impression['content_type_id'], 
        object_id=content_impression['object_id'],
    )
    daily_impression.tags.set(tags)  # ERROR: daily_impression has to be saved before we can set the M2M relationship

Luckily, digging a bit deeper into how Django implements the many to many relationship offers an answer. When we define a many to many relationship between models, Django creates an intermediary table with an accompanying through model accessible via Model.m2m_field.through. This allows us to also accumulate the through model instances in another list and bulk create them as well.

If the id field is declared as an AutoField, PostgreSQL, MariaDB and SQLite set the id field on model instances when using bulk_create(). Our many to many records reference these instances so we can first bulk create the model instances, and then bulk create the many to many instances:

for_date = date(2024, 8, 24)
impression_data = Impression.objects.filter(
    created__date=for_date
).values(
    'content_type', 'object_id'
).annotate(total_impressions=Count('id'))
daily_impressions_list = []
daily_impressions_tag_list = []
for content_impression in impression_data:
    daily_impression = DailyImpressionStats(**content_impression)
    daily_impressions_list.append(daily_impression)
    tags = get_tags_for_content(
        content_type=content_impression['content_type_id'],
        object_id=content_impression['object_id'],
    )
    daily_impressions_tag_list.extend(
        DailyImpressionStats.tags.through(
            daily_impression=daily_impression, tag=tag
        ) 
        for tag in tags
    )

DailyImpressionStats.objects.bulk_create(daily_impressions_list)
DailyImpressionStats.tags.through.objects.bulk_create(daily_impressions_tag_list)

Epilogue

It's easy to dismiss the problems with the original code snippets as "skill issues", and sometimes they really are. But we need to keep in mind that equally often code starts out as performant and ends up as a bottleneck. Conditions change, scale increases, tech stack evolves. If you tried to do some of the described optimizations in the initial code push, I would probably be the first one to invoke YAGNI and ask for a simplification. Business first, tech second -- every minute you spend on optimizing a query means the business might not make enough money and not live to see the day your optimization pays off.

It's not important to write the optimal code in the first go, it's important to be able to write it once it becomes problematic, and in the mean time, build things that matter.