Browse Source

Enable stacking patch_user_acl helper

rafalp 6 years ago
parent
commit
37bf66be88
2 changed files with 34 additions and 30 deletions
  1. 22 22
      misago/acl/test.py
  2. 12 8
      misago/acl/tests/test_patching_user_acl.py

+ 22 - 22
misago/acl/test.py

@@ -1,4 +1,4 @@
-from contextlib import ExitStack
+from contextlib import ContextDecorator, ExitStack, contextmanager
 from functools import wraps
 from unittest.mock import patch
 
@@ -7,7 +7,7 @@ from .useracl import get_user_acl
 __all__ = ["patch_user_acl"]
 
 
-class patch_user_acl(ExitStack):
+class patch_user_acl(ContextDecorator, ExitStack):
     """Testing utility that patches get_user_acl results
 
     Can be used as decorator or context manager.
@@ -19,9 +19,11 @@ class patch_user_acl(ExitStack):
     Patch should be a dict or callable.
     """
 
-    def __init__(self, *patches):
+    _acl_patches = []
+
+    def __init__(self, acl_patch):
         super().__init__()
-        self._patches = patches
+        self.acl_patch = acl_patch
 
     def patched_get_user_acl(self, user, cache_versions):
         user_acl = get_user_acl(user, cache_versions)
@@ -29,7 +31,7 @@ class patch_user_acl(ExitStack):
         return user_acl
 
     def apply_acl_patches(self, user, user_acl):
-        for acl_patch in self._patches:
+        for acl_patch in self._acl_patches:
             self.apply_acl_patch(user, user_acl, acl_patch)
 
     def apply_acl_patch(self, user, user_acl, acl_patch):
@@ -40,21 +42,19 @@ class patch_user_acl(ExitStack):
 
     def __enter__(self):
         super().__enter__()
-        self.enter_context(
-            patch(
-                "misago.acl.useracl.get_user_acl",
-                side_effect=self.patched_get_user_acl,
-            )
+        self.enter_context(self.enable_acl_patch())
+        self.enter_context(self.patch_user_acl())
+
+    @contextmanager
+    def enable_acl_patch(self):
+        try:
+            self._acl_patches.append(self.acl_patch)
+            yield
+        finally:
+            self._acl_patches.pop(-1)
+
+    def patch_user_acl(self):
+        return patch(
+            "misago.acl.useracl.get_user_acl",
+            side_effect=self.patched_get_user_acl,
         )
-
-    def __call__(self, f):
-        @wraps(f)
-        def inner(*args, **kwargs):
-            with self:
-                with patch(
-                    "misago.acl.useracl.get_user_acl",
-                    side_effect=self.patched_get_user_acl,
-                ):
-                    return f(*args, **kwargs)
-        
-        return inner

+ 12 - 8
misago/acl/tests/test_patching_user_acl.py

@@ -59,16 +59,20 @@ class PatchingUserACLTests(TestCase):
             user_acl = useracl.get_user_acl(user, cache_versions)
             assert user_acl["patched_for_user_id"] == user.id
 
-    @patch_user_acl(callable_acl_patch, {"other_acl_path": True})
-    def test_multiple_acl_patches_are_applied_by_decorator(self):
+    @patch_user_acl({"acl_patch": 1})
+    @patch_user_acl({"acl_patch": 2})
+    def test_multiple_acl_patches_applied_by_decorator_stack(self):
         user = User.objects.create_user("User", "user@example.com")
         user_acl = useracl.get_user_acl(user, cache_versions)
-        assert user_acl["patched_for_user_id"] == user.id
-        assert user_acl["other_acl_path"]
+        assert user_acl["acl_patch"] == 2
 
-    def test_multiple_acl_patches_are_applied_by_context_manager(self):
+    def test_multiple_acl_patches_applied_by_context_manager_stack(self):
         user = User.objects.create_user("User", "user@example.com")
-        with patch_user_acl(callable_acl_patch, {"other_acl_path": True}):
+        with patch_user_acl({"acl_patch": 1}):
+            with patch_user_acl({"acl_patch": 2}):
+                user_acl = useracl.get_user_acl(user, cache_versions)
+                assert user_acl["acl_patch"] == 2
             user_acl = useracl.get_user_acl(user, cache_versions)
-            assert user_acl["patched_for_user_id"] == user.id
-            assert user_acl["other_acl_path"]
+            assert user_acl["acl_patch"] == 1
+        user_acl = useracl.get_user_acl(user, cache_versions)
+        assert "acl_patch" not in user_acl