Add SLO support from federated CAS
This commit is contained in:
		@@ -12,7 +12,11 @@
 | 
				
			|||||||
from .default_settings import settings
 | 
					from .default_settings import settings
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from .cas import CASClient
 | 
					from .cas import CASClient
 | 
				
			||||||
from .models import FederatedUser
 | 
					from .models import FederatedUser, FederateSLO, User
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from importlib import import_module
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					SessionStore = import_module(settings.SESSION_ENGINE).SessionStore
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class CASFederateValidateUser(object):
 | 
					class CASFederateValidateUser(object):
 | 
				
			||||||
@@ -68,3 +72,33 @@ class CASFederateValidateUser(object):
 | 
				
			|||||||
            return True
 | 
					            return True
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            return False
 | 
					            return False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def register_slo(self, username, session_key, ticket):
 | 
				
			||||||
 | 
					        FederateSLO.objects.create(
 | 
				
			||||||
 | 
					            username=username,
 | 
				
			||||||
 | 
					            session_key=session_key,
 | 
				
			||||||
 | 
					            ticket=ticket
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def clean_sessions(self, logout_request):
 | 
				
			||||||
 | 
					        try:
 | 
				
			||||||
 | 
					            SLOs = self.client.get_saml_slos(logout_request)
 | 
				
			||||||
 | 
					        except NameError:
 | 
				
			||||||
 | 
					            SLOs = []
 | 
				
			||||||
 | 
					        for slo in SLOs:
 | 
				
			||||||
 | 
					            try:
 | 
				
			||||||
 | 
					                for federate_slo in FederateSLO.objects.filter(ticket=slo.text):
 | 
				
			||||||
 | 
					                    session = SessionStore(session_key=federate_slo.session_key)
 | 
				
			||||||
 | 
					                    session.flush()
 | 
				
			||||||
 | 
					                    try:
 | 
				
			||||||
 | 
					                        user = User.objects.get(
 | 
				
			||||||
 | 
					                            username=federate_slo.username,
 | 
				
			||||||
 | 
					                            session_key=federate_slo.session_key
 | 
				
			||||||
 | 
					                        )
 | 
				
			||||||
 | 
					                        user.logout()
 | 
				
			||||||
 | 
					                        user.delete()
 | 
				
			||||||
 | 
					                    except User.DoesNotExist:
 | 
				
			||||||
 | 
					                        pass
 | 
				
			||||||
 | 
					                    federate_slo.delete()
 | 
				
			||||||
 | 
					            except FederateSLO.DoesNotExist:
 | 
				
			||||||
 | 
					                pass
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -16,6 +16,8 @@ class Command(BaseCommand):
 | 
				
			|||||||
        federated_users = models.FederatedUser.objects.filter(
 | 
					        federated_users = models.FederatedUser.objects.filter(
 | 
				
			||||||
            last_update__lt=(timezone.now() - timedelta(seconds=settings.CAS_TICKET_TIMEOUT))
 | 
					            last_update__lt=(timezone.now() - timedelta(seconds=settings.CAS_TICKET_TIMEOUT))
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					        known_users = {user.username for user in models.User.objects.all()}
 | 
				
			||||||
        for user in federated_users:
 | 
					        for user in federated_users:
 | 
				
			||||||
            if not models.User.objects.filter(username='%s@%s' % (user.username, user.provider)):
 | 
					            if not ('%s@%s' % (user.username, user.provider)) in known_users:
 | 
				
			||||||
                user.delete()
 | 
					                user.delete()
 | 
				
			||||||
 | 
					        models.FederateSLO.clean_deleted_sessions()
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										28
									
								
								cas_server/migrations/0006_auto_20160623_1516.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										28
									
								
								cas_server/migrations/0006_auto_20160623_1516.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,28 @@
 | 
				
			|||||||
 | 
					# -*- coding: utf-8 -*-
 | 
				
			||||||
 | 
					# Generated by Django 1.9.7 on 2016-06-23 15:16
 | 
				
			||||||
 | 
					from __future__ import unicode_literals
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from django.db import migrations, models
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class Migration(migrations.Migration):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    dependencies = [
 | 
				
			||||||
 | 
					        ('cas_server', '0005_auto_20160616_1018'),
 | 
				
			||||||
 | 
					    ]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    operations = [
 | 
				
			||||||
 | 
					        migrations.CreateModel(
 | 
				
			||||||
 | 
					            name='FederateSLO',
 | 
				
			||||||
 | 
					            fields=[
 | 
				
			||||||
 | 
					                ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
 | 
				
			||||||
 | 
					                ('username', models.CharField(max_length=30)),
 | 
				
			||||||
 | 
					                ('session_key', models.CharField(blank=True, max_length=40, null=True)),
 | 
				
			||||||
 | 
					                ('ticket', models.CharField(max_length=255)),
 | 
				
			||||||
 | 
					            ],
 | 
				
			||||||
 | 
					        ),
 | 
				
			||||||
 | 
					        migrations.AlterUniqueTogether(
 | 
				
			||||||
 | 
					            name='federateslo',
 | 
				
			||||||
 | 
					            unique_together=set([('username', 'session_key')]),
 | 
				
			||||||
 | 
					        ),
 | 
				
			||||||
 | 
					    ]
 | 
				
			||||||
@@ -48,6 +48,25 @@ class FederatedUser(models.Model):
 | 
				
			|||||||
        return u"%s@%s" % (self.username, self.provider)
 | 
					        return u"%s@%s" % (self.username, self.provider)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class FederateSLO(models.Model):
 | 
				
			||||||
 | 
					    class Meta:
 | 
				
			||||||
 | 
					        unique_together = ("username", "session_key")
 | 
				
			||||||
 | 
					    username = models.CharField(max_length=30)
 | 
				
			||||||
 | 
					    session_key = models.CharField(max_length=40, blank=True, null=True)
 | 
				
			||||||
 | 
					    ticket = models.CharField(max_length=255)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @property
 | 
				
			||||||
 | 
					    def provider(self):
 | 
				
			||||||
 | 
					        component = self.username.split("@")
 | 
				
			||||||
 | 
					        return component[-1]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @classmethod
 | 
				
			||||||
 | 
					    def clean_deleted_sessions(cls):
 | 
				
			||||||
 | 
					        for federate_slo in cls.objects.all():
 | 
				
			||||||
 | 
					            if not SessionStore(session_key=federate_slo.session_key).get('authenticated'):
 | 
				
			||||||
 | 
					                federate_slo.delete()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class User(models.Model):
 | 
					class User(models.Model):
 | 
				
			||||||
    """A user logged into the CAS"""
 | 
					    """A user logged into the CAS"""
 | 
				
			||||||
    class Meta:
 | 
					    class Meta:
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -184,6 +184,8 @@ def gen_saml_id():
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def get_tuple(tuple, index, default=None):
 | 
					def get_tuple(tuple, index, default=None):
 | 
				
			||||||
 | 
					    if tuple is None:
 | 
				
			||||||
 | 
					        return default
 | 
				
			||||||
    try:
 | 
					    try:
 | 
				
			||||||
        return tuple[index]
 | 
					        return tuple[index]
 | 
				
			||||||
    except IndexError:
 | 
					    except IndexError:
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -20,7 +20,7 @@ from django.utils.decorators import method_decorator
 | 
				
			|||||||
from django.utils.translation import ugettext as _
 | 
					from django.utils.translation import ugettext as _
 | 
				
			||||||
from django.utils import timezone
 | 
					from django.utils import timezone
 | 
				
			||||||
from django.views.decorators.csrf import csrf_exempt
 | 
					from django.views.decorators.csrf import csrf_exempt
 | 
				
			||||||
 | 
					from django.middleware.csrf import CsrfViewMiddleware
 | 
				
			||||||
from django.views.generic import View
 | 
					from django.views.generic import View
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import logging
 | 
					import logging
 | 
				
			||||||
@@ -78,6 +78,11 @@ class LogoutMixin(object):
 | 
				
			|||||||
                username=username,
 | 
					                username=username,
 | 
				
			||||||
                session_key=self.request.session.session_key
 | 
					                session_key=self.request.session.session_key
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
 | 
					            if settings.CAS_FEDERATE:
 | 
				
			||||||
 | 
					                models.FederateSLO.objects.filter(
 | 
				
			||||||
 | 
					                    username=username,
 | 
				
			||||||
 | 
					                    session_key=self.request.session.session_key
 | 
				
			||||||
 | 
					                ).delete()
 | 
				
			||||||
            self.request.session.flush()
 | 
					            self.request.session.flush()
 | 
				
			||||||
            user.logout(self.request)
 | 
					            user.logout(self.request)
 | 
				
			||||||
            user.delete()
 | 
					            user.delete()
 | 
				
			||||||
@@ -181,43 +186,73 @@ class LogoutView(View, LogoutMixin):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class FederateAuth(View):
 | 
					class FederateAuth(View):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @method_decorator(csrf_exempt)
 | 
				
			||||||
 | 
					    def dispatch(self, request, *args, **kwargs):
 | 
				
			||||||
 | 
					        return super(FederateAuth, self).dispatch(request, *args, **kwargs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def get_cas_client(self, request, provider):
 | 
				
			||||||
 | 
					        if provider in settings.CAS_FEDERATE_PROVIDERS:
 | 
				
			||||||
 | 
					            service_url = utils.get_current_url(request, {"ticket", "provider"})
 | 
				
			||||||
 | 
					            return CASFederateValidateUser(provider, service_url)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def post(self, request, provider=None):
 | 
					    def post(self, request, provider=None):
 | 
				
			||||||
        if not settings.CAS_FEDERATE:
 | 
					        if not settings.CAS_FEDERATE:
 | 
				
			||||||
            return redirect("cas_server:login")
 | 
					            return redirect("cas_server:login")
 | 
				
			||||||
        form = forms.FederateSelect(request.POST)
 | 
					        # POST with a provider, this is probably an SLO request
 | 
				
			||||||
        if form.is_valid():
 | 
					        if provider in settings.CAS_FEDERATE_PROVIDERS:
 | 
				
			||||||
            params = utils.copy_params(
 | 
					            auth = self.get_cas_client(request, provider)
 | 
				
			||||||
                request.POST,
 | 
					            try:
 | 
				
			||||||
                ignore={"provider", "csrfmiddlewaretoken", "ticket"}
 | 
					                auth.clean_sessions(request.POST['logoutRequest'])
 | 
				
			||||||
            )
 | 
					            except KeyError:
 | 
				
			||||||
            url = utils.reverse_params(
 | 
					                pass
 | 
				
			||||||
                "cas_server:federateAuth",
 | 
					            return HttpResponse("ok")
 | 
				
			||||||
                kwargs=dict(provider=form.cleaned_data["provider"]),
 | 
					        # else, a User is trying to log in using an identity provider
 | 
				
			||||||
                params=params
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
            response = HttpResponseRedirect(url)
 | 
					 | 
				
			||||||
            if form.cleaned_data["remember"]:
 | 
					 | 
				
			||||||
                max_age = settings.CAS_FEDERATE_REMEMBER_TIMEOUT
 | 
					 | 
				
			||||||
                utils.set_cookie(response, "_remember_provider", request.POST["provider"], max_age)
 | 
					 | 
				
			||||||
            return response
 | 
					 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            return redirect("cas_server:login")
 | 
					            # Manually checking for csrf to protect the code below
 | 
				
			||||||
 | 
					            reason = CsrfViewMiddleware().process_view(request, None, (), {})
 | 
				
			||||||
 | 
					            if reason is not None:
 | 
				
			||||||
 | 
					                return reason  # Failed the test, stop here.
 | 
				
			||||||
 | 
					            form = forms.FederateSelect(request.POST)
 | 
				
			||||||
 | 
					            if form.is_valid():
 | 
				
			||||||
 | 
					                params = utils.copy_params(
 | 
				
			||||||
 | 
					                    request.POST,
 | 
				
			||||||
 | 
					                    ignore={"provider", "csrfmiddlewaretoken", "ticket"}
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					                url = utils.reverse_params(
 | 
				
			||||||
 | 
					                    "cas_server:federateAuth",
 | 
				
			||||||
 | 
					                    kwargs=dict(provider=form.cleaned_data["provider"]),
 | 
				
			||||||
 | 
					                    params=params
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					                response = HttpResponseRedirect(url)
 | 
				
			||||||
 | 
					                if form.cleaned_data["remember"]:
 | 
				
			||||||
 | 
					                    max_age = settings.CAS_FEDERATE_REMEMBER_TIMEOUT
 | 
				
			||||||
 | 
					                    utils.set_cookie(
 | 
				
			||||||
 | 
					                        response,
 | 
				
			||||||
 | 
					                        "_remember_provider",
 | 
				
			||||||
 | 
					                        request.POST["provider"],
 | 
				
			||||||
 | 
					                        max_age
 | 
				
			||||||
 | 
					                    )
 | 
				
			||||||
 | 
					                return response
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                return redirect("cas_server:login")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def get(self, request, provider=None):
 | 
					    def get(self, request, provider=None):
 | 
				
			||||||
        if not settings.CAS_FEDERATE:
 | 
					        if not settings.CAS_FEDERATE:
 | 
				
			||||||
            return redirect("cas_server:login")
 | 
					            return redirect("cas_server:login")
 | 
				
			||||||
        if provider not in settings.CAS_FEDERATE_PROVIDERS:
 | 
					        if provider not in settings.CAS_FEDERATE_PROVIDERS:
 | 
				
			||||||
            return redirect("cas_server:login")
 | 
					            return redirect("cas_server:login")
 | 
				
			||||||
        service_url = utils.get_current_url(request, {"ticket", "provider"})
 | 
					        auth = self.get_cas_client(request, provider)
 | 
				
			||||||
        auth = CASFederateValidateUser(provider, service_url)
 | 
					 | 
				
			||||||
        if 'ticket' not in request.GET:
 | 
					        if 'ticket' not in request.GET:
 | 
				
			||||||
            return HttpResponseRedirect(auth.get_login_url())
 | 
					            return HttpResponseRedirect(auth.get_login_url())
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            ticket = request.GET['ticket']
 | 
					            ticket = request.GET['ticket']
 | 
				
			||||||
            if auth.verify_ticket(ticket):
 | 
					            if auth.verify_ticket(ticket):
 | 
				
			||||||
                params = utils.copy_params(request.GET, ignore={"ticket"})
 | 
					                params = utils.copy_params(request.GET, ignore={"ticket"})
 | 
				
			||||||
                request.session["federate_username"] = "%s@%s" % (auth.username, auth.provider)
 | 
					                username = "%s@%s" % (auth.username, auth.provider)
 | 
				
			||||||
 | 
					                request.session["federate_username"] = username
 | 
				
			||||||
                request.session["federate_ticket"] = ticket
 | 
					                request.session["federate_ticket"] = ticket
 | 
				
			||||||
 | 
					                auth.register_slo(username, request.session.session_key, ticket)
 | 
				
			||||||
                url = utils.reverse_params("cas_server:login", params)
 | 
					                url = utils.reverse_params("cas_server:login", params)
 | 
				
			||||||
                return HttpResponseRedirect(url)
 | 
					                return HttpResponseRedirect(url)
 | 
				
			||||||
            else:
 | 
					            else:
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user