test_bans.py 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263
  1. from datetime import timedelta
  2. from django.test import TestCase
  3. from django.utils import timezone
  4. from ...conftest import get_cache_versions
  5. from ..bans import (
  6. ban_ip,
  7. ban_user,
  8. get_email_ban,
  9. get_ip_ban,
  10. get_request_ip_ban,
  11. get_user_ban,
  12. get_username_ban,
  13. )
  14. from ..models import Ban
  15. from ..test import create_test_user
  16. cache_versions = get_cache_versions()
  17. class GetBanTests(TestCase):
  18. def test_get_username_ban(self):
  19. """get_username_ban returns valid ban"""
  20. nonexistent_ban = get_username_ban("nonexistent")
  21. self.assertIsNone(nonexistent_ban)
  22. Ban.objects.create(
  23. banned_value="expired", expires_on=timezone.now() - timedelta(days=7)
  24. )
  25. expired_ban = get_username_ban("expired")
  26. self.assertIsNone(expired_ban)
  27. Ban.objects.create(banned_value="wrongtype", check_type=Ban.EMAIL)
  28. wrong_type_ban = get_username_ban("wrongtype")
  29. self.assertIsNone(wrong_type_ban)
  30. valid_ban = Ban.objects.create(
  31. banned_value="admi*", expires_on=timezone.now() + timedelta(days=7)
  32. )
  33. self.assertEqual(get_username_ban("admiral").pk, valid_ban.pk)
  34. registration_ban = Ban.objects.create(
  35. banned_value="mod*",
  36. expires_on=timezone.now() + timedelta(days=7),
  37. registration_only=True,
  38. )
  39. self.assertIsNone(get_username_ban("moderator"))
  40. self.assertEqual(get_username_ban("moderator", True).pk, registration_ban.pk)
  41. def test_get_email_ban(self):
  42. """get_email_ban returns valid ban"""
  43. nonexistent_ban = get_email_ban("non@existent.com")
  44. self.assertIsNone(nonexistent_ban)
  45. Ban.objects.create(
  46. banned_value="ex@pired.com",
  47. check_type=Ban.EMAIL,
  48. expires_on=timezone.now() - timedelta(days=7),
  49. )
  50. expired_ban = get_email_ban("ex@pired.com")
  51. self.assertIsNone(expired_ban)
  52. Ban.objects.create(banned_value="wrong@type.com", check_type=Ban.IP)
  53. wrong_type_ban = get_email_ban("wrong@type.com")
  54. self.assertIsNone(wrong_type_ban)
  55. valid_ban = Ban.objects.create(
  56. banned_value="*.ru",
  57. check_type=Ban.EMAIL,
  58. expires_on=timezone.now() + timedelta(days=7),
  59. )
  60. self.assertEqual(get_email_ban("banned@mail.ru").pk, valid_ban.pk)
  61. registration_ban = Ban.objects.create(
  62. banned_value="*.ua",
  63. check_type=Ban.EMAIL,
  64. expires_on=timezone.now() + timedelta(days=7),
  65. registration_only=True,
  66. )
  67. self.assertIsNone(get_email_ban("banned@mail.ua"))
  68. self.assertEqual(get_email_ban("banned@mail.ua", True).pk, registration_ban.pk)
  69. def test_get_ip_ban(self):
  70. """get_ip_ban returns valid ban"""
  71. nonexistent_ban = get_ip_ban("123.0.0.1")
  72. self.assertIsNone(nonexistent_ban)
  73. Ban.objects.create(
  74. banned_value="124.0.0.1",
  75. check_type=Ban.IP,
  76. expires_on=timezone.now() - timedelta(days=7),
  77. )
  78. expired_ban = get_ip_ban("124.0.0.1")
  79. self.assertIsNone(expired_ban)
  80. Ban.objects.create(banned_value="wrongtype", check_type=Ban.EMAIL)
  81. wrong_type_ban = get_ip_ban("wrongtype")
  82. self.assertIsNone(wrong_type_ban)
  83. valid_ban = Ban.objects.create(
  84. banned_value="125.0.0.*",
  85. check_type=Ban.IP,
  86. expires_on=timezone.now() + timedelta(days=7),
  87. )
  88. self.assertEqual(get_ip_ban("125.0.0.1").pk, valid_ban.pk)
  89. registration_ban = Ban.objects.create(
  90. banned_value="188.*",
  91. check_type=Ban.IP,
  92. expires_on=timezone.now() + timedelta(days=7),
  93. registration_only=True,
  94. )
  95. self.assertIsNone(get_ip_ban("188.12.12.41"))
  96. self.assertEqual(get_ip_ban("188.12.12.41", True).pk, registration_ban.pk)
  97. class UserBansTests(TestCase):
  98. def setUp(self):
  99. self.user = create_test_user("User", "user@example.com")
  100. def test_no_ban(self):
  101. """user is not caught by ban"""
  102. self.assertIsNone(get_user_ban(self.user, cache_versions))
  103. self.assertFalse(self.user.ban_cache.is_banned)
  104. def test_permanent_ban(self):
  105. """user is caught by permanent ban"""
  106. Ban.objects.create(
  107. banned_value="User",
  108. user_message="User reason",
  109. staff_message="Staff reason",
  110. )
  111. user_ban = get_user_ban(self.user, cache_versions)
  112. self.assertIsNotNone(user_ban)
  113. self.assertEqual(user_ban.user_message, "User reason")
  114. self.assertEqual(user_ban.staff_message, "Staff reason")
  115. self.assertTrue(self.user.ban_cache.is_banned)
  116. def test_temporary_ban(self):
  117. """user is caught by temporary ban"""
  118. Ban.objects.create(
  119. banned_value="us*",
  120. user_message="User reason",
  121. staff_message="Staff reason",
  122. expires_on=timezone.now() + timedelta(days=7),
  123. )
  124. user_ban = get_user_ban(self.user, cache_versions)
  125. self.assertIsNotNone(user_ban)
  126. self.assertEqual(user_ban.user_message, "User reason")
  127. self.assertEqual(user_ban.staff_message, "Staff reason")
  128. self.assertTrue(self.user.ban_cache.is_banned)
  129. def test_expired_ban(self):
  130. """user is not caught by expired ban"""
  131. Ban.objects.create(
  132. banned_value="us*", expires_on=timezone.now() - timedelta(days=7)
  133. )
  134. self.assertIsNone(get_user_ban(self.user, cache_versions))
  135. self.assertFalse(self.user.ban_cache.is_banned)
  136. def test_expired_non_flagged_ban(self):
  137. """user is not caught by expired but checked ban"""
  138. Ban.objects.create(
  139. banned_value="us*", expires_on=timezone.now() - timedelta(days=7)
  140. )
  141. Ban.objects.update(is_checked=True)
  142. self.assertIsNone(get_user_ban(self.user, cache_versions))
  143. self.assertFalse(self.user.ban_cache.is_banned)
  144. class MockRequest:
  145. def __init__(self):
  146. self.user_ip = "127.0.0.1"
  147. self.session = {}
  148. self.cache_versions = cache_versions
  149. class RequestIPBansTests(TestCase):
  150. def test_no_ban(self):
  151. """no ban found"""
  152. ip_ban = get_request_ip_ban(MockRequest())
  153. self.assertIsNone(ip_ban)
  154. def test_permanent_ban(self):
  155. """ip is caught by permanent ban"""
  156. Ban.objects.create(
  157. check_type=Ban.IP, banned_value="127.0.0.1", user_message="User reason"
  158. )
  159. ip_ban = get_request_ip_ban(MockRequest())
  160. self.assertTrue(ip_ban["is_banned"])
  161. self.assertEqual(ip_ban["ip"], "127.0.0.1")
  162. self.assertEqual(ip_ban["message"], "User reason")
  163. # repeated call uses cache
  164. get_request_ip_ban(MockRequest())
  165. def test_temporary_ban(self):
  166. """ip is caught by temporary ban"""
  167. Ban.objects.create(
  168. check_type=Ban.IP,
  169. banned_value="127.0.0.1",
  170. user_message="User reason",
  171. expires_on=timezone.now() + timedelta(days=7),
  172. )
  173. ip_ban = get_request_ip_ban(MockRequest())
  174. self.assertTrue(ip_ban["is_banned"])
  175. self.assertEqual(ip_ban["ip"], "127.0.0.1")
  176. self.assertEqual(ip_ban["message"], "User reason")
  177. # repeated call uses cache
  178. get_request_ip_ban(MockRequest())
  179. def test_expired_ban(self):
  180. """ip is not caught by expired ban"""
  181. Ban.objects.create(
  182. check_type=Ban.IP,
  183. banned_value="127.0.0.1",
  184. user_message="User reason",
  185. expires_on=timezone.now() - timedelta(days=7),
  186. )
  187. ip_ban = get_request_ip_ban(MockRequest())
  188. self.assertIsNone(ip_ban)
  189. # repeated call uses cache
  190. get_request_ip_ban(MockRequest())
  191. class BanUserTests(TestCase):
  192. def test_ban_user(self):
  193. """ban_user utility bans user"""
  194. user = create_test_user("User", "user@example.com")
  195. ban = ban_user(user, "User reason", "Staff reason")
  196. self.assertEqual(ban.user_message, "User reason")
  197. self.assertEqual(ban.staff_message, "Staff reason")
  198. db_ban = get_user_ban(user, cache_versions)
  199. self.assertEqual(ban.pk, db_ban.ban_id)
  200. class BanIpTests(TestCase):
  201. def test_ban_ip(self):
  202. """ban_ip utility bans IP address"""
  203. ban = ban_ip("127.0.0.1", "User reason", "Staff reason")
  204. self.assertEqual(ban.user_message, "User reason")
  205. self.assertEqual(ban.staff_message, "Staff reason")
  206. db_ban = get_ip_ban("127.0.0.1")
  207. self.assertEqual(ban.pk, db_ban.pk)