From b486aaadeb362f43ff437bb17397295a4ea6d87d Mon Sep 17 00:00:00 2001 From: blag Date: Sun, 15 Oct 2023 02:06:09 -0700 Subject: [PATCH 1/4] Use Django's code to simplify ours --- user_sessions/backends/db.py | 99 +++++++++--------------------------- user_sessions/middleware.py | 61 ++-------------------- user_sessions/models.py | 6 +-- 3 files changed, 32 insertions(+), 134 deletions(-) diff --git a/user_sessions/backends/db.py b/user_sessions/backends/db.py index 9cfa244..2a5e63f 100644 --- a/user_sessions/backends/db.py +++ b/user_sessions/backends/db.py @@ -1,14 +1,8 @@ -import logging - from django.contrib import auth -from django.contrib.sessions.backends.base import CreateError, SessionBase -from django.core.exceptions import SuspiciousOperation -from django.db import IntegrityError, router, transaction -from django.utils import timezone -from django.utils.encoding import force_str +from django.contrib.sessions.backends.db import SessionStore as DBStore -class SessionStore(SessionBase): +class SessionStore(DBStore): """ Implements database session store. """ @@ -19,89 +13,46 @@ def __init__(self, session_key=None, user_agent=None, ip=None): self.ip = ip self.user_id = None + @classmethod + def get_model_class(cls): + # Avoids a circular import and allows importing SessionStore when + # user_sessions is not in INSTALLED_APPS + from ..models import Session + + return Session + def __setitem__(self, key, value): if key == auth.SESSION_KEY: self.user_id = value super().__setitem__(key, value) - def load(self): - try: - s = Session.objects.get( - session_key=self.session_key, - expire_date__gt=timezone.now() - ) - self.user_id = s.user_id - # do not overwrite user_agent/ip, as those might have been updated - if self.user_agent != s.user_agent or self.ip != s.ip: - self.modified = True - return self.decode(s.session_data) - except (Session.DoesNotExist, SuspiciousOperation) as e: - if isinstance(e, SuspiciousOperation): - logger = logging.getLogger('django.security.%s' % - e.__class__.__name__) - logger.warning(force_str(e)) - self.create() - return {} - - def exists(self, session_key): - return Session.objects.filter(session_key=session_key).exists() + def _get_session_from_db(self): + s = super()._get_session_from_db() + self.user_id = s.user_id + # do not overwrite user_agent/ip, as those might have been updated + if self.user_agent != s.user_agent or self.ip != s.ip: + self.modified = True + return s def create(self): - while True: - self._session_key = self._get_new_session_key() - try: - # Save immediately to ensure we have a unique entry in the - # database. - self.save(must_create=True) - except CreateError: - # Key wasn't unique. Try again. - continue - self.modified = True - self._session_cache = {} - return + super().create() + self._session_cache = {} - def save(self, must_create=False): + def create_model_instance(self, data): """ - Saves the current session data to the database. If 'must_create' is - True, a database error will be raised if the saving operation doesn't - create a *new* entry (as opposed to possibly updating an existing - entry). + Return a new instance of the session model object, which represents the + current session state. Intended to be used for saving the session data + to the database. """ - obj = Session( + return self.model( session_key=self._get_or_create_session_key(), - session_data=self.encode(self._get_session(no_load=must_create)), + session_data=self.encode(data), expire_date=self.get_expiry_date(), user_agent=self.user_agent, user_id=self.user_id, ip=self.ip, ) - using = router.db_for_write(Session, instance=obj) - try: - with transaction.atomic(using): - obj.save(force_insert=must_create, using=using) - except IntegrityError as e: - if must_create and 'session_key' in str(e): - raise CreateError - raise def clear(self): super().clear() self.user_id = None - - def delete(self, session_key=None): - if session_key is None: - if self.session_key is None: - return - session_key = self.session_key - try: - Session.objects.get(session_key=session_key).delete() - except Session.DoesNotExist: - pass - - @classmethod - def clear_expired(cls): - Session.objects.filter(expire_date__lt=timezone.now()).delete() - - -# At bottom to avoid circular import -from ..models import Session # noqa: E402 isort:skip diff --git a/user_sessions/middleware.py b/user_sessions/middleware.py index 7df3b16..823c15b 100644 --- a/user_sessions/middleware.py +++ b/user_sessions/middleware.py @@ -1,68 +1,17 @@ -import time - from django.conf import settings -from django.utils.cache import patch_vary_headers -from django.utils.http import http_date - -try: - from importlib import import_module -except ImportError: - from django.utils.importlib import import_module +from django.contrib.sessions.middleware import ( + SessionMiddleware as DjangoSessionMiddleware, +) -try: - from django.utils.deprecation import MiddlewareMixin -except ImportError: - class MiddlewareMixin: - pass - -class SessionMiddleware(MiddlewareMixin): +class SessionMiddleware(DjangoSessionMiddleware): """ Middleware that provides ip and user_agent to the session store. """ def process_request(self, request): - engine = import_module(settings.SESSION_ENGINE) session_key = request.COOKIES.get(settings.SESSION_COOKIE_NAME, None) - request.session = engine.SessionStore( + request.session = self.SessionStore( ip=request.META.get('REMOTE_ADDR', ''), user_agent=request.META.get('HTTP_USER_AGENT', ''), session_key=session_key ) - - def process_response(self, request, response): - """ - If request.session was modified, or if the configuration is to save the - session every time, save the changes and set a session cookie. - """ - try: - accessed = request.session.accessed - modified = request.session.modified - except AttributeError: - pass - else: - if accessed: - patch_vary_headers(response, ('Cookie',)) - if modified or settings.SESSION_SAVE_EVERY_REQUEST: - if request.session.get_expire_at_browser_close(): - max_age = None - expires = None - else: - max_age = request.session.get_expiry_age() - expires_time = time.time() + max_age - expires = http_date(expires_time) - # Save the session data and refresh the client cookie. - # Skip session save for 500 responses, refs #3881. - if response.status_code != 500: - request.session.save() - response.set_cookie( - settings.SESSION_COOKIE_NAME, - request.session.session_key, - max_age=max_age, - expires=expires, - domain=settings.SESSION_COOKIE_DOMAIN, - path=settings.SESSION_COOKIE_PATH, - secure=settings.SESSION_COOKIE_SECURE or None, - httponly=settings.SESSION_COOKIE_HTTPONLY or None, - samesite=settings.SESSION_COOKIE_SAMESITE, - ) - return response diff --git a/user_sessions/models.py b/user_sessions/models.py index 32cb3f4..0f16077 100644 --- a/user_sessions/models.py +++ b/user_sessions/models.py @@ -2,6 +2,8 @@ from django.db import models from django.utils.translation import gettext_lazy as _ +from .backends.db import SessionStore + class SessionManager(models.Manager): use_in_migrations = True @@ -52,7 +54,3 @@ def get_decoded(self): user_agent = models.CharField(null=True, blank=True, max_length=200) last_activity = models.DateTimeField(auto_now=True) ip = models.GenericIPAddressField(null=True, blank=True, verbose_name='IP') - - -# At bottom to avoid circular import -from .backends.db import SessionStore # noqa: E402 isort:skip From 45276008984667fd9b0265a61bf3e73d42672cda Mon Sep 17 00:00:00 2001 From: blag Date: Sun, 15 Oct 2023 02:07:12 -0700 Subject: [PATCH 2/4] Simplify URLs --- user_sessions/urls.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/user_sessions/urls.py b/user_sessions/urls.py index e824910..5536b73 100644 --- a/user_sessions/urls.py +++ b/user_sessions/urls.py @@ -1,8 +1,6 @@ -from django.urls import path, re_path +from django.urls import path -from user_sessions.views import SessionDeleteOtherView - -from .views import SessionDeleteView, SessionListView +from .views import SessionDeleteOtherView, SessionDeleteView, SessionListView app_name = 'user_sessions' urlpatterns = [ @@ -16,8 +14,8 @@ view=SessionDeleteOtherView.as_view(), name='session_delete_other', ), - re_path( - r'^account/sessions/(?P\w+)/delete/$', + path( + 'account/sessions//delete/', view=SessionDeleteView.as_view(), name='session_delete', ), From 20839a0060b730c98fc8e81b4b7cf4df97d4b591 Mon Sep 17 00:00:00 2001 From: blag Date: Sun, 15 Oct 2023 02:08:19 -0700 Subject: [PATCH 3/4] Group database columns together in code --- user_sessions/models.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/user_sessions/models.py b/user_sessions/models.py index 0f16077..9d98628 100644 --- a/user_sessions/models.py +++ b/user_sessions/models.py @@ -40,6 +40,12 @@ class Session(models.Model): primary_key=True) session_data = models.TextField(_('session data')) expire_date = models.DateTimeField(_('expiry date'), db_index=True) + user = models.ForeignKey(getattr(settings, 'AUTH_USER_MODEL', 'auth.User'), + null=True, on_delete=models.CASCADE) + user_agent = models.CharField(null=True, blank=True, max_length=200) + last_activity = models.DateTimeField(auto_now=True) + ip = models.GenericIPAddressField(null=True, blank=True, verbose_name='IP') + objects = SessionManager() class Meta: @@ -48,9 +54,3 @@ class Meta: def get_decoded(self): return SessionStore(None, None).decode(self.session_data) - - user = models.ForeignKey(getattr(settings, 'AUTH_USER_MODEL', 'auth.User'), - null=True, on_delete=models.CASCADE) - user_agent = models.CharField(null=True, blank=True, max_length=200) - last_activity = models.DateTimeField(auto_now=True) - ip = models.GenericIPAddressField(null=True, blank=True, verbose_name='IP') From ef28ea0425c3ecd6b2c3f028603aa903ef7e9935 Mon Sep 17 00:00:00 2001 From: blag Date: Sun, 15 Oct 2023 12:28:40 -0700 Subject: [PATCH 4/4] Add some comments to give a sense where methods are used by superclass --- user_sessions/backends/db.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/user_sessions/backends/db.py b/user_sessions/backends/db.py index 2a5e63f..3eb0315 100644 --- a/user_sessions/backends/db.py +++ b/user_sessions/backends/db.py @@ -13,6 +13,7 @@ def __init__(self, session_key=None, user_agent=None, ip=None): self.ip = ip self.user_id = None + # Used by superclass to get self.model, which is used elsewhere @classmethod def get_model_class(cls): # Avoids a circular import and allows importing SessionStore when @@ -26,6 +27,7 @@ def __setitem__(self, key, value): self.user_id = value super().__setitem__(key, value) + # Used in DBStore.load() def _get_session_from_db(self): s = super()._get_session_from_db() self.user_id = s.user_id @@ -38,6 +40,7 @@ def create(self): super().create() self._session_cache = {} + # Used in DBStore.save() def create_model_instance(self, data): """ Return a new instance of the session model object, which represents the