from rest_framework import serializers

from . import PostingEndpoint, PostingMiddleware


class ProtectMiddleware(PostingMiddleware):
    def use_this_middleware(self):
        return self.mode == PostingEndpoint.EDIT

    def get_serializer(self):
        return ProtectSerializer(data=self.request.data)

    def post_save(self, serializer):
        if self.thread.category.acl['can_protect_posts']:
            try:
                self.post.is_protected = serializer.validated_data.get('protect', False)
                self.post.update_fields.append('is_protected')
            except (TypeError, ValueError):
                pass


class ProtectSerializer(serializers.Serializer):
    protect = serializers.BooleanField(required=False, default=False)