Browse Source

also exclude disabled users from resend activation/reset password endpoints

Rafał Pitoń 8 years ago
parent
commit
803c51b4bd

+ 6 - 3
misago/users/forms/auth.py

@@ -31,9 +31,10 @@ class MisagoAuthMixin(object):
                 self.error_messages['inactive_user'], code='inactive_user')
                 self.error_messages['inactive_user'], code='inactive_user')
 
 
     def confirm_user_not_banned(self, user):
     def confirm_user_not_banned(self, user):
-        self.user_ban = get_user_ban(user)
-        if self.user_ban:
-            raise ValidationError('', code='banned')
+        if not user.is_staff:
+            self.user_ban = get_user_ban(user)
+            if self.user_ban:
+                raise ValidationError('', code='banned')
 
 
     def get_errors_dict(self):
     def get_errors_dict(self):
         error = self.errors.as_data()['__all__'][0]
         error = self.errors.as_data()['__all__'][0]
@@ -124,6 +125,8 @@ class GetUserForm(MisagoAuthMixin, forms.Form):
         try:
         try:
             User = get_user_model()
             User = get_user_model()
             user = User.objects.get_by_email(data['email'])
             user = User.objects.get_by_email(data['email'])
+            if not user.is_active:
+                raise User.DoesNotExist()
             self.user_cache = user
             self.user_cache = user
         except User.DoesNotExist:
         except User.DoesNotExist:
             raise forms.ValidationError(
             raise forms.ValidationError(

+ 1 - 1
misago/users/middleware.py

@@ -32,7 +32,7 @@ class UserMiddleware(object):
     def process_request(self, request):
     def process_request(self, request):
         if request.user.is_anonymous():
         if request.user.is_anonymous():
             request.user = AnonymousUser()
             request.user = AnonymousUser()
-        elif not request.user.is_superuser:
+        elif not request.user.is_staff:
             if get_request_ip_ban(request) or get_user_ban(request.user):
             if get_request_ip_ban(request) or get_user_ban(request.user):
                 logout(request)
                 logout(request)
 
 

+ 76 - 10
misago/users/tests/test_auth_api.py

@@ -1,5 +1,3 @@
-import json
-
 from django.contrib.auth import get_user_model
 from django.contrib.auth import get_user_model
 from django.core import mail
 from django.core import mail
 from django.test import TestCase
 from django.test import TestCase
@@ -21,7 +19,7 @@ class GatewayTests(TestCase):
         response = self.client.get('/api/auth/')
         response = self.client.get('/api/auth/')
         self.assertEqual(response.status_code, 200)
         self.assertEqual(response.status_code, 200)
 
 
-        user_json = json.loads(smart_str(response.content))
+        user_json = response.json()
         self.assertIsNone(user_json['id'])
         self.assertIsNone(user_json['id'])
 
 
     def test_login(self):
     def test_login(self):
@@ -39,7 +37,7 @@ class GatewayTests(TestCase):
         response = self.client.get('/api/auth/')
         response = self.client.get('/api/auth/')
         self.assertEqual(response.status_code, 200)
         self.assertEqual(response.status_code, 200)
 
 
-        user_json = json.loads(smart_str(response.content))
+        user_json = response.json()
         self.assertEqual(user_json['id'], user.id)
         self.assertEqual(user_json['id'], user.id)
         self.assertEqual(user_json['username'], user.username)
         self.assertEqual(user_json['username'], user.username)
 
 
@@ -65,7 +63,7 @@ class GatewayTests(TestCase):
         })
         })
         self.assertEqual(response.status_code, 400)
         self.assertEqual(response.status_code, 400)
 
 
-        response_json = json.loads(smart_str(response.content))
+        response_json = response.json()
         self.assertEqual(response_json['code'], 'banned')
         self.assertEqual(response_json['code'], 'banned')
         self.assertEqual(response_json['detail']['message']['plain'],
         self.assertEqual(response_json['detail']['message']['plain'],
                          ban.user_message)
                          ban.user_message)
@@ -75,9 +73,36 @@ class GatewayTests(TestCase):
         response = self.client.get('/api/auth/')
         response = self.client.get('/api/auth/')
         self.assertEqual(response.status_code, 200)
         self.assertEqual(response.status_code, 200)
 
 
-        user_json = json.loads(smart_str(response.content))
+        user_json = response.json()
         self.assertIsNone(user_json['id'])
         self.assertIsNone(user_json['id'])
 
 
+    def test_login_banned_staff(self):
+        """login api signs banned staff member in"""
+        User = get_user_model()
+        user = User.objects.create_user('Bob', 'bob@test.com', 'Pass.123')
+
+        user.is_staff = True
+        user.save()
+
+        ban = Ban.objects.create(
+            check_type=BAN_USERNAME,
+            banned_value='bob',
+            user_message='You are tragically banned.',
+        )
+
+        response = self.client.post('/api/auth/', data={
+            'username': 'Bob',
+            'password': 'Pass.123',
+        })
+        self.assertEqual(response.status_code, 200)
+
+        response = self.client.get('/api/auth/')
+        self.assertEqual(response.status_code, 200)
+
+        user_json = response.json()
+        self.assertEqual(user_json['id'], user.id)
+        self.assertEqual(user_json['username'], user.username)
+
     def test_login_inactive_admin(self):
     def test_login_inactive_admin(self):
         """login api fails to sign admin-activated user in"""
         """login api fails to sign admin-activated user in"""
         User = get_user_model()
         User = get_user_model()
@@ -90,13 +115,13 @@ class GatewayTests(TestCase):
         })
         })
         self.assertEqual(response.status_code, 400)
         self.assertEqual(response.status_code, 400)
 
 
-        response_json = json.loads(smart_str(response.content))
+        response_json = response.json()
         self.assertEqual(response_json['code'], 'inactive_user')
         self.assertEqual(response_json['code'], 'inactive_user')
 
 
         response = self.client.get('/api/auth/')
         response = self.client.get('/api/auth/')
         self.assertEqual(response.status_code, 200)
         self.assertEqual(response.status_code, 200)
 
 
-        user_json = json.loads(smart_str(response.content))
+        user_json = response.json()
         self.assertIsNone(user_json['id'])
         self.assertIsNone(user_json['id'])
 
 
     def test_login_inactive_user(self):
     def test_login_inactive_user(self):
@@ -111,13 +136,34 @@ class GatewayTests(TestCase):
         })
         })
         self.assertEqual(response.status_code, 400)
         self.assertEqual(response.status_code, 400)
 
 
-        response_json = json.loads(smart_str(response.content))
+        response_json = response.json()
         self.assertEqual(response_json['code'], 'inactive_admin')
         self.assertEqual(response_json['code'], 'inactive_admin')
 
 
         response = self.client.get('/api/auth/')
         response = self.client.get('/api/auth/')
         self.assertEqual(response.status_code, 200)
         self.assertEqual(response.status_code, 200)
 
 
-        user_json = json.loads(smart_str(response.content))
+        user_json = response.json()
+        self.assertIsNone(user_json['id'])
+
+    def test_login_disabled_user(self):
+        """its impossible to sign in to disabled account"""
+        User = get_user_model()
+        user = User.objects.create_user(
+            'Bob', 'bob@test.com', 'Pass.123', is_active=False)
+
+        user.is_staff = True
+        user.save()
+
+        response = self.client.post('/api/auth/', data={
+            'username': 'Bob',
+            'password': 'Pass.123',
+        })
+        self.assertContains(response, "Login or password is incorrect.", status_code=400)
+
+        response = self.client.get('/api/auth/')
+        self.assertEqual(response.status_code, 200)
+
+        user_json = response.json()
         self.assertIsNone(user_json['id'])
         self.assertIsNone(user_json['id'])
 
 
 
 
@@ -150,6 +196,16 @@ class SendActivationAPITests(TestCase):
 
 
         self.assertIn('Activate Bob', mail.outbox[0].subject)
         self.assertIn('Activate Bob', mail.outbox[0].subject)
 
 
+    def test_submit_disabled(self):
+        """request activation link api fails disabled users"""
+        self.user.is_active = False
+        self.user.save()
+
+        response = self.client.post(self.link, data={'email': self.user.email})
+        self.assertContains(response, 'not_found', status_code=400)
+
+        self.assertTrue(not mail.outbox)
+
     def test_submit_empty(self):
     def test_submit_empty(self):
         """request activation link api errors for no body"""
         """request activation link api errors for no body"""
         response = self.client.post(self.link)
         response = self.client.post(self.link)
@@ -219,6 +275,16 @@ class SendPasswordFormAPITests(TestCase):
 
 
         self.assertIn('Change Bob password', mail.outbox[0].subject)
         self.assertIn('Change Bob password', mail.outbox[0].subject)
 
 
+    def test_submit_disabled(self):
+        """request change password form api fails disabled users"""
+        self.user.is_active = False
+        self.user.save()
+
+        response = self.client.post(self.link, data={'email': self.user.email})
+        self.assertContains(response, 'not_found', status_code=400)
+
+        self.assertTrue(not mail.outbox)
+
     def test_submit_empty(self):
     def test_submit_empty(self):
         """request change password form link api errors for no body"""
         """request change password form link api errors for no body"""
         response = self.client.post(self.link)
         response = self.client.post(self.link)

+ 70 - 41
misago/users/tests/test_bans.py

@@ -14,21 +14,26 @@ class GetBanTests(TestCase):
         nonexistent_ban = get_username_ban('nonexistent')
         nonexistent_ban = get_username_ban('nonexistent')
         self.assertIsNone(nonexistent_ban)
         self.assertIsNone(nonexistent_ban)
 
 
-        Ban.objects.create(banned_value='expired',
-                           expires_on=timezone.now() - timedelta(days=7))
+        Ban.objects.create(
+            banned_value='expired',
+            expires_on=timezone.now() - timedelta(days=7)
+        )
 
 
         expired_ban = get_username_ban('expired')
         expired_ban = get_username_ban('expired')
         self.assertIsNone(expired_ban)
         self.assertIsNone(expired_ban)
 
 
-        Ban.objects.create(banned_value='wrongtype',
-                           check_type=BAN_EMAIL)
+        Ban.objects.create(
+            banned_value='wrongtype',
+            check_type=BAN_EMAIL
+        )
 
 
         wrong_type_ban = get_username_ban('wrongtype')
         wrong_type_ban = get_username_ban('wrongtype')
         self.assertIsNone(wrong_type_ban)
         self.assertIsNone(wrong_type_ban)
 
 
         valid_ban = Ban.objects.create(
         valid_ban = Ban.objects.create(
             banned_value='admi*',
             banned_value='admi*',
-            expires_on=timezone.now() + timedelta(days=7))
+            expires_on=timezone.now() + timedelta(days=7)
+        )
         self.assertEqual(get_username_ban('admiral').pk, valid_ban.pk)
         self.assertEqual(get_username_ban('admiral').pk, valid_ban.pk)
 
 
     def test_get_email_ban(self):
     def test_get_email_ban(self):
@@ -36,15 +41,19 @@ class GetBanTests(TestCase):
         nonexistent_ban = get_email_ban('non@existent.com')
         nonexistent_ban = get_email_ban('non@existent.com')
         self.assertIsNone(nonexistent_ban)
         self.assertIsNone(nonexistent_ban)
 
 
-        Ban.objects.create(banned_value='ex@pired.com',
-                           check_type=BAN_EMAIL,
-                           expires_on=timezone.now() - timedelta(days=7))
+        Ban.objects.create(
+            banned_value='ex@pired.com',
+            check_type=BAN_EMAIL,
+            expires_on=timezone.now() - timedelta(days=7)
+        )
 
 
         expired_ban = get_email_ban('ex@pired.com')
         expired_ban = get_email_ban('ex@pired.com')
         self.assertIsNone(expired_ban)
         self.assertIsNone(expired_ban)
 
 
-        Ban.objects.create(banned_value='wrong@type.com',
-                           check_type=BAN_IP)
+        Ban.objects.create(
+            banned_value='wrong@type.com',
+            check_type=BAN_IP
+        )
 
 
         wrong_type_ban = get_email_ban('wrong@type.com')
         wrong_type_ban = get_email_ban('wrong@type.com')
         self.assertIsNone(wrong_type_ban)
         self.assertIsNone(wrong_type_ban)
@@ -52,7 +61,8 @@ class GetBanTests(TestCase):
         valid_ban = Ban.objects.create(
         valid_ban = Ban.objects.create(
             banned_value='*.ru',
             banned_value='*.ru',
             check_type=BAN_EMAIL,
             check_type=BAN_EMAIL,
-            expires_on=timezone.now() + timedelta(days=7))
+            expires_on=timezone.now() + timedelta(days=7)
+        )
         self.assertEqual(get_email_ban('banned@mail.ru').pk, valid_ban.pk)
         self.assertEqual(get_email_ban('banned@mail.ru').pk, valid_ban.pk)
 
 
     def test_get_ip_ban(self):
     def test_get_ip_ban(self):
@@ -60,15 +70,19 @@ class GetBanTests(TestCase):
         nonexistent_ban = get_ip_ban('123.0.0.1')
         nonexistent_ban = get_ip_ban('123.0.0.1')
         self.assertIsNone(nonexistent_ban)
         self.assertIsNone(nonexistent_ban)
 
 
-        Ban.objects.create(banned_value='124.0.0.1',
-                           check_type=BAN_IP,
-                           expires_on=timezone.now() - timedelta(days=7))
+        Ban.objects.create(
+            banned_value='124.0.0.1',
+            check_type=BAN_IP,
+            expires_on=timezone.now() - timedelta(days=7)
+        )
 
 
         expired_ban = get_ip_ban('124.0.0.1')
         expired_ban = get_ip_ban('124.0.0.1')
         self.assertIsNone(expired_ban)
         self.assertIsNone(expired_ban)
 
 
-        Ban.objects.create(banned_value='wrongtype',
-                           check_type=BAN_EMAIL)
+        Ban.objects.create(
+            banned_value='wrongtype',
+            check_type=BAN_EMAIL
+        )
 
 
         wrong_type_ban = get_ip_ban('wrongtype')
         wrong_type_ban = get_ip_ban('wrongtype')
         self.assertIsNone(wrong_type_ban)
         self.assertIsNone(wrong_type_ban)
@@ -76,7 +90,8 @@ class GetBanTests(TestCase):
         valid_ban = Ban.objects.create(
         valid_ban = Ban.objects.create(
             banned_value='125.0.0.*',
             banned_value='125.0.0.*',
             check_type=BAN_IP,
             check_type=BAN_IP,
-            expires_on=timezone.now() + timedelta(days=7))
+            expires_on=timezone.now() + timedelta(days=7)
+        )
         self.assertEqual(get_ip_ban('125.0.0.1').pk, valid_ban.pk)
         self.assertEqual(get_ip_ban('125.0.0.1').pk, valid_ban.pk)
 
 
 
 
@@ -94,9 +109,11 @@ class UserBansTests(TestCase):
 
 
     def test_permanent_ban(self):
     def test_permanent_ban(self):
         """user is caught by permanent ban"""
         """user is caught by permanent ban"""
-        Ban.objects.create(banned_value='bob',
-                           user_message='User reason',
-                           staff_message='Staff reason')
+        Ban.objects.create(
+            banned_value='bob',
+            user_message='User reason',
+            staff_message='Staff reason'
+        )
 
 
         user_ban = get_user_ban(self.user)
         user_ban = get_user_ban(self.user)
         self.assertIsNotNone(user_ban)
         self.assertIsNotNone(user_ban)
@@ -106,10 +123,12 @@ class UserBansTests(TestCase):
 
 
     def test_temporary_ban(self):
     def test_temporary_ban(self):
         """user is caught by temporary ban"""
         """user is caught by temporary ban"""
-        Ban.objects.create(banned_value='bo*',
-                           user_message='User reason',
-                           staff_message='Staff reason',
-                           expires_on=timezone.now() + timedelta(days=7))
+        Ban.objects.create(
+            banned_value='bo*',
+            user_message='User reason',
+            staff_message='Staff reason',
+            expires_on=timezone.now() + timedelta(days=7)
+        )
 
 
         user_ban = get_user_ban(self.user)
         user_ban = get_user_ban(self.user)
         self.assertIsNotNone(user_ban)
         self.assertIsNotNone(user_ban)
@@ -119,16 +138,20 @@ class UserBansTests(TestCase):
 
 
     def test_expired_ban(self):
     def test_expired_ban(self):
         """user is not caught by expired ban"""
         """user is not caught by expired ban"""
-        Ban.objects.create(banned_value='bo*',
-                           expires_on=timezone.now() - timedelta(days=7))
+        Ban.objects.create(
+            banned_value='bo*',
+            expires_on=timezone.now() - timedelta(days=7)
+        )
 
 
         self.assertIsNone(get_user_ban(self.user))
         self.assertIsNone(get_user_ban(self.user))
         self.assertFalse(self.user.ban_cache.is_banned)
         self.assertFalse(self.user.ban_cache.is_banned)
 
 
     def test_expired_non_flagged_ban(self):
     def test_expired_non_flagged_ban(self):
         """user is not caught by expired but checked ban"""
         """user is not caught by expired but checked ban"""
-        Ban.objects.create(banned_value='bo*',
-                           expires_on=timezone.now() - timedelta(days=7))
+        Ban.objects.create(
+            banned_value='bo*',
+            expires_on=timezone.now() - timedelta(days=7)
+        )
         Ban.objects.update(is_checked=True)
         Ban.objects.update(is_checked=True)
 
 
         self.assertIsNone(get_user_ban(self.user))
         self.assertIsNone(get_user_ban(self.user))
@@ -149,9 +172,11 @@ class RequestIPBansTests(TestCase):
 
 
     def test_permanent_ban(self):
     def test_permanent_ban(self):
         """ip is caught by permanent ban"""
         """ip is caught by permanent ban"""
-        Ban.objects.create(check_type=BAN_IP,
-                           banned_value='127.0.0.1',
-                           user_message='User reason')
+        Ban.objects.create(
+            check_type=BAN_IP,
+            banned_value='127.0.0.1',
+            user_message='User reason'
+        )
 
 
         ip_ban = get_request_ip_ban(FakeRequest())
         ip_ban = get_request_ip_ban(FakeRequest())
         self.assertTrue(ip_ban['is_banned'])
         self.assertTrue(ip_ban['is_banned'])
@@ -163,10 +188,12 @@ class RequestIPBansTests(TestCase):
 
 
     def test_temporary_ban(self):
     def test_temporary_ban(self):
         """ip is caught by temporary ban"""
         """ip is caught by temporary ban"""
-        Ban.objects.create(check_type=BAN_IP,
-                           banned_value='127.0.0.1',
-                           user_message='User reason',
-                           expires_on=timezone.now() + timedelta(days=7))
+        Ban.objects.create(
+            check_type=BAN_IP,
+            banned_value='127.0.0.1',
+            user_message='User reason',
+            expires_on=timezone.now() + timedelta(days=7)
+        )
 
 
         ip_ban = get_request_ip_ban(FakeRequest())
         ip_ban = get_request_ip_ban(FakeRequest())
         self.assertTrue(ip_ban['is_banned'])
         self.assertTrue(ip_ban['is_banned'])
@@ -178,10 +205,12 @@ class RequestIPBansTests(TestCase):
 
 
     def test_expired_ban(self):
     def test_expired_ban(self):
         """ip is not caught by expired ban"""
         """ip is not caught by expired ban"""
-        Ban.objects.create(check_type=BAN_IP,
-                           banned_value='127.0.0.1',
-                           user_message='User reason',
-                           expires_on=timezone.now() - timedelta(days=7))
+        Ban.objects.create(
+            check_type=BAN_IP,
+            banned_value='127.0.0.1',
+            user_message='User reason',
+            expires_on=timezone.now() - timedelta(days=7)
+        )
 
 
         ip_ban = get_request_ip_ban(FakeRequest())
         ip_ban = get_request_ip_ban(FakeRequest())
         self.assertIsNone(ip_ban)
         self.assertIsNone(ip_ban)
@@ -192,7 +221,7 @@ class RequestIPBansTests(TestCase):
 
 
 class BanUserTests(TestCase):
 class BanUserTests(TestCase):
     def test_ban_user(self):
     def test_ban_user(self):
-        """ban_user bans user"""
+        """ban_user utility bans user"""
         User = get_user_model()
         User = get_user_model()
         user = User.objects.create_user('Bob', 'bob@boberson.com', 'pass123')
         user = User.objects.create_user('Bob', 'bob@boberson.com', 'pass123')
 
 
@@ -206,7 +235,7 @@ class BanUserTests(TestCase):
 
 
 class BanIpTests(TestCase):
 class BanIpTests(TestCase):
     def test_ban_ip(self):
     def test_ban_ip(self):
-        """ban_ip bans IP address"""
+        """ban_ip utility bans IP address"""
         ban = ban_ip('127.0.0.1', 'User reason', 'Staff reason')
         ban = ban_ip('127.0.0.1', 'User reason', 'Staff reason')
         self.assertEqual(ban.user_message, 'User reason')
         self.assertEqual(ban.user_message, 'User reason')
         self.assertEqual(ban.staff_message, 'Staff reason')
         self.assertEqual(ban.staff_message, 'Staff reason')