Code factorisation in models.py
This commit is contained in:
		@@ -28,13 +28,38 @@ from datetime import timedelta
 | 
			
		||||
from concurrent.futures import ThreadPoolExecutor
 | 
			
		||||
from requests_futures.sessions import FuturesSession
 | 
			
		||||
 | 
			
		||||
import cas_server.utils as utils
 | 
			
		||||
from cas_server import utils
 | 
			
		||||
from . import VERSION
 | 
			
		||||
 | 
			
		||||
#: logger facility
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class JsonAttributes(models.Model):
 | 
			
		||||
    """
 | 
			
		||||
        Bases: :class:`django.db.models.Model`
 | 
			
		||||
 | 
			
		||||
        A base class for models storing attributes as a json
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    class Meta:
 | 
			
		||||
        abstract = True
 | 
			
		||||
 | 
			
		||||
    #: The attributes json encoded
 | 
			
		||||
    _attributs = models.TextField(default=None, null=True, blank=True)
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def attributs(self):
 | 
			
		||||
        """The attributes"""
 | 
			
		||||
        if self._attributs is not None:
 | 
			
		||||
            return utils.json.loads(self._attributs)
 | 
			
		||||
 | 
			
		||||
    @attributs.setter
 | 
			
		||||
    def attributs(self, value):
 | 
			
		||||
        """attributs property setter"""
 | 
			
		||||
        self._attributs = utils.json_encode(value)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@python_2_unicode_compatible
 | 
			
		||||
class FederatedIendityProvider(models.Model):
 | 
			
		||||
    """
 | 
			
		||||
@@ -130,9 +155,9 @@ class FederatedIendityProvider(models.Model):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@python_2_unicode_compatible
 | 
			
		||||
class FederatedUser(models.Model):
 | 
			
		||||
class FederatedUser(JsonAttributes):
 | 
			
		||||
    """
 | 
			
		||||
        Bases: :class:`django.db.models.Model`
 | 
			
		||||
        Bases: :class:`JsonAttributes`
 | 
			
		||||
 | 
			
		||||
        A federated user as returner by a CAS provider (username and attributes)
 | 
			
		||||
    """
 | 
			
		||||
@@ -142,8 +167,6 @@ class FederatedUser(models.Model):
 | 
			
		||||
    username = models.CharField(max_length=124)
 | 
			
		||||
    #: A foreign key to :class:`FederatedIendityProvider`
 | 
			
		||||
    provider = models.ForeignKey(FederatedIendityProvider, on_delete=models.CASCADE)
 | 
			
		||||
    #: The user attributes json encoded
 | 
			
		||||
    _attributs = models.TextField(default=None, null=True, blank=True)
 | 
			
		||||
    #: The last ticket used to authenticate :attr:`username` against :attr:`provider`
 | 
			
		||||
    ticket = models.CharField(max_length=255)
 | 
			
		||||
    #: Last update timespampt. Usually, the last time :attr:`ticket` has been set.
 | 
			
		||||
@@ -152,17 +175,6 @@ class FederatedUser(models.Model):
 | 
			
		||||
    def __str__(self):
 | 
			
		||||
        return self.federated_username
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def attributs(self):
 | 
			
		||||
        """The user attributes returned by the CAS backend on successful ticket validation"""
 | 
			
		||||
        if self._attributs is not None:
 | 
			
		||||
            return utils.json.loads(self._attributs)
 | 
			
		||||
 | 
			
		||||
    @attributs.setter
 | 
			
		||||
    def attributs(self, value):
 | 
			
		||||
        """attributs property setter"""
 | 
			
		||||
        self._attributs = utils.json_encode(value)
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def federated_username(self):
 | 
			
		||||
        """The federated username with a suffix for the current :class:`FederatedUser`."""
 | 
			
		||||
@@ -290,35 +302,23 @@ class User(models.Model):
 | 
			
		||||
            :param request: The current django HttpRequest to display possible failure to the user.
 | 
			
		||||
            :type request: :class:`django.http.HttpRequest` or :obj:`NoneType<types.NoneType>`
 | 
			
		||||
        """
 | 
			
		||||
        async_list = []
 | 
			
		||||
        session = FuturesSession(
 | 
			
		||||
            executor=ThreadPoolExecutor(max_workers=settings.CAS_SLO_MAX_PARALLEL_REQUESTS)
 | 
			
		||||
        )
 | 
			
		||||
        # first invalidate all Tickets
 | 
			
		||||
        ticket_classes = [ProxyGrantingTicket, ServiceTicket, ProxyTicket]
 | 
			
		||||
        for ticket_class in ticket_classes:
 | 
			
		||||
            queryset = ticket_class.objects.filter(user=self)
 | 
			
		||||
            for ticket in queryset:
 | 
			
		||||
                ticket.logout(session, async_list)
 | 
			
		||||
            queryset.delete()
 | 
			
		||||
        for future in async_list:
 | 
			
		||||
            if future:  # pragma: no branch (should always be true)
 | 
			
		||||
                try:
 | 
			
		||||
                    future.result()
 | 
			
		||||
                except Exception as error:
 | 
			
		||||
                    logger.warning(
 | 
			
		||||
                        "Error during SLO for user %s: %s" % (
 | 
			
		||||
                            self.username,
 | 
			
		||||
                            error
 | 
			
		||||
                        )
 | 
			
		||||
                    )
 | 
			
		||||
                    if request is not None:
 | 
			
		||||
                        error = utils.unpack_nested_exception(error)
 | 
			
		||||
                        messages.add_message(
 | 
			
		||||
                            request,
 | 
			
		||||
                            messages.WARNING,
 | 
			
		||||
                            _(u'Error during service logout %s') % error
 | 
			
		||||
                        )
 | 
			
		||||
        for error in Ticket.send_slos(
 | 
			
		||||
            [ticket_class.objects.filter(user=self) for ticket_class in ticket_classes]
 | 
			
		||||
        ):
 | 
			
		||||
            logger.warning(
 | 
			
		||||
                "Error during SLO for user %s: %s" % (
 | 
			
		||||
                    self.username,
 | 
			
		||||
                    error
 | 
			
		||||
                )
 | 
			
		||||
            )
 | 
			
		||||
            if request is not None:
 | 
			
		||||
                error = utils.unpack_nested_exception(error)
 | 
			
		||||
                messages.add_message(
 | 
			
		||||
                    request,
 | 
			
		||||
                    messages.WARNING,
 | 
			
		||||
                    _(u'Error during service logout %s') % error
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
    def get_ticket(self, ticket_class, service, service_pattern, renew):
 | 
			
		||||
        """
 | 
			
		||||
@@ -544,20 +544,13 @@ class ServicePattern(models.Model):
 | 
			
		||||
                if re.match(filtre.pattern, str(value)):
 | 
			
		||||
                    break
 | 
			
		||||
            else:
 | 
			
		||||
                bad_filter = (filtre.pattern, filtre.attribut, user.attributs.get(filtre.attribut))
 | 
			
		||||
                logger.warning(
 | 
			
		||||
                    "User constraint failed for %s, service %s: %s do not match %s %s." % (
 | 
			
		||||
                        user.username,
 | 
			
		||||
                        self.name,
 | 
			
		||||
                        filtre.pattern,
 | 
			
		||||
                        filtre.attribut,
 | 
			
		||||
                        user.attributs.get(filtre.attribut)
 | 
			
		||||
                        (user.username, self.name) + bad_filter
 | 
			
		||||
                    )
 | 
			
		||||
                )
 | 
			
		||||
                raise BadFilter('%s do not match %s %s' % (
 | 
			
		||||
                    filtre.pattern,
 | 
			
		||||
                    filtre.attribut,
 | 
			
		||||
                    user.attributs.get(filtre.attribut)
 | 
			
		||||
                ))
 | 
			
		||||
                raise BadFilter('%s do not match %s %s' % bad_filter)
 | 
			
		||||
        if self.user_field and not user.attributs.get(self.user_field):
 | 
			
		||||
            logger.warning(
 | 
			
		||||
                "Cannot use %s a loggin for user %s on service %s because it is absent" % (
 | 
			
		||||
@@ -715,9 +708,9 @@ class ReplaceAttributValue(models.Model):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@python_2_unicode_compatible
 | 
			
		||||
class Ticket(models.Model):
 | 
			
		||||
class Ticket(JsonAttributes):
 | 
			
		||||
    """
 | 
			
		||||
        Bases: :class:`django.db.models.Model`
 | 
			
		||||
        Bases: :class:`JsonAttributes`
 | 
			
		||||
 | 
			
		||||
        Generic class for a Ticket
 | 
			
		||||
    """
 | 
			
		||||
@@ -725,8 +718,6 @@ class Ticket(models.Model):
 | 
			
		||||
        abstract = True
 | 
			
		||||
    #: ForeignKey to a :class:`User`.
 | 
			
		||||
    user = models.ForeignKey(User, related_name="%(class)s")
 | 
			
		||||
    #: The user attributes to transmit to the service json encoded
 | 
			
		||||
    _attributs = models.TextField(default=None, null=True, blank=True)
 | 
			
		||||
    #: A boolean. ``True`` if the ticket has been validated
 | 
			
		||||
    validate = models.BooleanField(default=False)
 | 
			
		||||
    #: The service url for the ticket
 | 
			
		||||
@@ -749,17 +740,6 @@ class Ticket(models.Model):
 | 
			
		||||
    #: requests.
 | 
			
		||||
    TIMEOUT = settings.CAS_TICKET_TIMEOUT
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def attributs(self):
 | 
			
		||||
        """The user attributes to be transmited to the service on successful validation"""
 | 
			
		||||
        if self._attributs is not None:
 | 
			
		||||
            return utils.json.loads(self._attributs)
 | 
			
		||||
 | 
			
		||||
    @attributs.setter
 | 
			
		||||
    def attributs(self, value):
 | 
			
		||||
        """attributs property setter"""
 | 
			
		||||
        self._attributs = utils.json_encode(value)
 | 
			
		||||
 | 
			
		||||
    class DoesNotExist(Exception):
 | 
			
		||||
        """raised in :meth:`Ticket.get` then ticket prefix and ticket classes mismatch"""
 | 
			
		||||
        pass
 | 
			
		||||
@@ -767,6 +747,33 @@ class Ticket(models.Model):
 | 
			
		||||
    def __str__(self):
 | 
			
		||||
        return u"Ticket-%s" % self.pk
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def send_slos(queryset_list):
 | 
			
		||||
        """
 | 
			
		||||
            Send SLO requests to each ticket of each queryset of ``queryset_list``
 | 
			
		||||
 | 
			
		||||
            :param list queryset_list: A list a :class:`Ticket` queryset
 | 
			
		||||
            :return: A list of possibly encoutered :class:`Exception`
 | 
			
		||||
            :rtype: list
 | 
			
		||||
        """
 | 
			
		||||
        # sending SLO to timed-out validated tickets
 | 
			
		||||
        async_list = []
 | 
			
		||||
        session = FuturesSession(
 | 
			
		||||
            executor=ThreadPoolExecutor(max_workers=settings.CAS_SLO_MAX_PARALLEL_REQUESTS)
 | 
			
		||||
        )
 | 
			
		||||
        errors = []
 | 
			
		||||
        for queryset in queryset_list:
 | 
			
		||||
            for ticket in queryset:
 | 
			
		||||
                ticket.logout(session, async_list)
 | 
			
		||||
            queryset.delete()
 | 
			
		||||
        for future in async_list:
 | 
			
		||||
            if future:  # pragma: no branch (should always be true)
 | 
			
		||||
                try:
 | 
			
		||||
                    future.result()
 | 
			
		||||
                except Exception as error:
 | 
			
		||||
                    errors.append(error)
 | 
			
		||||
        return errors
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def clean_old_entries(cls):
 | 
			
		||||
        """Remove old ticket and send SLO to timed-out services"""
 | 
			
		||||
@@ -779,25 +786,12 @@ class Ticket(models.Model):
 | 
			
		||||
                Q(creation__lt=(timezone.now() - timedelta(seconds=cls.VALIDITY)))
 | 
			
		||||
            )
 | 
			
		||||
        ).delete()
 | 
			
		||||
 | 
			
		||||
        # sending SLO to timed-out validated tickets
 | 
			
		||||
        async_list = []
 | 
			
		||||
        session = FuturesSession(
 | 
			
		||||
            executor=ThreadPoolExecutor(max_workers=settings.CAS_SLO_MAX_PARALLEL_REQUESTS)
 | 
			
		||||
        )
 | 
			
		||||
        queryset = cls.objects.filter(
 | 
			
		||||
            creation__lt=(timezone.now() - timedelta(seconds=cls.TIMEOUT))
 | 
			
		||||
        )
 | 
			
		||||
        for ticket in queryset:
 | 
			
		||||
            ticket.logout(session, async_list)
 | 
			
		||||
        queryset.delete()
 | 
			
		||||
        for future in async_list:
 | 
			
		||||
            if future:  # pragma: no branch (should always be true)
 | 
			
		||||
                try:
 | 
			
		||||
                    future.result()
 | 
			
		||||
                except Exception as error:
 | 
			
		||||
                    logger.warning("Error durring SLO %s" % error)
 | 
			
		||||
                    sys.stderr.write("%r\n" % error)
 | 
			
		||||
        for error in cls.send_slos([queryset]):
 | 
			
		||||
            logger.warning("Error durring SLO %s" % error)
 | 
			
		||||
            sys.stderr.write("%r\n" % error)
 | 
			
		||||
 | 
			
		||||
    def logout(self, session, async_list=None):
 | 
			
		||||
        """Send a SLO request to the ticket service"""
 | 
			
		||||
@@ -811,16 +805,7 @@ class Ticket(models.Model):
 | 
			
		||||
                    self.user.username
 | 
			
		||||
                )
 | 
			
		||||
            )
 | 
			
		||||
            xml = u"""<samlp:LogoutRequest xmlns:samlp="urn:oasis:names:tc:SAML:2.0:protocol"
 | 
			
		||||
 ID="%(id)s" Version="2.0" IssueInstant="%(datetime)s">
 | 
			
		||||
<saml:NameID xmlns:saml="urn:oasis:names:tc:SAML:2.0:assertion"></saml:NameID>
 | 
			
		||||
<samlp:SessionIndex>%(ticket)s</samlp:SessionIndex>
 | 
			
		||||
</samlp:LogoutRequest>""" % \
 | 
			
		||||
                {
 | 
			
		||||
                    'id': utils.gen_saml_id(),
 | 
			
		||||
                    'datetime': timezone.now().isoformat(),
 | 
			
		||||
                    'ticket':  self.value
 | 
			
		||||
                }
 | 
			
		||||
            xml = utils.logout_request(self.value)
 | 
			
		||||
            if self.service_pattern.single_log_out_callback:
 | 
			
		||||
                url = self.service_pattern.single_log_out_callback
 | 
			
		||||
            else:
 | 
			
		||||
 
 | 
			
		||||
@@ -261,7 +261,7 @@ class FederateAuthLoginLogoutTestCase(
 | 
			
		||||
            # SLO for an unkown ticket should do nothing
 | 
			
		||||
            response = client.post(
 | 
			
		||||
                "/federate/%s" % provider.suffix,
 | 
			
		||||
                {'logoutRequest': tests_utils.logout_request(utils.gen_st())}
 | 
			
		||||
                {'logoutRequest': utils.logout_request(utils.gen_st())}
 | 
			
		||||
            )
 | 
			
		||||
            self.assertEqual(response.status_code, 200)
 | 
			
		||||
            self.assertEqual(response.content, b"ok")
 | 
			
		||||
@@ -288,7 +288,7 @@ class FederateAuthLoginLogoutTestCase(
 | 
			
		||||
            # 3 or 'CAS_2_SAML_1_0'
 | 
			
		||||
            response = client.post(
 | 
			
		||||
                "/federate/%s" % provider.suffix,
 | 
			
		||||
                {'logoutRequest': tests_utils.logout_request(ticket)}
 | 
			
		||||
                {'logoutRequest': utils.logout_request(ticket)}
 | 
			
		||||
            )
 | 
			
		||||
            self.assertEqual(response.status_code, 200)
 | 
			
		||||
            self.assertEqual(response.content, b"ok")
 | 
			
		||||
 
 | 
			
		||||
@@ -340,17 +340,3 @@ class DummyCAS(BaseHTTPServer.BaseHTTPRequestHandler):
 | 
			
		||||
        httpd_thread.daemon = True
 | 
			
		||||
        httpd_thread.start()
 | 
			
		||||
        return (httpd, host, port)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def logout_request(ticket):
 | 
			
		||||
    """build a SLO request XML, ready to be send"""
 | 
			
		||||
    return u"""<samlp:LogoutRequest xmlns:samlp="urn:oasis:names:tc:SAML:2.0:protocol"
 | 
			
		||||
 ID="%(id)s" Version="2.0" IssueInstant="%(datetime)s">
 | 
			
		||||
<saml:NameID xmlns:saml="urn:oasis:names:tc:SAML:2.0:assertion"></saml:NameID>
 | 
			
		||||
<samlp:SessionIndex>%(ticket)s</samlp:SessionIndex>
 | 
			
		||||
</samlp:LogoutRequest>""" % \
 | 
			
		||||
        {
 | 
			
		||||
            'id': utils.gen_saml_id(),
 | 
			
		||||
            'datetime': timezone.now().isoformat(),
 | 
			
		||||
            'ticket':  ticket
 | 
			
		||||
        }
 | 
			
		||||
 
 | 
			
		||||
@@ -17,6 +17,7 @@ from django.http import HttpResponseRedirect, HttpResponse
 | 
			
		||||
from django.contrib import messages
 | 
			
		||||
from django.contrib.messages import constants as DEFAULT_MESSAGE_LEVELS
 | 
			
		||||
from django.core.serializers.json import DjangoJSONEncoder
 | 
			
		||||
from django.utils import timezone
 | 
			
		||||
 | 
			
		||||
import random
 | 
			
		||||
import string
 | 
			
		||||
@@ -680,3 +681,22 @@ def dictfetchall(cursor):
 | 
			
		||||
        dict(zip(columns, row))
 | 
			
		||||
        for row in cursor.fetchall()
 | 
			
		||||
    ]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def logout_request(ticket):
 | 
			
		||||
    """
 | 
			
		||||
        Forge a SLO logout request
 | 
			
		||||
 | 
			
		||||
        :param unicode ticket: A ticket value
 | 
			
		||||
        :return: A SLO XML body request
 | 
			
		||||
        :rtype: unicode
 | 
			
		||||
    """
 | 
			
		||||
    return u"""<samlp:LogoutRequest xmlns:samlp="urn:oasis:names:tc:SAML:2.0:protocol"
 | 
			
		||||
 ID="%(id)s" Version="2.0" IssueInstant="%(datetime)s">
 | 
			
		||||
<saml:NameID xmlns:saml="urn:oasis:names:tc:SAML:2.0:assertion"></saml:NameID>
 | 
			
		||||
<samlp:SessionIndex>%(ticket)s</samlp:SessionIndex>
 | 
			
		||||
</samlp:LogoutRequest>""" % {
 | 
			
		||||
        'id': gen_saml_id(),
 | 
			
		||||
        'datetime': timezone.now().isoformat(),
 | 
			
		||||
        'ticket':  ticket
 | 
			
		||||
    }
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user