diff --git a/README.md b/README.md index 779f87e..35d5466 100644 --- a/README.md +++ b/README.md @@ -43,7 +43,8 @@ from drf_multiple_model.views import ObjectMultipleModelAPIView * pagination * Filtering -- either per queryset or on all querysets * custom model labeling - +* django-filters support +* For full configuration options, filtering tools, and more, see [the documentation](https://django-rest-multiple-models.readthedocs.org/en/latest/). # Basic Usage @@ -68,11 +69,21 @@ class PlaySerializer(serializers.ModelSerializer): class Meta: model = Play fields = ('genre','title','pages') - + +class PlayFilter(django_filters.FilterSet): + class Meta: + model = Play + fields = ['genre', 'pages'] + class PoemSerializer(serializers.ModelSerializer): class Meta: model = Poem fields = ('title','stanzas') + +class PoemFilter(django_filters.FilterSet): + class Meta: + model = Poem + fields = ['style', 'lines', 'stanzas'] ``` Then you might use the `ObjectMultipleModelAPIView` as follows: @@ -83,8 +94,8 @@ from drf_multiple_model.views import ObjectMultipleModelAPIView class TextAPIView(ObjectMultipleModelAPIView): querylist = [ - {'queryset': Play.objects.all(), 'serializer_class': PlaySerializer}, - {'queryset': Poem.objects.filter(style='Sonnet'), 'serializer_class': PoemSerializer}, + {'queryset': Play.objects.all(), 'serializer_class': PlaySerializer, 'filterset_class':PlayFilter}, + {'queryset': Poem.objects.filter(style='Sonnet'), 'serializer_class': PoemSerializer,'filterset_class':PoemFilter}, .... ] ``` diff --git a/drf_multiple_model/mixins.py b/drf_multiple_model/mixins.py index 10df9c5..0b523c8 100644 --- a/drf_multiple_model/mixins.py +++ b/drf_multiple_model/mixins.py @@ -1,4 +1,5 @@ import warnings +from copy import deepcopy from django.core.exceptions import ValidationError from django.db.models.query import QuerySet @@ -9,20 +10,20 @@ class BaseMultipleModelMixin(object): """ Base class that holds functions need for all MultipleModelMixins/Views """ + querylist = None # Keys required for every item in a querylist - required_keys = ['queryset', 'serializer_class'] + required_keys = ["queryset", "serializer_class"] # default pagination state. Gets overridden if pagination is active is_paginated = False + default_filterset_class = None def get_querylist(self): assert self.querylist is not None, ( - '{} should either include a `querylist` attribute, ' - 'or override the `get_querylist()` method.'.format( - self.__class__.__name__ - ) + "{} should either include a `querylist` attribute, " + "or override the `get_querylist()` method.".format(self.__class__.__name__) ) return self.querylist @@ -36,8 +37,8 @@ def check_query_data(self, query_data): for key in self.required_keys: if key not in query_data: raise ValidationError( - 'All items in the {} querylist attribute should contain a ' - '`{}` key'.format(self.__class__.__name__, key) + "All items in the {} querylist attribute should contain a " + "`{}` key".format(self.__class__.__name__, key) ) def load_queryset(self, query_data, request, *args, **kwargs): @@ -46,17 +47,17 @@ def load_queryset(self, query_data, request, *args, **kwargs): built-in rest_framework filters and custom filters passed into the querylist """ - queryset = query_data.get('queryset', []) + queryset = query_data.get("queryset", []) + filterset_class = query_data.get("filterset_class", None) if isinstance(queryset, QuerySet): # Ensure queryset is re-evaluated on each request. queryset = queryset.all() - # run rest_framework filters - queryset = self.filter_queryset(queryset) + queryset = self.filter_queryset_custom(queryset, filterset_class) # run custom filters - filter_fn = query_data.get('filter_fn', None) + filter_fn = query_data.get("filter_fn", None) if filter_fn is not None: queryset = filter_fn(queryset, request, *args, **kwargs) @@ -65,6 +66,23 @@ def load_queryset(self, query_data, request, *args, **kwargs): return page if page is not None else queryset + def filter_queryset_custom(self, queryset, filterset_class=None): + + old_filterset_class = getattr(self, "filterset_class", None) + for backend in list(self.filter_backends): + + try: + from django_filters.rest_framework import DjangoFilterBackend + + if issubclass(backend, DjangoFilterBackend): + self.filterset_class = filterset_class + except ImportError: + pass + + queryset = backend().filter_queryset(self.request, queryset, self) + self.filterset_class = old_filterset_class + return queryset + def get_empty_results(self): """ Because the base result type is different depending on the return structure @@ -72,8 +90,8 @@ def get_empty_results(self): `results` variable to the proper type """ assert self.result_type is not None, ( - '{} must specify a `result_type` value or overwrite the ' - '`get_empty_result` method.'.format(self.__class__.__name__) + "{} must specify a `result_type` value or overwrite the " + "`get_empty_result` method.".format(self.__class__.__name__) ) return self.result_type() @@ -84,10 +102,8 @@ def add_to_results(self, data, label, results): data from this queryset/serializer combo """ raise NotImplementedError( - '{} must specify how to add data to the running results tally ' - 'by overriding the `add_to_results` method.'.format( - self.__class__.__name__ - ) + "{} must specify how to add data to the running results tally " + "by overriding the `add_to_results` method.".format(self.__class__.__name__) ) def format_results(self, results, request): @@ -109,7 +125,9 @@ def list(self, request, *args, **kwargs): # Run the paired serializer context = self.get_serializer_context() - data = query_data['serializer_class'](queryset, many=True, context=context).data + data = query_data["serializer_class"]( + queryset, many=True, context=context + ).data label = self.get_label(queryset, query_data) @@ -156,6 +174,7 @@ class FlatMultipleModelMixin(BaseMultipleModelMixin): ... ] """ + # Optional keyword to sort flat lasts by given attribute # note that the attribute must by shared by ALL models sorting_field = None @@ -166,14 +185,16 @@ class FlatMultipleModelMixin(BaseMultipleModelMixin): # Django-like model lookups are supported via '__', but you have to be sure that all querysets will return results # with corresponding structure. sorting_fields_map = {} - sorting_parameter_name = 'o' + sorting_parameter_name = "o" # Flag to append the particular django model being used to the data add_model_type = True result_type = list - _list_attribute_error = 'Invalid sorting field. Corresponding data item is a list: {}' + _list_attribute_error = ( + "Invalid sorting field. Corresponding data item is a list: {}" + ) def initial(self, request, *args, **kwargs): """ @@ -182,13 +203,15 @@ def initial(self, request, *args, **kwargs): after original `initial` has been ran in order to make sure that view has all its properties set up. """ super(FlatMultipleModelMixin, self).initial(request, *args, **kwargs) - assert not (self.sorting_field and self.sorting_fields), \ - '{} should either define ``sorting_field`` or ``sorting_fields`` property, not both.' \ - .format(self.__class__.__name__) + assert not ( + self.sorting_field and self.sorting_fields + ), "{} should either define ``sorting_field`` or ``sorting_fields`` property, not both.".format( + self.__class__.__name__ + ) if self.sorting_field: warnings.warn( - '``sorting_field`` property is pending its deprecation. Use ``sorting_fields`` instead.', - DeprecationWarning + "``sorting_field`` property is pending its deprecation. Use ``sorting_fields`` instead.", + DeprecationWarning, ) self.sorting_fields = [self.sorting_field] self._sorting_fields = self.sorting_fields @@ -198,13 +221,13 @@ def get_label(self, queryset, query_data): Gets option label for each datum. Can be used for type identification of individual serialized objects """ - if query_data.get('label', False): - return query_data['label'] + if query_data.get("label", False): + return query_data["label"] elif self.add_model_type: try: return queryset.model.__name__ except AttributeError: - return query_data['queryset'].model.__name__ + return query_data["queryset"].model.__name__ def add_to_results(self, data, label, results): """ @@ -213,7 +236,7 @@ def add_to_results(self, data, label, results): """ for datum in data: if label is not None: - datum.update({'type': label}) + datum.update({"type": label}) results.append(datum) @@ -227,9 +250,9 @@ def format_results(self, results, request): if self._sorting_fields: results = self.sort_results(results) - if request.accepted_renderer.format == 'html': + if request.accepted_renderer.format == "html": # Makes the the results available to the template context by transforming to a dict - results = {'data': results} + results = {"data": results} return results @@ -240,8 +263,8 @@ def _sort_by(self, datum, param, path=None): if not path: path = [] try: - if '__' in param: - root, new_param = param.split('__') + if "__" in param: + root, new_param = param.split("__") path.append(root) return self._sort_by(datum[root], param=new_param, path=path) else: @@ -252,9 +275,9 @@ def _sort_by(self, datum, param, path=None): raise ValidationError(self._list_attribute_error.format(param)) return data except TypeError: - raise ValidationError(self._list_attribute_error.format('.'.join(path))) + raise ValidationError(self._list_attribute_error.format(".".join(path))) except KeyError: - raise ValidationError('Invalid sorting field: {}'.format('.'.join(path))) + raise ValidationError("Invalid sorting field: {}".format(".".join(path))) def prepare_sorting_fields(self): """ @@ -264,22 +287,26 @@ def prepare_sorting_fields(self): if self.sorting_parameter_name in self.request.query_params: # Extract sorting parameter from query string self._sorting_fields = [ - _.strip() for _ in self.request.query_params.get(self.sorting_parameter_name).split(',') + _.strip() + for _ in self.request.query_params.get( + self.sorting_parameter_name + ).split(",") ] if self._sorting_fields: # Create a list of sorting parameters. Each parameter is a tuple: (field:str, descending:bool) self._sorting_fields = [ - (self.sorting_fields_map.get(field.lstrip('-'), field.lstrip('-')), field[0] == '-') + ( + self.sorting_fields_map.get(field.lstrip("-"), field.lstrip("-")), + field[0] == "-", + ) for field in self._sorting_fields ] def sort_results(self, results): for field, descending in reversed(self._sorting_fields): results = sorted( - results, - reverse=descending, - key=lambda x: self._sort_by(x, field) + results, reverse=descending, key=lambda x: self._sort_by(x, field) ) return results @@ -314,6 +341,7 @@ class ObjectMultipleModelMixin(BaseMultipleModelMixin): ... } """ + result_type = dict def add_to_results(self, data, label, results): @@ -326,10 +354,10 @@ def get_label(self, queryset, query_data): Gets option label for each datum. Can be used for type identification of individual serialized objects """ - if query_data.get('label', False): - return query_data['label'] + if query_data.get("label", False): + return query_data["label"] try: return queryset.model.__name__ except AttributeError: - return query_data['queryset'].model.__name__ + return query_data["queryset"].model.__name__