Add some tests using tox
This commit is contained in:
		
							
								
								
									
										3
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										3
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							@@ -1,6 +1,9 @@
 | 
				
			|||||||
*.pyc
 | 
					*.pyc
 | 
				
			||||||
 | 
					*.egg-info
 | 
				
			||||||
 | 
					
 | 
				
			||||||
bootstrap3
 | 
					bootstrap3
 | 
				
			||||||
cas/
 | 
					cas/
 | 
				
			||||||
db.sqlite3
 | 
					db.sqlite3
 | 
				
			||||||
manage.py
 | 
					manage.py
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					.tox
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										21
									
								
								.travis.yml
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										21
									
								
								.travis.yml
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,21 @@
 | 
				
			|||||||
 | 
					language: python
 | 
				
			||||||
 | 
					python:
 | 
				
			||||||
 | 
					  - "2.7"
 | 
				
			||||||
 | 
					env:
 | 
				
			||||||
 | 
					  global:
 | 
				
			||||||
 | 
					    - PIP_DOWNLOAD_CACHE=$HOME/.pip_cache
 | 
				
			||||||
 | 
					  matrix:
 | 
				
			||||||
 | 
					    - TOX_ENV=py27-django17
 | 
				
			||||||
 | 
					    - TOX_ENV=py27-django18
 | 
				
			||||||
 | 
					    - TOX_ENV=flake8
 | 
				
			||||||
 | 
					cache:
 | 
				
			||||||
 | 
					  directories:
 | 
				
			||||||
 | 
					    - $HOME/.pip-cache/
 | 
				
			||||||
 | 
					install:
 | 
				
			||||||
 | 
					  - "travis_retry pip install setuptools --upgrade"
 | 
				
			||||||
 | 
					  - "pip install tox"
 | 
				
			||||||
 | 
					script:
 | 
				
			||||||
 | 
					  - tox -e $TOX_ENV
 | 
				
			||||||
 | 
					after_script:
 | 
				
			||||||
 | 
					  - cat .tox/$TOX_ENV/log/*.log
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -27,26 +27,14 @@ class UserCredential(forms.Form):
 | 
				
			|||||||
    method = forms.CharField(widget=forms.HiddenInput(), required=False)
 | 
					    method = forms.CharField(widget=forms.HiddenInput(), required=False)
 | 
				
			||||||
    warn = forms.BooleanField(label=_('warn'), required=False)
 | 
					    warn = forms.BooleanField(label=_('warn'), required=False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __init__(self, request, *args, **kwargs):
 | 
					    def __init__(self, *args, **kwargs):
 | 
				
			||||||
        self.request = request
 | 
					 | 
				
			||||||
        super(UserCredential, self).__init__(*args, **kwargs)
 | 
					        super(UserCredential, self).__init__(*args, **kwargs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def clean(self):
 | 
					    def clean(self):
 | 
				
			||||||
        cleaned_data = super(UserCredential, self).clean()
 | 
					        cleaned_data = super(UserCredential, self).clean()
 | 
				
			||||||
        auth = utils.import_attr(settings.CAS_AUTH_CLASS)(cleaned_data.get("username"))
 | 
					        auth = utils.import_attr(settings.CAS_AUTH_CLASS)(cleaned_data.get("username"))
 | 
				
			||||||
        if auth.test_password(cleaned_data.get("password")):
 | 
					        if auth.test_password(cleaned_data.get("password")):
 | 
				
			||||||
            try:
 | 
					            cleaned_data["username"] = auth.username
 | 
				
			||||||
                user = models.User.objects.get(
 | 
					 | 
				
			||||||
                    username=auth.username,
 | 
					 | 
				
			||||||
                    session_key=self.request.session.session_key
 | 
					 | 
				
			||||||
                )
 | 
					 | 
				
			||||||
                user.save()
 | 
					 | 
				
			||||||
            except models.User.DoesNotExist:
 | 
					 | 
				
			||||||
                user = models.User.objects.create(
 | 
					 | 
				
			||||||
                    username=auth.username,
 | 
					 | 
				
			||||||
                    session_key=self.request.session.session_key
 | 
					 | 
				
			||||||
                )
 | 
					 | 
				
			||||||
                user.save()
 | 
					 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            raise forms.ValidationError(_(u"Bad user"))
 | 
					            raise forms.ValidationError(_(u"Bad user"))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -89,11 +89,14 @@ class LogoutView(View, LogoutMixin):
 | 
				
			|||||||
    request = None
 | 
					    request = None
 | 
				
			||||||
    service = None
 | 
					    service = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def get(self, request, *args, **kwargs):
 | 
					    def init_get(self, request):
 | 
				
			||||||
        """methode called on GET request on this view"""
 | 
					 | 
				
			||||||
        self.request = request
 | 
					        self.request = request
 | 
				
			||||||
        self.service = request.GET.get('service')
 | 
					        self.service = request.GET.get('service')
 | 
				
			||||||
        self.url = request.GET.get('url')
 | 
					        self.url = request.GET.get('url')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def get(self, request, *args, **kwargs):
 | 
				
			||||||
 | 
					        """methode called on GET request on this view"""
 | 
				
			||||||
 | 
					        self.init_get(request)
 | 
				
			||||||
        self.logout()
 | 
					        self.logout()
 | 
				
			||||||
        # if service is set, redirect to service after logout
 | 
					        # if service is set, redirect to service after logout
 | 
				
			||||||
        if self.service:
 | 
					        if self.service:
 | 
				
			||||||
@@ -105,6 +108,7 @@ class LogoutView(View, LogoutMixin):
 | 
				
			|||||||
        # else redirect to login page
 | 
					        # else redirect to login page
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            if settings.CAS_REDIRECT_TO_LOGIN_AFTER_LOGOUT:
 | 
					            if settings.CAS_REDIRECT_TO_LOGIN_AFTER_LOGOUT:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                messages.add_message(request, messages.SUCCESS, _(u'Successfully logout'))
 | 
					                messages.add_message(request, messages.SUCCESS, _(u'Successfully logout'))
 | 
				
			||||||
                return redirect("cas_server:login")
 | 
					                return redirect("cas_server:login")
 | 
				
			||||||
            else:
 | 
					            else:
 | 
				
			||||||
@@ -129,67 +133,110 @@ class LoginView(View, LogoutMixin):
 | 
				
			|||||||
    renewed = False
 | 
					    renewed = False
 | 
				
			||||||
    warned = False
 | 
					    warned = False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def post(self, request, *args, **kwargs):
 | 
					    INVALID_LOGIN_TICKET = 1
 | 
				
			||||||
        """methode called on POST request on this view"""
 | 
					    USER_LOGIN_OK = 2
 | 
				
			||||||
 | 
					    USER_LOGIN_FAILURE = 3
 | 
				
			||||||
 | 
					    USER_ALREADY_LOGGED = 4
 | 
				
			||||||
 | 
					    USER_AUTHENTICATED = 5
 | 
				
			||||||
 | 
					    USER_NOT_AUTHENTICATED = 6
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def init_post(self, request):
 | 
				
			||||||
        self.request = request
 | 
					        self.request = request
 | 
				
			||||||
        self.service = request.POST.get('service')
 | 
					        self.service = request.POST.get('service')
 | 
				
			||||||
        self.renew = True if request.POST.get('renew') else False
 | 
					        self.renew = True if request.POST.get('renew') else False
 | 
				
			||||||
        self.gateway = request.POST.get('gateway')
 | 
					        self.gateway = request.POST.get('gateway')
 | 
				
			||||||
        self.method = request.POST.get('method')
 | 
					        self.method = request.POST.get('method')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def check_lt(self):
 | 
				
			||||||
        # save LT for later check
 | 
					        # save LT for later check
 | 
				
			||||||
        lt_valid = request.session.get('lt')
 | 
					        lt_valid = self.request.session.get('lt')
 | 
				
			||||||
        lt_send = request.POST.get('lt')
 | 
					        lt_send = self.request.POST.get('lt')
 | 
				
			||||||
        # generate a new LT (by posting the LT has been consumed)
 | 
					        # generate a new LT (by posting the LT has been consumed)
 | 
				
			||||||
        request.session['lt'] = utils.gen_lt()
 | 
					        self.request.session['lt'] = utils.gen_lt()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # check if send LT is valid
 | 
					        # check if send LT is valid
 | 
				
			||||||
        if lt_valid is None or lt_valid != lt_send:
 | 
					        if lt_valid is None or lt_valid != lt_send:
 | 
				
			||||||
 | 
					            return False
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            return True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def post(self, request, *args, **kwargs):
 | 
				
			||||||
 | 
					        """methode called on POST request on this view"""
 | 
				
			||||||
 | 
					        self.init_post(request)
 | 
				
			||||||
 | 
					        ret = self.process_post()
 | 
				
			||||||
 | 
					        if ret == self.INVALID_LOGIN_TICKET:
 | 
				
			||||||
            messages.add_message(
 | 
					            messages.add_message(
 | 
				
			||||||
                self.request,
 | 
					                self.request,
 | 
				
			||||||
                messages.ERROR,
 | 
					                messages.ERROR,
 | 
				
			||||||
                _(u"Invalid login ticket")
 | 
					                _(u"Invalid login ticket")
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
            values = request.POST.copy()
 | 
					        elif ret == self.USER_LOGIN_OK:
 | 
				
			||||||
            # if not set a new LT and fail
 | 
					            try:
 | 
				
			||||||
            values['lt'] = request.session['lt']
 | 
					 | 
				
			||||||
            self.init_form(values)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        elif not request.session.get("authenticated") or self.renew:
 | 
					 | 
				
			||||||
            self.init_form(request.POST)
 | 
					 | 
				
			||||||
            if self.form.is_valid():
 | 
					 | 
				
			||||||
                self.user = models.User.objects.get(
 | 
					                self.user = models.User.objects.get(
 | 
				
			||||||
                    username=self.form.cleaned_data['username'],
 | 
					                    username=self.request.session['username'],
 | 
				
			||||||
                    session_key=self.request.session.session_key
 | 
					                    session_key=self.request.session.session_key
 | 
				
			||||||
                )
 | 
					                )
 | 
				
			||||||
                request.session.set_expiry(0)
 | 
					                self.user.save()
 | 
				
			||||||
                request.session["username"] = self.form.cleaned_data['username']
 | 
					            except models.User.DoesNotExist:
 | 
				
			||||||
                request.session["warn"] = True if self.form.cleaned_data.get("warn") else False
 | 
					                self.user = models.User.objects.create(
 | 
				
			||||||
                request.session["authenticated"] = True
 | 
					                    username=self.request.session['username'],
 | 
				
			||||||
                self.renewed = True
 | 
					                    session_key=self.request.session.session_key
 | 
				
			||||||
                self.warned = True
 | 
					                )
 | 
				
			||||||
            else:
 | 
					                self.user.save()
 | 
				
			||||||
 | 
					        elif ret == self.USER_LOGIN_FAILURE:  # bad user login
 | 
				
			||||||
            self.logout()
 | 
					            self.logout()
 | 
				
			||||||
 | 
					        elif ret == self.USER_ALREADY_LOGGED:
 | 
				
			||||||
 | 
					            pass
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            raise EnvironmentError("invalid output for LoginView.process_post")
 | 
				
			||||||
        return self.common()
 | 
					        return self.common()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def get(self, request, *args, **kwargs):
 | 
					    def process_post(self, pytest=False):
 | 
				
			||||||
        """methode called on GET request on this view"""
 | 
					        if not self.check_lt():
 | 
				
			||||||
 | 
					            values = self.request.POST.copy()
 | 
				
			||||||
 | 
					            # if not set a new LT and fail
 | 
				
			||||||
 | 
					            values['lt'] = self.request.session['lt']
 | 
				
			||||||
 | 
					            self.init_form(values)
 | 
				
			||||||
 | 
					            return self.INVALID_LOGIN_TICKET
 | 
				
			||||||
 | 
					        elif not self.request.session.get("authenticated") or self.renew:
 | 
				
			||||||
 | 
					            self.init_form(self.request.POST)
 | 
				
			||||||
 | 
					            if self.form.is_valid():
 | 
				
			||||||
 | 
					                self.request.session.set_expiry(0)
 | 
				
			||||||
 | 
					                self.request.session["username"] = self.form.cleaned_data['username']
 | 
				
			||||||
 | 
					                self.request.session["warn"] = True if self.form.cleaned_data.get("warn") else False
 | 
				
			||||||
 | 
					                self.request.session["authenticated"] = True
 | 
				
			||||||
 | 
					                self.renewed = True
 | 
				
			||||||
 | 
					                self.warned = True
 | 
				
			||||||
 | 
					                return self.USER_LOGIN_OK
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                return self.USER_LOGIN_FAILURE
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            return self.USER_ALREADY_LOGGED
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def init_get(self, request):
 | 
				
			||||||
        self.request = request
 | 
					        self.request = request
 | 
				
			||||||
        self.service = request.GET.get('service')
 | 
					        self.service = request.GET.get('service')
 | 
				
			||||||
        self.renew = True if request.GET.get('renew') else False
 | 
					        self.renew = True if request.GET.get('renew') else False
 | 
				
			||||||
        self.gateway = request.GET.get('gateway')
 | 
					        self.gateway = request.GET.get('gateway')
 | 
				
			||||||
        self.method = request.GET.get('method')
 | 
					        self.method = request.GET.get('method')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # generate a new LT if none is present
 | 
					    def get(self, request, *args, **kwargs):
 | 
				
			||||||
        request.session['lt'] = request.session.get('lt', utils.gen_lt())
 | 
					        """methode called on GET request on this view"""
 | 
				
			||||||
 | 
					        self.init_get(request)
 | 
				
			||||||
        if not request.session.get("authenticated") or self.renew:
 | 
					        self.process_get()
 | 
				
			||||||
            self.init_form()
 | 
					 | 
				
			||||||
        return self.common()
 | 
					        return self.common()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def process_get(self):
 | 
				
			||||||
 | 
					        # generate a new LT if none is present
 | 
				
			||||||
 | 
					        self.request.session['lt'] = self.request.session.get('lt', utils.gen_lt())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if not self.request.session.get("authenticated") or self.renew:
 | 
				
			||||||
 | 
					            self.init_form()
 | 
				
			||||||
 | 
					            return self.USER_NOT_AUTHENTICATED
 | 
				
			||||||
 | 
					        return self.USER_AUTHENTICATED
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def init_form(self, values=None):
 | 
					    def init_form(self, values=None):
 | 
				
			||||||
        self.form = forms.UserCredential(
 | 
					        self.form = forms.UserCredential(
 | 
				
			||||||
            self.request,
 | 
					 | 
				
			||||||
            values,
 | 
					            values,
 | 
				
			||||||
            initial={
 | 
					            initial={
 | 
				
			||||||
                'service': self.service,
 | 
					                'service': self.service,
 | 
				
			||||||
@@ -345,7 +392,6 @@ class Auth(View):
 | 
				
			|||||||
        if not username or not password or not service:
 | 
					        if not username or not password or not service:
 | 
				
			||||||
            return HttpResponse("no\n", content_type="text/plain")
 | 
					            return HttpResponse("no\n", content_type="text/plain")
 | 
				
			||||||
        form = forms.UserCredential(
 | 
					        form = forms.UserCredential(
 | 
				
			||||||
            request,
 | 
					 | 
				
			||||||
            request.POST,
 | 
					            request.POST,
 | 
				
			||||||
            initial={
 | 
					            initial={
 | 
				
			||||||
                'service': service,
 | 
					                'service': service,
 | 
				
			||||||
@@ -354,11 +400,18 @@ class Auth(View):
 | 
				
			|||||||
            }
 | 
					            }
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        if form.is_valid():
 | 
					        if form.is_valid():
 | 
				
			||||||
 | 
					            try:
 | 
				
			||||||
                try:
 | 
					                try:
 | 
				
			||||||
                    user = models.User.objects.get(
 | 
					                    user = models.User.objects.get(
 | 
				
			||||||
                        username=form.cleaned_data['username'],
 | 
					                        username=form.cleaned_data['username'],
 | 
				
			||||||
                        session_key=request.session.session_key
 | 
					                        session_key=request.session.session_key
 | 
				
			||||||
                    )
 | 
					                    )
 | 
				
			||||||
 | 
					                except models.User.DoesNotExist:
 | 
				
			||||||
 | 
					                    user = models.User.objects.create(
 | 
				
			||||||
 | 
					                        username=form.cleaned_data['username'],
 | 
				
			||||||
 | 
					                        session_key=request.session.session_key
 | 
				
			||||||
 | 
					                    )
 | 
				
			||||||
 | 
					                user.save()
 | 
				
			||||||
                # is the service allowed
 | 
					                # is the service allowed
 | 
				
			||||||
                service_pattern = ServicePattern.validate(service)
 | 
					                service_pattern = ServicePattern.validate(service)
 | 
				
			||||||
                # is the current user allowed on this service
 | 
					                # is the current user allowed on this service
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										9
									
								
								requirements-dev.txt
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										9
									
								
								requirements-dev.txt
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,9 @@
 | 
				
			|||||||
 | 
					tox==1.8.1
 | 
				
			||||||
 | 
					pytest==2.6.4
 | 
				
			||||||
 | 
					pytest-django==2.7.0
 | 
				
			||||||
 | 
					pytest-pythonpath==0.3
 | 
				
			||||||
 | 
					requests>=2.4
 | 
				
			||||||
 | 
					django-picklefield>=0.3.1
 | 
				
			||||||
 | 
					requests_futures>=0.9.5
 | 
				
			||||||
 | 
					django-bootstrap3>=5.4
 | 
				
			||||||
 | 
					lxml>=3.4
 | 
				
			||||||
							
								
								
									
										7
									
								
								requirements.txt
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										7
									
								
								requirements.txt
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,7 @@
 | 
				
			|||||||
 | 
					setuptools>=5.5
 | 
				
			||||||
 | 
					requests>=2.4
 | 
				
			||||||
 | 
					requests_futures>=0.9.5
 | 
				
			||||||
 | 
					django-picklefield>=0.3.1
 | 
				
			||||||
 | 
					django-bootstrap3>=5.4
 | 
				
			||||||
 | 
					lxml>=3.4
 | 
				
			||||||
 | 
					
 | 
				
			||||||
							
								
								
									
										0
									
								
								tests/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								tests/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										61
									
								
								tests/dummy.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										61
									
								
								tests/dummy.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,61 @@
 | 
				
			|||||||
 | 
					from cas_server import models
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class DummyUserManager(object):
 | 
				
			||||||
 | 
					    def __init__(self, username, session_key):
 | 
				
			||||||
 | 
					        self.username = username
 | 
				
			||||||
 | 
					        self.session_key = session_key
 | 
				
			||||||
 | 
					    def get(self, username=None, session_key=None):
 | 
				
			||||||
 | 
					        if username == self.username and session_key == self.session_key:
 | 
				
			||||||
 | 
					            return models.User(username=username, session_key=session_key)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            raise models.User.DoesNotExist()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class DummyTicketManager(object):
 | 
				
			||||||
 | 
					    def __init__(self, ticket_class, service, ticket):
 | 
				
			||||||
 | 
					        self.ticket_class = ticket_class
 | 
				
			||||||
 | 
					        self.service = service
 | 
				
			||||||
 | 
					        self.ticket = ticket
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def create(self, **kwargs):
 | 
				
			||||||
 | 
					        for field in models.ServiceTicket._meta.fields:
 | 
				
			||||||
 | 
					            field.allow_unsaved_instance_assignment = True
 | 
				
			||||||
 | 
					        return self.ticket_class(**kwargs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def filter(self, *args, **kwargs):
 | 
				
			||||||
 | 
					        return DummyQuerySet()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def get(self, **kwargs):
 | 
				
			||||||
 | 
					        if 'value' in kwargs:
 | 
				
			||||||
 | 
					            if kwargs['value'] != self.ticket:
 | 
				
			||||||
 | 
					                raise self.ticket_class.DoesNotExist()
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            kwargs['value'] = self.ticket
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        if 'service' in kwargs:
 | 
				
			||||||
 | 
					            if kwargs['service'] != self.service:
 | 
				
			||||||
 | 
					                raise self.ticket_class.DoesNotExist()
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            kwargs['service'] = self.service
 | 
				
			||||||
 | 
					        if not 'user' in kwargs:
 | 
				
			||||||
 | 
					            kwargs['user'] = models.User(username="test")
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        for field in models.ServiceTicket._meta.fields:
 | 
				
			||||||
 | 
					            field.allow_unsaved_instance_assignment = True
 | 
				
			||||||
 | 
					        for key in kwargs.keys():
 | 
				
			||||||
 | 
					            if '__' in key:
 | 
				
			||||||
 | 
					                del kwargs[key]
 | 
				
			||||||
 | 
					        kwargs['attributs'] = {'mail': 'test@example.com'}
 | 
				
			||||||
 | 
					        kwargs['service_pattern'] = models.ServicePattern()
 | 
				
			||||||
 | 
					        return self.ticket_class(**kwargs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class DummySession(dict):
 | 
				
			||||||
 | 
					    session_key = "test_session"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def set_expiry(self, int):
 | 
				
			||||||
 | 
					        pass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class DummyQuerySet(set):
 | 
				
			||||||
 | 
					    pass
 | 
				
			||||||
							
								
								
									
										32
									
								
								tests/init.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										32
									
								
								tests/init.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,32 @@
 | 
				
			|||||||
 | 
					import django
 | 
				
			||||||
 | 
					from django.conf import settings
 | 
				
			||||||
 | 
					from django.contrib import messages
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					settings.configure()
 | 
				
			||||||
 | 
					settings.STATIC_URL = "/static/"
 | 
				
			||||||
 | 
					settings.DATABASES = {
 | 
				
			||||||
 | 
					    'default': {
 | 
				
			||||||
 | 
					        'ENGINE': 'django.db.backends.sqlite3',
 | 
				
			||||||
 | 
					        'NAME': '/dev/null',
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					settings.INSTALLED_APPS = (
 | 
				
			||||||
 | 
					    'django.contrib.admin',
 | 
				
			||||||
 | 
					    'django.contrib.auth',
 | 
				
			||||||
 | 
					    'django.contrib.contenttypes',
 | 
				
			||||||
 | 
					    'django.contrib.sessions',
 | 
				
			||||||
 | 
					    'django.contrib.messages',
 | 
				
			||||||
 | 
					    'django.contrib.staticfiles',
 | 
				
			||||||
 | 
					    'bootstrap3',
 | 
				
			||||||
 | 
					    'cas_server',
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					settings.ROOT_URLCONF = "/"
 | 
				
			||||||
 | 
					settings.CAS_AUTH_CLASS = 'cas_server.auth.TestAuthUser'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					try:
 | 
				
			||||||
 | 
					    django.setup()
 | 
				
			||||||
 | 
					except AttributeError:
 | 
				
			||||||
 | 
					    pass
 | 
				
			||||||
 | 
					messages.add_message = lambda x,y,z:None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
							
								
								
									
										93
									
								
								tests/test_validate_service.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										93
									
								
								tests/test_validate_service.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,93 @@
 | 
				
			|||||||
 | 
					from __future__ import absolute_import
 | 
				
			||||||
 | 
					from .init import *
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from django.test import RequestFactory
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import os
 | 
				
			||||||
 | 
					import pytest
 | 
				
			||||||
 | 
					from lxml import etree
 | 
				
			||||||
 | 
					from cas_server.views import ValidateService
 | 
				
			||||||
 | 
					from cas_server import models
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from .dummy import *
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.mark.django_db
 | 
				
			||||||
 | 
					def test_validate_service_view_ok():
 | 
				
			||||||
 | 
					    factory = RequestFactory()
 | 
				
			||||||
 | 
					    request = factory.get('/serviceValidate?ticket=ST-random&service=https://www.example.com')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    request.session = DummySession()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    models.ServiceTicket.objects = DummyTicketManager(models.ServiceTicket, 'https://www.example.com', "ST-random")
 | 
				
			||||||
 | 
					    models.ServiceTicket.save = lambda x:None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    validate = ValidateService()
 | 
				
			||||||
 | 
					    validate.allow_proxy_ticket = False
 | 
				
			||||||
 | 
					    response = validate.get(request)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    assert response.status_code == 200
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    root = etree.fromstring(response.content)
 | 
				
			||||||
 | 
					    users = root.xpath("//cas:user", namespaces={'cas': "http://www.yale.edu/tp/cas"})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    assert len(users) == 1
 | 
				
			||||||
 | 
					    assert users[0].text == "test"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    attributes = root.xpath("//cas:attributes", namespaces={'cas': "http://www.yale.edu/tp/cas"})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    assert len(attributes) == 1
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    attrs = {}
 | 
				
			||||||
 | 
					    for attr in attributes[0]:
 | 
				
			||||||
 | 
					        attrs[attr.tag[len("http://www.yale.edu/tp/cas")+2:]]=attr.text
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    assert 'mail' in attrs
 | 
				
			||||||
 | 
					    assert attrs['mail'] == 'test@example.com'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.mark.django_db
 | 
				
			||||||
 | 
					def test_validate_service_view_badservice():
 | 
				
			||||||
 | 
					    factory = RequestFactory()
 | 
				
			||||||
 | 
					    request = factory.get('/serviceValidate?ticket=ST-random&service=https://www.example1.com')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    request.session = DummySession()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    models.ServiceTicket.objects = DummyTicketManager(models.ServiceTicket, 'https://www.example2.com', "ST-random")
 | 
				
			||||||
 | 
					    models.ServiceTicket.save = lambda x:None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    validate = ValidateService()
 | 
				
			||||||
 | 
					    validate.allow_proxy_ticket = False
 | 
				
			||||||
 | 
					    response = validate.get(request)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    assert response.status_code == 200
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    root = etree.fromstring(response.content)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    error = root.xpath("//cas:authenticationFailure", namespaces={'cas': "http://www.yale.edu/tp/cas"})
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    assert len(error) == 1
 | 
				
			||||||
 | 
					    assert error[0].attrib['code'] == 'INVALID_SERVICE'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.mark.django_db
 | 
				
			||||||
 | 
					def test_validate_service_view_badticket():
 | 
				
			||||||
 | 
					    factory = RequestFactory()
 | 
				
			||||||
 | 
					    request = factory.get('/serviceValidate?ticket=ST-random1&service=https://www.example.com')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    request.session = DummySession()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    models.ServiceTicket.objects = DummyTicketManager(models.ServiceTicket, 'https://www.example.com', "ST-random2")
 | 
				
			||||||
 | 
					    models.ServiceTicket.save = lambda x:None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    validate = ValidateService()
 | 
				
			||||||
 | 
					    validate.allow_proxy_ticket = False
 | 
				
			||||||
 | 
					    response = validate.get(request)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    assert response.status_code == 200
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    root = etree.fromstring(response.content)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    error = root.xpath("//cas:authenticationFailure", namespaces={'cas': "http://www.yale.edu/tp/cas"})
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    assert len(error) == 1
 | 
				
			||||||
 | 
					    assert error[0].attrib['code'] == 'INVALID_TICKET'
 | 
				
			||||||
							
								
								
									
										49
									
								
								tests/test_views_auth.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										49
									
								
								tests/test_views_auth.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,49 @@
 | 
				
			|||||||
 | 
					from __future__ import absolute_import
 | 
				
			||||||
 | 
					from .init import *
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from django.test import RequestFactory
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import os
 | 
				
			||||||
 | 
					import pytest
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from cas_server.views import Auth
 | 
				
			||||||
 | 
					from cas_server import models
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from .dummy import *
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					settings.CAS_AUTH_SHARED_SECRET = "test"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.mark.django_db
 | 
				
			||||||
 | 
					def test_auth_view_goodpass():
 | 
				
			||||||
 | 
					    factory = RequestFactory()
 | 
				
			||||||
 | 
					    request = factory.post('/auth', {'username':'test', 'password':'test', 'service':'https://www.example.com', 'secret':'test'})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    request.session = DummySession()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    models.User.objects = DummyUserManager(username="test", session_key=request.session.session_key)
 | 
				
			||||||
 | 
					    models.ServiceTicket.objects = DummyTicketManager(models.ServiceTicket, 'https://www.example.com', "ST-random")
 | 
				
			||||||
 | 
					    models.ServicePattern.validate = classmethod(lambda x,y: models.ServicePattern())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    auth = Auth()
 | 
				
			||||||
 | 
					    response = auth.post(request)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    assert response.status_code == 200
 | 
				
			||||||
 | 
					    assert response.content == "yes\n"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_auth_view_badpass():
 | 
				
			||||||
 | 
					    factory = RequestFactory()
 | 
				
			||||||
 | 
					    request = factory.post('/auth', {'username':'test', 'password':'badpass', 'service':'https://www.example.com', 'secret':'test'})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    request.session = DummySession()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    models.User.objects = DummyUserManager(username="test", session_key=request.session.session_key)
 | 
				
			||||||
 | 
					    models.ServiceTicket.objects = DummyTicketManager(models.ServiceTicket, 'https://www.example.com', "ST-random")
 | 
				
			||||||
 | 
					    models.ServicePattern.validate = classmethod(lambda x,y: models.ServicePattern())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    auth = Auth()
 | 
				
			||||||
 | 
					    response = auth.post(request)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    assert response.status_code == 200
 | 
				
			||||||
 | 
					    assert response.content == "no\n"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
							
								
								
									
										170
									
								
								tests/test_views_login.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										170
									
								
								tests/test_views_login.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,170 @@
 | 
				
			|||||||
 | 
					from __future__ import absolute_import
 | 
				
			||||||
 | 
					from .init import *
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from django.test import RequestFactory
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import os
 | 
				
			||||||
 | 
					import pytest
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from cas_server.views import LoginView
 | 
				
			||||||
 | 
					from cas_server import models
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from .dummy import *
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_login_view_post_goodpass_goodlt():
 | 
				
			||||||
 | 
					    factory = RequestFactory()
 | 
				
			||||||
 | 
					    request = factory.post('/login', {'username':'test', 'password':'test', 'lt':'LT-random'})
 | 
				
			||||||
 | 
					    request.session = DummySession()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    request.session['lt'] = 'LT-random'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    request.session["username"] = os.urandom(20)
 | 
				
			||||||
 | 
					    request.session["warn"] = os.urandom(20)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    login = LoginView()
 | 
				
			||||||
 | 
					    login.init_post(request)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    ret = login.process_post(pytest=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    assert ret == LoginView.USER_LOGIN_OK
 | 
				
			||||||
 | 
					    assert request.session.get("authenticated") == True
 | 
				
			||||||
 | 
					    assert request.session.get("username") == "test"
 | 
				
			||||||
 | 
					    assert request.session.get("warn") == False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_login_view_post_badlt():
 | 
				
			||||||
 | 
					    factory = RequestFactory()
 | 
				
			||||||
 | 
					    request = factory.post('/login', {'username':'test', 'password':'test', 'lt':'LT-random1'})
 | 
				
			||||||
 | 
					    request.session = DummySession()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    request.session['lt'] = 'LT-random2'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    authenticated = os.urandom(20)
 | 
				
			||||||
 | 
					    username = os.urandom(20)
 | 
				
			||||||
 | 
					    warn = os.urandom(20)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    request.session["authenticated"] = authenticated
 | 
				
			||||||
 | 
					    request.session["username"] = username
 | 
				
			||||||
 | 
					    request.session["warn"] = warn
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    login = LoginView()
 | 
				
			||||||
 | 
					    login.init_post(request)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    ret = login.process_post(pytest=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    assert ret == LoginView.INVALID_LOGIN_TICKET
 | 
				
			||||||
 | 
					    assert request.session.get("authenticated") == authenticated
 | 
				
			||||||
 | 
					    assert request.session.get("username") == username
 | 
				
			||||||
 | 
					    assert request.session.get("warn") == warn
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_login_view_post_badpass_good_lt():
 | 
				
			||||||
 | 
					    factory = RequestFactory()
 | 
				
			||||||
 | 
					    request = factory.post('/login', {'username':'test', 'password':'badpassword', 'lt':'LT-random'})
 | 
				
			||||||
 | 
					    request.session = DummySession()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    request.session['lt'] = 'LT-random'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    login = LoginView()
 | 
				
			||||||
 | 
					    login.init_post(request)
 | 
				
			||||||
 | 
					    ret = login.process_post()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    assert ret == LoginView.USER_LOGIN_FAILURE
 | 
				
			||||||
 | 
					    assert not request.session.get("authenticated")
 | 
				
			||||||
 | 
					    assert not request.session.get("username")
 | 
				
			||||||
 | 
					    assert not request.session.get("warn")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_view_login_get_unauth():
 | 
				
			||||||
 | 
					    factory = RequestFactory()
 | 
				
			||||||
 | 
					    request = factory.post('/login')
 | 
				
			||||||
 | 
					    request.session = DummySession()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    login = LoginView()
 | 
				
			||||||
 | 
					    login.init_get(request)
 | 
				
			||||||
 | 
					    ret = login.process_get()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    assert ret == LoginView.USER_NOT_AUTHENTICATED
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    login = LoginView()
 | 
				
			||||||
 | 
					    response = login.get(request)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    assert response.status_code == 200
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.mark.django_db
 | 
				
			||||||
 | 
					def test_view_login_get_auth():
 | 
				
			||||||
 | 
					    factory = RequestFactory()
 | 
				
			||||||
 | 
					    request = factory.post('/login')
 | 
				
			||||||
 | 
					    request.session = DummySession()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    request.session["authenticated"] = True
 | 
				
			||||||
 | 
					    request.session["username"] = "test"
 | 
				
			||||||
 | 
					    request.session["warn"] = False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    login = LoginView()
 | 
				
			||||||
 | 
					    login.init_get(request)
 | 
				
			||||||
 | 
					    ret = login.process_get()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    assert ret == LoginView.USER_AUTHENTICATED
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    models.User.objects = DummyUserManager(username="test", session_key=request.session.session_key)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    login = LoginView()
 | 
				
			||||||
 | 
					    response = login.get(request)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    assert response.status_code == 200
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.mark.django_db
 | 
				
			||||||
 | 
					def test_view_login_get_auth_service():
 | 
				
			||||||
 | 
					    factory = RequestFactory()
 | 
				
			||||||
 | 
					    request = factory.post('/login?service=https://www.example.com')
 | 
				
			||||||
 | 
					    request.session = DummySession()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    request.session["authenticated"] = True
 | 
				
			||||||
 | 
					    request.session["username"] = "test"
 | 
				
			||||||
 | 
					    request.session["warn"] = False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    login = LoginView()
 | 
				
			||||||
 | 
					    login.init_get(request)
 | 
				
			||||||
 | 
					    ret = login.process_get()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    assert ret == LoginView.USER_AUTHENTICATED
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    models.User.objects = DummyUserManager(username="test", session_key=request.session.session_key)
 | 
				
			||||||
 | 
					    models.User.save = lambda x:None
 | 
				
			||||||
 | 
					    models.ServiceTicket.objects = DummyTicketManager(models.ServiceTicket, 'https://www.example.com', "ST-random")
 | 
				
			||||||
 | 
					    models.ServicePattern.validate = classmethod(lambda x,y: models.ServicePattern())
 | 
				
			||||||
 | 
					    models.ServiceTicket.save = lambda x:None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    login = LoginView()
 | 
				
			||||||
 | 
					    response = login.get(request)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    assert response.status_code == 302
 | 
				
			||||||
 | 
					    assert response['Location'].startswith('https://www.example.com?ticket=ST-')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.mark.django_db
 | 
				
			||||||
 | 
					def test_view_login_get_auth_service_warn():
 | 
				
			||||||
 | 
					    factory = RequestFactory()
 | 
				
			||||||
 | 
					    request = factory.post('/login?service=https://www.example.com')
 | 
				
			||||||
 | 
					    request.session = DummySession()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    request.session["authenticated"] = True
 | 
				
			||||||
 | 
					    request.session["username"] = "test"
 | 
				
			||||||
 | 
					    request.session["warn"] = True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    login = LoginView()
 | 
				
			||||||
 | 
					    login.init_get(request)
 | 
				
			||||||
 | 
					    ret = login.process_get()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    assert ret == LoginView.USER_AUTHENTICATED
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    models.User.objects = DummyUserManager(username="test", session_key=request.session.session_key)
 | 
				
			||||||
 | 
					    models.User.save = lambda x:None
 | 
				
			||||||
 | 
					    models.ServiceTicket.objects = DummyTicketManager(models.ServiceTicket, 'https://www.example.com', "ST-random")
 | 
				
			||||||
 | 
					    models.ServicePattern.validate = classmethod(lambda x,y: models.ServicePattern())
 | 
				
			||||||
 | 
					    models.ServiceTicket.save = lambda x:None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    login = LoginView()
 | 
				
			||||||
 | 
					    response = login.get(request)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    assert response.status_code == 200
 | 
				
			||||||
							
								
								
									
										92
									
								
								tests/test_views_logout.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										92
									
								
								tests/test_views_logout.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,92 @@
 | 
				
			|||||||
 | 
					from __future__ import absolute_import
 | 
				
			||||||
 | 
					from .init import *
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from django.test import RequestFactory
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import os
 | 
				
			||||||
 | 
					import pytest
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from cas_server.views import LogoutView
 | 
				
			||||||
 | 
					from cas_server import models
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from .dummy import *
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.mark.django_db
 | 
				
			||||||
 | 
					def test_logout_view():
 | 
				
			||||||
 | 
					    factory = RequestFactory()
 | 
				
			||||||
 | 
					    request = factory.get('/logout')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    request.session = DummySession()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    request.session["authenticated"] = True
 | 
				
			||||||
 | 
					    request.session["username"] = "test"
 | 
				
			||||||
 | 
					    request.session["warn"] = False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    models.User.objects = DummyUserManager(username="test", session_key=request.session.session_key)
 | 
				
			||||||
 | 
					    dlist = [None]
 | 
				
			||||||
 | 
					    models.User.delete = lambda x:dlist.pop()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    logout = LogoutView()
 | 
				
			||||||
 | 
					    response = logout.get(request)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    assert response.status_code == 200
 | 
				
			||||||
 | 
					    assert dlist == []
 | 
				
			||||||
 | 
					    assert not request.session.get("authenticated")
 | 
				
			||||||
 | 
					    assert not request.session.get("username")
 | 
				
			||||||
 | 
					    assert not request.session.get("warn")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.mark.django_db
 | 
				
			||||||
 | 
					def test_logout_view_url():
 | 
				
			||||||
 | 
					    factory = RequestFactory()
 | 
				
			||||||
 | 
					    request = factory.get('/logout?url=https://www.example.com')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    request.session = DummySession()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    request.session["authenticated"] = True
 | 
				
			||||||
 | 
					    request.session["username"] = "test"
 | 
				
			||||||
 | 
					    request.session["warn"] = False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    models.User.objects = DummyUserManager(username="test", session_key=request.session.session_key)
 | 
				
			||||||
 | 
					    dlist = [None]
 | 
				
			||||||
 | 
					    models.User.delete = lambda x:dlist.pop()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    logout = LogoutView()
 | 
				
			||||||
 | 
					    response = logout.get(request)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    assert response.status_code == 302
 | 
				
			||||||
 | 
					    assert response['Location'] == 'https://www.example.com'
 | 
				
			||||||
 | 
					    assert dlist == []
 | 
				
			||||||
 | 
					    assert not request.session.get("authenticated")
 | 
				
			||||||
 | 
					    assert not request.session.get("username")
 | 
				
			||||||
 | 
					    assert not request.session.get("warn")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.mark.django_db
 | 
				
			||||||
 | 
					def test_logout_view_service():
 | 
				
			||||||
 | 
					    factory = RequestFactory()
 | 
				
			||||||
 | 
					    request = factory.get('/logout?service=https://www.example.com')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    request.session = DummySession()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    request.session["authenticated"] = True
 | 
				
			||||||
 | 
					    request.session["username"] = "test"
 | 
				
			||||||
 | 
					    request.session["warn"] = False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    models.User.objects = DummyUserManager(username="test", session_key=request.session.session_key)
 | 
				
			||||||
 | 
					    dlist = [None]
 | 
				
			||||||
 | 
					    models.User.delete = lambda x:dlist.pop()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    logout = LogoutView()
 | 
				
			||||||
 | 
					    response = logout.get(request)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    assert response.status_code == 302
 | 
				
			||||||
 | 
					    assert response['Location'] == 'https://www.example.com'
 | 
				
			||||||
 | 
					    assert dlist == []
 | 
				
			||||||
 | 
					    assert not request.session.get("authenticated")
 | 
				
			||||||
 | 
					    assert not request.session.get("username")
 | 
				
			||||||
 | 
					    assert not request.session.get("warn")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
							
								
								
									
										61
									
								
								tests/test_views_validate.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										61
									
								
								tests/test_views_validate.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,61 @@
 | 
				
			|||||||
 | 
					from __future__ import absolute_import
 | 
				
			||||||
 | 
					from .init import *
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from django.test import RequestFactory
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import os
 | 
				
			||||||
 | 
					import pytest
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from cas_server.views import Validate
 | 
				
			||||||
 | 
					from cas_server import models
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from .dummy import *
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.mark.django_db
 | 
				
			||||||
 | 
					def test_validate_view_ok():
 | 
				
			||||||
 | 
					    factory = RequestFactory()
 | 
				
			||||||
 | 
					    request = factory.get('/validate?ticket=ST-random&service=https://www.example.com')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    request.session = DummySession()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    models.ServiceTicket.objects = DummyTicketManager(models.ServiceTicket, 'https://www.example.com', "ST-random")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    validate = Validate()
 | 
				
			||||||
 | 
					    response = validate.get(request)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    assert response.status_code == 200
 | 
				
			||||||
 | 
					    assert response.content == "yes\n"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.mark.django_db
 | 
				
			||||||
 | 
					def test_validate_view_badservice():
 | 
				
			||||||
 | 
					    factory = RequestFactory()
 | 
				
			||||||
 | 
					    request = factory.get('/validate?ticket=ST-random&service=https://www.example2.com')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    request.session = DummySession()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    models.ServiceTicket.objects = DummyTicketManager(models.ServiceTicket, 'https://www.example.com', "ST-random")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    validate = Validate()
 | 
				
			||||||
 | 
					    response = validate.get(request)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    assert response.status_code == 200
 | 
				
			||||||
 | 
					    assert response.content == "no\n"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.mark.django_db
 | 
				
			||||||
 | 
					def test_validate_view_badticket():
 | 
				
			||||||
 | 
					    factory = RequestFactory()
 | 
				
			||||||
 | 
					    request = factory.get('/validate?ticket=ST-random2&service=https://www.example.com')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    request.session = DummySession()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    models.ServiceTicket.objects = DummyTicketManager(models.ServiceTicket, 'https://www.example.com', "ST-random1")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    validate = Validate()
 | 
				
			||||||
 | 
					    response = validate.get(request)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    assert response.status_code == 200
 | 
				
			||||||
 | 
					    assert response.content == "no\n"
 | 
				
			||||||
							
								
								
									
										34
									
								
								tox.ini
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										34
									
								
								tox.ini
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,34 @@
 | 
				
			|||||||
 | 
					[tox]
 | 
				
			||||||
 | 
					envlist=
 | 
				
			||||||
 | 
					    py27-django17,
 | 
				
			||||||
 | 
					    py27-django18,
 | 
				
			||||||
 | 
					    flake8,
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					[flake8]
 | 
				
			||||||
 | 
					max-line-length=100
 | 
				
			||||||
 | 
					exclude=migrations
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					[base]
 | 
				
			||||||
 | 
					deps =
 | 
				
			||||||
 | 
					    -r{toxinidir}/requirements-dev.txt
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					[testenv]
 | 
				
			||||||
 | 
					commands=py.test --tb native {posargs:tests}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					[testenv:py27-django17]
 | 
				
			||||||
 | 
					basepython=python2.7
 | 
				
			||||||
 | 
					deps =
 | 
				
			||||||
 | 
					    Django>=1.7,<1.8
 | 
				
			||||||
 | 
					    {[base]deps}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					[testenv:py27-django18]
 | 
				
			||||||
 | 
					basepython=python2.7
 | 
				
			||||||
 | 
					deps =
 | 
				
			||||||
 | 
					    Django>=1.8,<1.9
 | 
				
			||||||
 | 
					    {[base]deps}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					[testenv:flake8]
 | 
				
			||||||
 | 
					basepython=python
 | 
				
			||||||
 | 
					deps=flake8
 | 
				
			||||||
 | 
					commands=flake8 {toxinidir}/cas_server
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		Reference in New Issue
	
	Block a user