The basic setup
Suppose you have these models:
from django.db import models
class Category(models.Model):
name = models.CharField(max_length=100)
class Blogpost(models.Model):
title = models.CharField(max_length=100)
categories = models.ManyToManyField(Category)
Suppose you hook these up Django REST Framework and list all Blogpost
items. Something like this:
from rest_framework import routers
from . import views
router = routers.DefaultRouter()
router.register(r'blogposts', views.BlogpostViewSet)
from rest_framework import viewsets
class BlogpostViewSet(viewsets.ModelViewSet):
queryset = Blogpost.objects.all().order_by('date')
serializer_class = serializers.BlogpostSerializer
What's the problem?
Then, if you execute this list (e.g. curl http://localhost:8000/api/blogposts/
) what will happen, on the database, is something like this:
SELECT "app_blogpost"."id", "app_blogpost"."title" FROM "app_blogpost";
SELECT "app_category"."id", "app_category"."name" FROM "app_category" INNER JOIN "app_blogpost_categories" ON ("app_category"."id" = "app_blogpost_categories"."category_id") WHERE "app_blogpost_categories"."blogpost_id" = 1025;
SELECT "app_category"."id", "app_category"."name" FROM "app_category" INNER JOIN "app_blogpost_categories" ON ("app_category"."id" = "app_blogpost_categories"."category_id") WHERE "app_blogpost_categories"."blogpost_id" = 193;
SELECT "app_category"."id", "app_category"."name" FROM "app_category" INNER JOIN "app_blogpost_categories" ON ("app_category"."id" = "app_blogpost_categories"."category_id") WHERE "app_blogpost_categories"."blogpost_id" = 757;
SELECT "app_category"."id", "app_category"."name" FROM "app_category" INNER JOIN "app_blogpost_categories" ON ("app_category"."id" = "app_blogpost_categories"."category_id") WHERE "app_blogpost_categories"."blogpost_id" = 853;
SELECT "app_category"."id", "app_category"."name" FROM "app_category" INNER JOIN "app_blogpost_categories" ON ("app_category"."id" = "app_blogpost_categories"."category_id") WHERE "app_blogpost_categories"."blogpost_id" = 1116;
SELECT "app_category"."id", "app_category"."name" FROM "app_category" INNER JOIN "app_blogpost_categories" ON ("app_category"."id" = "app_blogpost_categories"."category_id") WHERE "app_blogpost_categories"."blogpost_id" = 1126;
SELECT "app_category"."id", "app_category"."name" FROM "app_category" INNER JOIN "app_blogpost_categories" ON ("app_category"."id" = "app_blogpost_categories"."category_id") WHERE "app_blogpost_categories"."blogpost_id" = 964;
SELECT "app_category"."id", "app_category"."name" FROM "app_category" INNER JOIN "app_blogpost_categories" ON ("app_category"."id" = "app_blogpost_categories"."category_id") WHERE "app_blogpost_categories"."blogpost_id" = 591;
SELECT "app_category"."id", "app_category"."name" FROM "app_category" INNER JOIN "app_blogpost_categories" ON ("app_category"."id" = "app_blogpost_categories"."category_id") WHERE "app_blogpost_categories"."blogpost_id" = 1112;
SELECT "app_category"."id", "app_category"."name" FROM "app_category" INNER JOIN "app_blogpost_categories" ON ("app_category"."id" = "app_blogpost_categories"."category_id") WHERE "app_blogpost_categories"."blogpost_id" = 1034;
...
Obviously, it depends on how you define that serializers.BlogpostSerializer
class, but basically, as it loops over the Blogpost
, for each and every one, it needs to make a query to the many-to-many table (app_blogpost_categories
in this example).
That's not going to be performant. In fact, it might be dangerous on your database if the query of blogposts
gets big, like requesting a 100 or 1,000 records. Fetching 1,000 rows from the app_blogpost
table might be cheap'ish but doing 1,000 selects with JOIN
is never going to be cheap. It adds up horribly.
How you solve it
The trick is to only do 1 query on the many-to-many field's table, 1 query on the app_blogpost
table and 1 query on the app_category
table.
First you have to override the ViewSet.list
method. Then, in there you can do exactly what you need.
Here's the framework for this change:
from rest_framework import viewsets
class BlogpostViewSet(viewsets.ModelViewSet):
serializer_class = serializers.BlogpostSerializer
def get_queryset(self):
Blogpost.objects.all().order_by('date')
def list(self, request, *args, **kwargs):
response = super().list(request, *args, **kwargs)
return response
Next, we need to make a mapping of all Category.id -1-> Category.name
. But we want to make sure we do only on the categories that are involved in the Blogpost
records that matter. You could do something like this:
category_names = {}
for category in Category.objects.all():
category_names[category.id] = category.name
But to avoid doing a lookup of category names for those you never need, use the query set on Blogpost
. I.e.
qs = self.get_queryset()
all_categories = Category.objects.filter(
id__in=Blogpost.categories.through.objects.filter(
blogpost__in=qs
).values('category_id')
)
category_names = {}
for category in all_categories:
category_names[category.id] = category.name
Now you have a dictionary of all the Category IDs that matter.
Note! The above "optimization" assumes that it's worth it. Meaning, if the number of Category
records in your database is huge, and the Blogpost
queryset is very filtered, then it's worth only extracting a subset. Alternatively, if you only have like 100 different categories in your database, just do the first variant were you look them up "simplestly" without any fancy joins.
Next, is the mapping of Blogpost.id -N-> Category.name
. To do that you need to build up a dictionary (int to list of strings). Like this:
categories_map = defaultdict(list)
for m2m in Blogpost.categories.through.objects.filter(blogpost__in=qs):
categories_map[m2m.blogpost_id].append(
category_names[m2m.category_id]
)
So what we have now is a dictionary whose keys are the IDs in self.get_queryset()
and each value is a list of a strings. E.g. ['Category X', 'Category Z']
etc.
Lastly, we need to put these back into the serialized response. This feels a little hackish but it works:
for each in response.data:
each['categories'] = categories_map.get(each['id'], [])
The whole solution looks something like this:
from rest_framework import viewsets
class BlogpostViewSet(viewsets.ModelViewSet):
serializer_class = serializers.BlogpostSerializer
def get_queryset(self):
Blogpost.objects.all().order_by('date')
def list(self, request, *args, **kwargs):
response = super().list(request, *args, **kwargs)
qs = self.get_queryset()
all_categories = Category.objects.filter(
id__in=Blogpost.categories.through.objects.filter(
blogpost__in=qs
).values('category_id')
)
category_names = {}
for category in all_categories:
category_names[category.id] = category.name
categories_map = defaultdict(list)
for m2m in Blogpost.categories.through.objects.filter(blogpost__in=qs):
categories_map[m2m.blogpost_id].append(
category_names[m2m.category_id]
)
for each in response.data:
each['categories'] = categories_map.get(each['id'], [])
return response
It's arguably not very pretty but doing 3 tight queries instead of doing as many queries as you have records is much better. O(c)
is better than O(n)
.
Discussion
Perhaps the best solution is to not run into this problem. Like, don't serialize any many-to-many fields.
Or, if you use pagination very conservatively, and only allow like 10 items per page then it won't be so expensive to do one query per every many-to-many field.