Merge branch 'master' into federate
This commit is contained in:
		
							
								
								
									
										0
									
								
								cas_server/tests/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								cas_server/tests/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										193
									
								
								cas_server/tests/mixin.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										193
									
								
								cas_server/tests/mixin.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,193 @@
 | 
			
		||||
# ⁻*- coding: utf-8 -*-
 | 
			
		||||
# This program is distributed in the hope that it will be useful, but WITHOUT
 | 
			
		||||
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
 | 
			
		||||
# FOR A PARTICULAR PURPOSE. See the GNU General Public License version 3 for
 | 
			
		||||
# more details.
 | 
			
		||||
#
 | 
			
		||||
# You should have received a copy of the GNU General Public License version 3
 | 
			
		||||
# along with this program; if not, write to the Free Software Foundation, Inc., 51
 | 
			
		||||
# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
 | 
			
		||||
#
 | 
			
		||||
# (c) 2016 Valentin Samir
 | 
			
		||||
"""Some mixin classes for tests"""
 | 
			
		||||
from cas_server.default_settings import settings
 | 
			
		||||
from django.utils import timezone
 | 
			
		||||
 | 
			
		||||
import re
 | 
			
		||||
from lxml import etree
 | 
			
		||||
from datetime import timedelta
 | 
			
		||||
 | 
			
		||||
from cas_server import models
 | 
			
		||||
from cas_server.tests.utils import get_auth_client
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class BaseServicePattern(object):
 | 
			
		||||
    """Mixing for setting up service pattern for testing"""
 | 
			
		||||
    def setup_service_patterns(self, proxy=False):
 | 
			
		||||
        """setting up service pattern"""
 | 
			
		||||
        # For general purpose testing
 | 
			
		||||
        self.service = "https://www.example.com"
 | 
			
		||||
        self.service_pattern = models.ServicePattern.objects.create(
 | 
			
		||||
            name="example",
 | 
			
		||||
            pattern="^https://www\.example\.com(/.*)?$",
 | 
			
		||||
            proxy=proxy,
 | 
			
		||||
        )
 | 
			
		||||
        models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern)
 | 
			
		||||
 | 
			
		||||
        # For testing the restrict_users attributes
 | 
			
		||||
        self.service_restrict_user_fail = "https://restrict_user_fail.example.com"
 | 
			
		||||
        self.service_pattern_restrict_user_fail = models.ServicePattern.objects.create(
 | 
			
		||||
            name="restrict_user_fail",
 | 
			
		||||
            pattern="^https://restrict_user_fail\.example\.com(/.*)?$",
 | 
			
		||||
            restrict_users=True,
 | 
			
		||||
            proxy=proxy,
 | 
			
		||||
        )
 | 
			
		||||
        self.service_restrict_user_success = "https://restrict_user_success.example.com"
 | 
			
		||||
        self.service_pattern_restrict_user_success = models.ServicePattern.objects.create(
 | 
			
		||||
            name="restrict_user_success",
 | 
			
		||||
            pattern="^https://restrict_user_success\.example\.com(/.*)?$",
 | 
			
		||||
            restrict_users=True,
 | 
			
		||||
            proxy=proxy,
 | 
			
		||||
        )
 | 
			
		||||
        models.Username.objects.create(
 | 
			
		||||
            value=settings.CAS_TEST_USER,
 | 
			
		||||
            service_pattern=self.service_pattern_restrict_user_success
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # For testing the user attributes filtering conditions
 | 
			
		||||
        self.service_filter_fail = "https://filter_fail.example.com"
 | 
			
		||||
        self.service_pattern_filter_fail = models.ServicePattern.objects.create(
 | 
			
		||||
            name="filter_fail",
 | 
			
		||||
            pattern="^https://filter_fail\.example\.com(/.*)?$",
 | 
			
		||||
            proxy=proxy,
 | 
			
		||||
        )
 | 
			
		||||
        models.FilterAttributValue.objects.create(
 | 
			
		||||
            attribut="right",
 | 
			
		||||
            pattern="^admin$",
 | 
			
		||||
            service_pattern=self.service_pattern_filter_fail
 | 
			
		||||
        )
 | 
			
		||||
        self.service_filter_fail_alt = "https://filter_fail_alt.example.com"
 | 
			
		||||
        self.service_pattern_filter_fail_alt = models.ServicePattern.objects.create(
 | 
			
		||||
            name="filter_fail_alt",
 | 
			
		||||
            pattern="^https://filter_fail_alt\.example\.com(/.*)?$",
 | 
			
		||||
            proxy=proxy,
 | 
			
		||||
        )
 | 
			
		||||
        models.FilterAttributValue.objects.create(
 | 
			
		||||
            attribut="nom",
 | 
			
		||||
            pattern="^toto$",
 | 
			
		||||
            service_pattern=self.service_pattern_filter_fail_alt
 | 
			
		||||
        )
 | 
			
		||||
        self.service_filter_success = "https://filter_success.example.com"
 | 
			
		||||
        self.service_pattern_filter_success = models.ServicePattern.objects.create(
 | 
			
		||||
            name="filter_success",
 | 
			
		||||
            pattern="^https://filter_success\.example\.com(/.*)?$",
 | 
			
		||||
            proxy=proxy,
 | 
			
		||||
        )
 | 
			
		||||
        models.FilterAttributValue.objects.create(
 | 
			
		||||
            attribut="email",
 | 
			
		||||
            pattern="^%s$" % re.escape(settings.CAS_TEST_ATTRIBUTES['email']),
 | 
			
		||||
            service_pattern=self.service_pattern_filter_success
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # For testing the user_field attributes
 | 
			
		||||
        self.service_field_needed_fail = "https://field_needed_fail.example.com"
 | 
			
		||||
        self.service_pattern_field_needed_fail = models.ServicePattern.objects.create(
 | 
			
		||||
            name="field_needed_fail",
 | 
			
		||||
            pattern="^https://field_needed_fail\.example\.com(/.*)?$",
 | 
			
		||||
            user_field="uid",
 | 
			
		||||
            proxy=proxy,
 | 
			
		||||
        )
 | 
			
		||||
        self.service_field_needed_success = "https://field_needed_success.example.com"
 | 
			
		||||
        self.service_pattern_field_needed_success = models.ServicePattern.objects.create(
 | 
			
		||||
            name="field_needed_success",
 | 
			
		||||
            pattern="^https://field_needed_success\.example\.com(/.*)?$",
 | 
			
		||||
            user_field="alias",
 | 
			
		||||
            proxy=proxy,
 | 
			
		||||
        )
 | 
			
		||||
        self.service_field_needed_success_alt = "https://field_needed_success_alt.example.com"
 | 
			
		||||
        self.service_pattern_field_needed_success = models.ServicePattern.objects.create(
 | 
			
		||||
            name="field_needed_success_alt",
 | 
			
		||||
            pattern="^https://field_needed_success_alt\.example\.com(/.*)?$",
 | 
			
		||||
            user_field="nom",
 | 
			
		||||
            proxy=proxy,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class XmlContent(object):
 | 
			
		||||
    """Mixin for test on CAS XML responses"""
 | 
			
		||||
    def assert_error(self, response, code, text=None):
 | 
			
		||||
        """Assert a validation error"""
 | 
			
		||||
        self.assertEqual(response.status_code, 200)
 | 
			
		||||
        root = etree.fromstring(response.content)
 | 
			
		||||
        error = root.xpath(
 | 
			
		||||
            "//cas:authenticationFailure",
 | 
			
		||||
            namespaces={'cas': "http://www.yale.edu/tp/cas"}
 | 
			
		||||
        )
 | 
			
		||||
        self.assertEqual(len(error), 1)
 | 
			
		||||
        self.assertEqual(error[0].attrib['code'], code)
 | 
			
		||||
        if text is not None:
 | 
			
		||||
            self.assertEqual(error[0].text, text)
 | 
			
		||||
 | 
			
		||||
    def assert_success(self, response, username, original_attributes):
 | 
			
		||||
        """assert a ticket validation success"""
 | 
			
		||||
        self.assertEqual(response.status_code, 200)
 | 
			
		||||
 | 
			
		||||
        root = etree.fromstring(response.content)
 | 
			
		||||
        sucess = root.xpath(
 | 
			
		||||
            "//cas:authenticationSuccess",
 | 
			
		||||
            namespaces={'cas': "http://www.yale.edu/tp/cas"}
 | 
			
		||||
        )
 | 
			
		||||
        self.assertTrue(sucess)
 | 
			
		||||
 | 
			
		||||
        users = root.xpath("//cas:user", namespaces={'cas': "http://www.yale.edu/tp/cas"})
 | 
			
		||||
        self.assertEqual(len(users), 1)
 | 
			
		||||
        self.assertEqual(users[0].text, username)
 | 
			
		||||
 | 
			
		||||
        attributes = root.xpath(
 | 
			
		||||
            "//cas:attributes",
 | 
			
		||||
            namespaces={'cas': "http://www.yale.edu/tp/cas"}
 | 
			
		||||
        )
 | 
			
		||||
        self.assertEqual(len(attributes), 1)
 | 
			
		||||
        attrs1 = set()
 | 
			
		||||
        for attr in attributes[0]:
 | 
			
		||||
            attrs1.add((attr.tag[len("http://www.yale.edu/tp/cas")+2:], attr.text))
 | 
			
		||||
 | 
			
		||||
        attributes = root.xpath("//cas:attribute", namespaces={'cas': "http://www.yale.edu/tp/cas"})
 | 
			
		||||
        self.assertEqual(len(attributes), len(attrs1))
 | 
			
		||||
        attrs2 = set()
 | 
			
		||||
        for attr in attributes:
 | 
			
		||||
            attrs2.add((attr.attrib['name'], attr.attrib['value']))
 | 
			
		||||
        original = set()
 | 
			
		||||
        for key, value in original_attributes.items():
 | 
			
		||||
            if isinstance(value, list):
 | 
			
		||||
                for sub_value in value:
 | 
			
		||||
                    original.add((key, sub_value))
 | 
			
		||||
            else:
 | 
			
		||||
                original.add((key, value))
 | 
			
		||||
        self.assertEqual(attrs1, attrs2)
 | 
			
		||||
        self.assertEqual(attrs1, original)
 | 
			
		||||
 | 
			
		||||
        return root
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class UserModels(object):
 | 
			
		||||
    """Mixin for test on CAS user models"""
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def expire_user():
 | 
			
		||||
        """return an expired user"""
 | 
			
		||||
        client = get_auth_client()
 | 
			
		||||
 | 
			
		||||
        new_date = timezone.now() - timedelta(seconds=(settings.SESSION_COOKIE_AGE + 600))
 | 
			
		||||
        models.User.objects.filter(
 | 
			
		||||
            username=settings.CAS_TEST_USER,
 | 
			
		||||
            session_key=client.session.session_key
 | 
			
		||||
        ).update(date=new_date)
 | 
			
		||||
        return client
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def get_user(client):
 | 
			
		||||
        """return the user associated with an authenticated client"""
 | 
			
		||||
        return models.User.objects.get(
 | 
			
		||||
            username=settings.CAS_TEST_USER,
 | 
			
		||||
            session_key=client.session.session_key
 | 
			
		||||
        )
 | 
			
		||||
							
								
								
									
										84
									
								
								cas_server/tests/settings.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										84
									
								
								cas_server/tests/settings.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,84 @@
 | 
			
		||||
"""
 | 
			
		||||
Django test settings for cas_server application.
 | 
			
		||||
 | 
			
		||||
Generated by 'django-admin startproject' using Django 1.9.7.
 | 
			
		||||
 | 
			
		||||
For more information on this file, see
 | 
			
		||||
https://docs.djangoproject.com/en/1.9/topics/settings/
 | 
			
		||||
 | 
			
		||||
For the full list of settings and their values, see
 | 
			
		||||
https://docs.djangoproject.com/en/1.9/ref/settings/
 | 
			
		||||
"""
 | 
			
		||||
 | 
			
		||||
import os
 | 
			
		||||
 | 
			
		||||
# Build paths inside the project like this: os.path.join(BASE_DIR, ...)
 | 
			
		||||
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Quick-start development settings - unsuitable for production
 | 
			
		||||
# See https://docs.djangoproject.com/en/1.9/howto/deployment/checklist/
 | 
			
		||||
 | 
			
		||||
# SECURITY WARNING: keep the secret key used in production secret!
 | 
			
		||||
SECRET_KEY = 'changeme'
 | 
			
		||||
 | 
			
		||||
# SECURITY WARNING: don't run with debug turned on in production!
 | 
			
		||||
DEBUG = True
 | 
			
		||||
 | 
			
		||||
ALLOWED_HOSTS = []
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Application definition
 | 
			
		||||
 | 
			
		||||
INSTALLED_APPS = [
 | 
			
		||||
    'django.contrib.admin',
 | 
			
		||||
    'django.contrib.auth',
 | 
			
		||||
    'django.contrib.contenttypes',
 | 
			
		||||
    'django.contrib.sessions',
 | 
			
		||||
    'django.contrib.messages',
 | 
			
		||||
    'django.contrib.staticfiles',
 | 
			
		||||
    'bootstrap3',
 | 
			
		||||
    'cas_server',
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
MIDDLEWARE_CLASSES = [
 | 
			
		||||
    'django.contrib.sessions.middleware.SessionMiddleware',
 | 
			
		||||
    'django.middleware.common.CommonMiddleware',
 | 
			
		||||
    'django.middleware.csrf.CsrfViewMiddleware',
 | 
			
		||||
    'django.contrib.auth.middleware.AuthenticationMiddleware',
 | 
			
		||||
    'django.contrib.auth.middleware.SessionAuthenticationMiddleware',
 | 
			
		||||
    'django.contrib.messages.middleware.MessageMiddleware',
 | 
			
		||||
    'django.middleware.clickjacking.XFrameOptionsMiddleware',
 | 
			
		||||
    'django.middleware.locale.LocaleMiddleware',
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
ROOT_URLCONF = 'cas_server.tests.urls'
 | 
			
		||||
 | 
			
		||||
# Database
 | 
			
		||||
# https://docs.djangoproject.com/en/1.9/ref/settings/#databases
 | 
			
		||||
 | 
			
		||||
DATABASES = {
 | 
			
		||||
    'default': {
 | 
			
		||||
        'ENGINE': 'django.db.backends.sqlite3',
 | 
			
		||||
        'NAME': ':memory:',
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
# Internationalization
 | 
			
		||||
# https://docs.djangoproject.com/en/1.9/topics/i18n/
 | 
			
		||||
 | 
			
		||||
LANGUAGE_CODE = 'en-us'
 | 
			
		||||
 | 
			
		||||
TIME_ZONE = 'UTC'
 | 
			
		||||
 | 
			
		||||
USE_I18N = True
 | 
			
		||||
 | 
			
		||||
USE_L10N = True
 | 
			
		||||
 | 
			
		||||
USE_TZ = True
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Static files (CSS, JavaScript, Images)
 | 
			
		||||
# https://docs.djangoproject.com/en/1.9/howto/static-files/
 | 
			
		||||
 | 
			
		||||
STATIC_URL = '/static/'
 | 
			
		||||
							
								
								
									
										166
									
								
								cas_server/tests/test_models.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										166
									
								
								cas_server/tests/test_models.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,166 @@
 | 
			
		||||
# ⁻*- coding: utf-8 -*-
 | 
			
		||||
# This program is distributed in the hope that it will be useful, but WITHOUT
 | 
			
		||||
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
 | 
			
		||||
# FOR A PARTICULAR PURPOSE. See the GNU General Public License version 3 for
 | 
			
		||||
# more details.
 | 
			
		||||
#
 | 
			
		||||
# You should have received a copy of the GNU General Public License version 3
 | 
			
		||||
# along with this program; if not, write to the Free Software Foundation, Inc., 51
 | 
			
		||||
# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
 | 
			
		||||
#
 | 
			
		||||
# (c) 2016 Valentin Samir
 | 
			
		||||
"""Tests module for models"""
 | 
			
		||||
from cas_server.default_settings import settings
 | 
			
		||||
 | 
			
		||||
from django.test import TestCase
 | 
			
		||||
from django.test.utils import override_settings
 | 
			
		||||
from django.utils import timezone
 | 
			
		||||
 | 
			
		||||
from datetime import timedelta
 | 
			
		||||
from importlib import import_module
 | 
			
		||||
 | 
			
		||||
from cas_server import models
 | 
			
		||||
from cas_server.tests.utils import get_auth_client, HttpParamsHandler
 | 
			
		||||
from cas_server.tests.mixin import UserModels, BaseServicePattern
 | 
			
		||||
 | 
			
		||||
SessionStore = import_module(settings.SESSION_ENGINE).SessionStore
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@override_settings(CAS_AUTH_CLASS='cas_server.auth.TestAuthUser')
 | 
			
		||||
class UserTestCase(TestCase, UserModels):
 | 
			
		||||
    """tests for the user models"""
 | 
			
		||||
    def setUp(self):
 | 
			
		||||
        """Prepare the test context"""
 | 
			
		||||
        self.service = 'http://127.0.0.1:45678'
 | 
			
		||||
        self.service_pattern = models.ServicePattern.objects.create(
 | 
			
		||||
            name="localhost",
 | 
			
		||||
            pattern="^https?://127\.0\.0\.1(:[0-9]+)?(/.*)?$",
 | 
			
		||||
            single_log_out=True
 | 
			
		||||
        )
 | 
			
		||||
        models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern)
 | 
			
		||||
 | 
			
		||||
    def test_clean_old_entries(self):
 | 
			
		||||
        """test clean_old_entries"""
 | 
			
		||||
        # get an authenticated client
 | 
			
		||||
        client = self.expire_user()
 | 
			
		||||
        # assert the user exists before being cleaned
 | 
			
		||||
        self.assertEqual(len(models.User.objects.all()), 1)
 | 
			
		||||
        # assert the last activity date is before the expiry date
 | 
			
		||||
        self.assertTrue(
 | 
			
		||||
            self.get_user(client).date < (
 | 
			
		||||
                timezone.now() - timedelta(seconds=settings.SESSION_COOKIE_AGE)
 | 
			
		||||
            )
 | 
			
		||||
        )
 | 
			
		||||
        # delete old inactive users
 | 
			
		||||
        models.User.clean_old_entries()
 | 
			
		||||
        # assert the user has being well delete
 | 
			
		||||
        self.assertEqual(len(models.User.objects.all()), 0)
 | 
			
		||||
 | 
			
		||||
    def test_clean_deleted_sessions(self):
 | 
			
		||||
        """test clean_deleted_sessions"""
 | 
			
		||||
        # get an authenticated client
 | 
			
		||||
        client1 = get_auth_client()
 | 
			
		||||
        client2 = get_auth_client()
 | 
			
		||||
        # generate a ticket to fire SLO during user cleaning (SLO should fail a nothing listen
 | 
			
		||||
        # on self.service)
 | 
			
		||||
        ticket = self.get_user(client1).get_ticket(
 | 
			
		||||
            models.ServiceTicket,
 | 
			
		||||
            self.service,
 | 
			
		||||
            self.service_pattern,
 | 
			
		||||
            renew=False
 | 
			
		||||
        )
 | 
			
		||||
        ticket.validate = True
 | 
			
		||||
        ticket.save()
 | 
			
		||||
        # simulated expired session being garbage collected for client1
 | 
			
		||||
        session = SessionStore(session_key=client1.session.session_key)
 | 
			
		||||
        session.flush()
 | 
			
		||||
        # assert the user exists before being cleaned
 | 
			
		||||
        self.assertTrue(self.get_user(client1))
 | 
			
		||||
        self.assertTrue(self.get_user(client2))
 | 
			
		||||
        self.assertEqual(len(models.User.objects.all()), 2)
 | 
			
		||||
        # session has being remove so the user of client1 is no longer authenticated
 | 
			
		||||
        self.assertFalse(client1.session.get("authenticated"))
 | 
			
		||||
        # the user a client2 should still be authenticated
 | 
			
		||||
        self.assertTrue(client2.session.get("authenticated"))
 | 
			
		||||
        # the user should be deleted
 | 
			
		||||
        models.User.clean_deleted_sessions()
 | 
			
		||||
        # assert the user with expired sessions has being well deleted but the other remain
 | 
			
		||||
        self.assertEqual(len(models.User.objects.all()), 1)
 | 
			
		||||
        self.assertFalse(models.ServiceTicket.objects.all())
 | 
			
		||||
        self.assertTrue(client2.session.get("authenticated"))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@override_settings(CAS_AUTH_CLASS='cas_server.auth.TestAuthUser')
 | 
			
		||||
class TicketTestCase(TestCase, UserModels, BaseServicePattern):
 | 
			
		||||
    """tests for the tickets models"""
 | 
			
		||||
    def setUp(self):
 | 
			
		||||
        """Prepare the test context"""
 | 
			
		||||
        self.setup_service_patterns()
 | 
			
		||||
        self.service = 'http://127.0.0.1:45678'
 | 
			
		||||
        self.service_pattern = models.ServicePattern.objects.create(
 | 
			
		||||
            name="localhost",
 | 
			
		||||
            pattern="^https?://127\.0\.0\.1(:[0-9]+)?(/.*)?$",
 | 
			
		||||
            single_log_out=True
 | 
			
		||||
        )
 | 
			
		||||
        models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern)
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def get_ticket(
 | 
			
		||||
        user,
 | 
			
		||||
        ticket_class,
 | 
			
		||||
        service,
 | 
			
		||||
        service_pattern,
 | 
			
		||||
        renew=False,
 | 
			
		||||
        validate=False,
 | 
			
		||||
        validity_expired=False,
 | 
			
		||||
        timeout_expired=False,
 | 
			
		||||
        single_log_out=False,
 | 
			
		||||
    ):
 | 
			
		||||
        """Return a ticket"""
 | 
			
		||||
        ticket = user.get_ticket(ticket_class, service, service_pattern, renew)
 | 
			
		||||
        ticket.validate = validate
 | 
			
		||||
        ticket.single_log_out = single_log_out
 | 
			
		||||
        if validity_expired:
 | 
			
		||||
            ticket.creation = min(
 | 
			
		||||
                ticket.creation,
 | 
			
		||||
                (timezone.now() - timedelta(seconds=(ticket_class.VALIDITY + 10)))
 | 
			
		||||
            )
 | 
			
		||||
        if timeout_expired:
 | 
			
		||||
            ticket.creation = min(
 | 
			
		||||
                ticket.creation,
 | 
			
		||||
                (timezone.now() - timedelta(seconds=(ticket_class.TIMEOUT + 10)))
 | 
			
		||||
            )
 | 
			
		||||
        ticket.save()
 | 
			
		||||
        return ticket
 | 
			
		||||
 | 
			
		||||
    def test_clean_old_service_ticket(self):
 | 
			
		||||
        """test tickets clean_old_entries"""
 | 
			
		||||
        # ge an authenticated client
 | 
			
		||||
        client = get_auth_client()
 | 
			
		||||
        # get the user associated to the client
 | 
			
		||||
        user = self.get_user(client)
 | 
			
		||||
        # generate a ticket for that client, waiting for validation
 | 
			
		||||
        self.get_ticket(user, models.ServiceTicket, self.service, self.service_pattern)
 | 
			
		||||
        # generate another ticket for those validation time has expired
 | 
			
		||||
        self.get_ticket(
 | 
			
		||||
            user, models.ServiceTicket,
 | 
			
		||||
            self.service, self.service_pattern, validity_expired=True
 | 
			
		||||
        )
 | 
			
		||||
        (httpd, host, port) = HttpParamsHandler.run()[0:3]
 | 
			
		||||
        service = "http://%s:%s" % (host, port)
 | 
			
		||||
        # generate a ticket with SLO having timeout reach
 | 
			
		||||
        self.get_ticket(
 | 
			
		||||
            user, models.ServiceTicket,
 | 
			
		||||
            service, self.service_pattern, timeout_expired=True,
 | 
			
		||||
            validate=True, single_log_out=True
 | 
			
		||||
        )
 | 
			
		||||
        # there should be 3 tickets in the db
 | 
			
		||||
        self.assertEqual(len(models.ServiceTicket.objects.all()), 3)
 | 
			
		||||
        # we call the clean_old_entries method that should delete validated non SLO ticket and
 | 
			
		||||
        # expired non validated ticket and send SLO for SLO expired ticket before deleting then
 | 
			
		||||
        models.ServiceTicket.clean_old_entries()
 | 
			
		||||
        params = httpd.PARAMS
 | 
			
		||||
        # we successfully got a SLO request
 | 
			
		||||
        self.assertTrue(b'logoutRequest' in params and params[b'logoutRequest'])
 | 
			
		||||
        # only 1 ticket remain in the db
 | 
			
		||||
        self.assertEqual(len(models.ServiceTicket.objects.all()), 1)
 | 
			
		||||
							
								
								
									
										191
									
								
								cas_server/tests/test_utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										191
									
								
								cas_server/tests/test_utils.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,191 @@
 | 
			
		||||
# ⁻*- coding: utf-8 -*-
 | 
			
		||||
# This program is distributed in the hope that it will be useful, but WITHOUT
 | 
			
		||||
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
 | 
			
		||||
# FOR A PARTICULAR PURPOSE. See the GNU General Public License version 3 for
 | 
			
		||||
# more details.
 | 
			
		||||
#
 | 
			
		||||
# You should have received a copy of the GNU General Public License version 3
 | 
			
		||||
# along with this program; if not, write to the Free Software Foundation, Inc., 51
 | 
			
		||||
# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
 | 
			
		||||
#
 | 
			
		||||
# (c) 2016 Valentin Samir
 | 
			
		||||
"""Tests module for utils"""
 | 
			
		||||
from django.test import TestCase
 | 
			
		||||
 | 
			
		||||
import six
 | 
			
		||||
 | 
			
		||||
from cas_server import utils
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class CheckPasswordCase(TestCase):
 | 
			
		||||
    """Tests for the utils function `utils.check_password`"""
 | 
			
		||||
 | 
			
		||||
    def setUp(self):
 | 
			
		||||
        """Generate random bytes string that will be used ass passwords"""
 | 
			
		||||
        self.password1 = utils.gen_saml_id()
 | 
			
		||||
        self.password2 = utils.gen_saml_id()
 | 
			
		||||
        if not isinstance(self.password1, bytes):  # pragma: no cover executed only in python3
 | 
			
		||||
            self.password1 = self.password1.encode("utf8")
 | 
			
		||||
            self.password2 = self.password2.encode("utf8")
 | 
			
		||||
 | 
			
		||||
    def test_setup(self):
 | 
			
		||||
        """check that generated password are bytes"""
 | 
			
		||||
        self.assertIsInstance(self.password1, bytes)
 | 
			
		||||
        self.assertIsInstance(self.password2, bytes)
 | 
			
		||||
 | 
			
		||||
    def test_plain(self):
 | 
			
		||||
        """test the plain auth method"""
 | 
			
		||||
        self.assertTrue(utils.check_password("plain", self.password1, self.password1, "utf8"))
 | 
			
		||||
        self.assertFalse(utils.check_password("plain", self.password1, self.password2, "utf8"))
 | 
			
		||||
 | 
			
		||||
    def test_plain_unicode(self):
 | 
			
		||||
        """test the plain auth method with unicode input"""
 | 
			
		||||
        self.assertTrue(
 | 
			
		||||
            utils.check_password(
 | 
			
		||||
                "plain",
 | 
			
		||||
                self.password1.decode("utf8"),
 | 
			
		||||
                self.password1.decode("utf8"),
 | 
			
		||||
                "utf8"
 | 
			
		||||
            )
 | 
			
		||||
        )
 | 
			
		||||
        self.assertFalse(
 | 
			
		||||
            utils.check_password(
 | 
			
		||||
                "plain",
 | 
			
		||||
                self.password1.decode("utf8"),
 | 
			
		||||
                self.password2.decode("utf8"),
 | 
			
		||||
                "utf8"
 | 
			
		||||
            )
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def test_crypt(self):
 | 
			
		||||
        """test the crypt auth method"""
 | 
			
		||||
        salts = ["$6$UVVAQvrMyXMF3FF3", "aa"]
 | 
			
		||||
        hashed_password1 = []
 | 
			
		||||
        for salt in salts:
 | 
			
		||||
            if six.PY3:
 | 
			
		||||
                hashed_password1.append(
 | 
			
		||||
                    utils.crypt.crypt(
 | 
			
		||||
                        self.password1.decode("utf8"),
 | 
			
		||||
                        salt
 | 
			
		||||
                    ).encode("utf8")
 | 
			
		||||
                )
 | 
			
		||||
            else:
 | 
			
		||||
                hashed_password1.append(utils.crypt.crypt(self.password1, salt))
 | 
			
		||||
 | 
			
		||||
        for hp1 in hashed_password1:
 | 
			
		||||
            self.assertTrue(utils.check_password("crypt", self.password1, hp1, "utf8"))
 | 
			
		||||
            self.assertFalse(utils.check_password("crypt", self.password2, hp1, "utf8"))
 | 
			
		||||
 | 
			
		||||
        with self.assertRaises(ValueError):
 | 
			
		||||
            utils.check_password("crypt", self.password1, b"$truc$s$dsdsd", "utf8")
 | 
			
		||||
 | 
			
		||||
    def test_ldap_password_valid(self):
 | 
			
		||||
        """test the ldap auth method with all the schemes"""
 | 
			
		||||
        salt = b"UVVAQvrMyXMF3FF3"
 | 
			
		||||
        schemes_salt = [b"{SMD5}", b"{SSHA}", b"{SSHA256}", b"{SSHA384}", b"{SSHA512}"]
 | 
			
		||||
        schemes_nosalt = [b"{MD5}", b"{SHA}", b"{SHA256}", b"{SHA384}", b"{SHA512}"]
 | 
			
		||||
        hashed_password1 = []
 | 
			
		||||
        for scheme in schemes_salt:
 | 
			
		||||
            hashed_password1.append(
 | 
			
		||||
                utils.LdapHashUserPassword.hash(scheme, self.password1, salt, charset="utf8")
 | 
			
		||||
            )
 | 
			
		||||
        for scheme in schemes_nosalt:
 | 
			
		||||
            hashed_password1.append(
 | 
			
		||||
                utils.LdapHashUserPassword.hash(scheme, self.password1, charset="utf8")
 | 
			
		||||
            )
 | 
			
		||||
        hashed_password1.append(
 | 
			
		||||
            utils.LdapHashUserPassword.hash(
 | 
			
		||||
                b"{CRYPT}",
 | 
			
		||||
                self.password1,
 | 
			
		||||
                b"$6$UVVAQvrMyXMF3FF3",
 | 
			
		||||
                charset="utf8"
 | 
			
		||||
            )
 | 
			
		||||
        )
 | 
			
		||||
        for hp1 in hashed_password1:
 | 
			
		||||
            self.assertIsInstance(hp1, bytes)
 | 
			
		||||
            self.assertTrue(utils.check_password("ldap", self.password1, hp1, "utf8"))
 | 
			
		||||
            self.assertFalse(utils.check_password("ldap", self.password2, hp1, "utf8"))
 | 
			
		||||
 | 
			
		||||
    def test_ldap_password_fail(self):
 | 
			
		||||
        """test the ldap auth method with malformed hash or bad schemes"""
 | 
			
		||||
        salt = b"UVVAQvrMyXMF3FF3"
 | 
			
		||||
        schemes_salt = [b"{SMD5}", b"{SSHA}", b"{SSHA256}", b"{SSHA384}", b"{SSHA512}"]
 | 
			
		||||
        schemes_nosalt = [b"{MD5}", b"{SHA}", b"{SHA256}", b"{SHA384}", b"{SHA512}"]
 | 
			
		||||
 | 
			
		||||
        # first try to hash with bad parameters
 | 
			
		||||
        with self.assertRaises(utils.LdapHashUserPassword.BadScheme):
 | 
			
		||||
            utils.LdapHashUserPassword.hash(b"TOTO", self.password1)
 | 
			
		||||
        for scheme in schemes_nosalt:
 | 
			
		||||
            with self.assertRaises(utils.LdapHashUserPassword.BadScheme):
 | 
			
		||||
                utils.LdapHashUserPassword.hash(scheme, self.password1, salt)
 | 
			
		||||
        for scheme in schemes_salt:
 | 
			
		||||
            with self.assertRaises(utils.LdapHashUserPassword.BadScheme):
 | 
			
		||||
                utils.LdapHashUserPassword.hash(scheme, self.password1)
 | 
			
		||||
        with self.assertRaises(utils.LdapHashUserPassword.BadSalt):
 | 
			
		||||
            utils.LdapHashUserPassword.hash(b'{CRYPT}', self.password1, b"$truc$toto")
 | 
			
		||||
 | 
			
		||||
        # then try to check hash with bad hashes
 | 
			
		||||
        with self.assertRaises(utils.LdapHashUserPassword.BadHash):
 | 
			
		||||
            utils.check_password("ldap", self.password1, b"TOTOssdsdsd", "utf8")
 | 
			
		||||
        for scheme in schemes_salt:
 | 
			
		||||
            with self.assertRaises(utils.LdapHashUserPassword.BadHash):
 | 
			
		||||
                utils.check_password("ldap", self.password1, scheme + b"dG90b3E8ZHNkcw==", "utf8")
 | 
			
		||||
 | 
			
		||||
    def test_hex(self):
 | 
			
		||||
        """test all the hex_HASH method: the hashed password is a simple hash of the password"""
 | 
			
		||||
        hashes = ["md5", "sha1", "sha224", "sha256", "sha384", "sha512"]
 | 
			
		||||
        hashed_password1 = []
 | 
			
		||||
        for hash in hashes:
 | 
			
		||||
            hashed_password1.append(
 | 
			
		||||
                ("hex_%s" % hash, getattr(utils.hashlib, hash)(self.password1).hexdigest())
 | 
			
		||||
            )
 | 
			
		||||
        for (method, hp1) in hashed_password1:
 | 
			
		||||
            self.assertTrue(utils.check_password(method, self.password1, hp1, "utf8"))
 | 
			
		||||
            self.assertFalse(utils.check_password(method, self.password2, hp1, "utf8"))
 | 
			
		||||
 | 
			
		||||
    def test_bad_method(self):
 | 
			
		||||
        """try to check password with a bad method, should raise a ValueError"""
 | 
			
		||||
        with self.assertRaises(ValueError):
 | 
			
		||||
            utils.check_password("test", self.password1, b"$truc$s$dsdsd", "utf8")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class UtilsTestCase(TestCase):
 | 
			
		||||
    """tests for some little utils functions"""
 | 
			
		||||
    def test_import_attr(self):
 | 
			
		||||
        """
 | 
			
		||||
            test the import_attr function. Feeded with a dotted path string, it should
 | 
			
		||||
            import the dotted module and return that last componend of the dotted path
 | 
			
		||||
            (function, class or variable)
 | 
			
		||||
        """
 | 
			
		||||
        with self.assertRaises(ImportError):
 | 
			
		||||
            utils.import_attr('toto.titi.tutu')
 | 
			
		||||
        with self.assertRaises(AttributeError):
 | 
			
		||||
            utils.import_attr('cas_server.utils.toto')
 | 
			
		||||
        with self.assertRaises(ValueError):
 | 
			
		||||
            utils.import_attr('toto')
 | 
			
		||||
        self.assertEqual(
 | 
			
		||||
            utils.import_attr('cas_server.default_app_config'),
 | 
			
		||||
            'cas_server.apps.CasAppConfig'
 | 
			
		||||
        )
 | 
			
		||||
        self.assertEqual(utils.import_attr(utils), utils)
 | 
			
		||||
 | 
			
		||||
    def test_update_url(self):
 | 
			
		||||
        """
 | 
			
		||||
            test the update_url function. Given an url with possible GET parameter and a dict
 | 
			
		||||
            the function build a url with GET parameters updated by the dictionnary
 | 
			
		||||
        """
 | 
			
		||||
        url1 = utils.update_url(u"https://www.example.com?toto=1", {u"tata": u"2"})
 | 
			
		||||
        url2 = utils.update_url(b"https://www.example.com?toto=1", {b"tata": b"2"})
 | 
			
		||||
        self.assertEqual(url1, u"https://www.example.com?tata=2&toto=1")
 | 
			
		||||
        self.assertEqual(url2, u"https://www.example.com?tata=2&toto=1")
 | 
			
		||||
 | 
			
		||||
        url3 = utils.update_url(u"https://www.example.com?toto=1", {u"toto": u"2"})
 | 
			
		||||
        self.assertEqual(url3, u"https://www.example.com?toto=2")
 | 
			
		||||
 | 
			
		||||
    def test_crypt_salt_is_valid(self):
 | 
			
		||||
        """test the function crypt_salt_is_valid who test if a crypt salt is valid"""
 | 
			
		||||
        self.assertFalse(utils.crypt_salt_is_valid(""))  # len 0
 | 
			
		||||
        self.assertFalse(utils.crypt_salt_is_valid("a"))  # len 1
 | 
			
		||||
        self.assertFalse(utils.crypt_salt_is_valid("$$"))  # start with $ followed by $
 | 
			
		||||
        self.assertFalse(utils.crypt_salt_is_valid("$toto"))  # start with $ but no secondary $
 | 
			
		||||
        self.assertFalse(utils.crypt_salt_is_valid("$toto$toto"))  # algorithm toto not known
 | 
			
		||||
							
								
								
									
										1813
									
								
								cas_server/tests/test_view.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1813
									
								
								cas_server/tests/test_view.py
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
							
								
								
									
										22
									
								
								cas_server/tests/urls.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										22
									
								
								cas_server/tests/urls.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,22 @@
 | 
			
		||||
"""cas URL Configuration
 | 
			
		||||
 | 
			
		||||
The `urlpatterns` list routes URLs to views. For more information please see:
 | 
			
		||||
    https://docs.djangoproject.com/en/1.9/topics/http/urls/
 | 
			
		||||
Examples:
 | 
			
		||||
Function views
 | 
			
		||||
    1. Add an import:  from my_app import views
 | 
			
		||||
    2. Add a URL to urlpatterns:  url(r'^$', views.home, name='home')
 | 
			
		||||
Class-based views
 | 
			
		||||
    1. Add an import:  from other_app.views import Home
 | 
			
		||||
    2. Add a URL to urlpatterns:  url(r'^$', Home.as_view(), name='home')
 | 
			
		||||
Including another URLconf
 | 
			
		||||
    1. Import the include() function: from django.conf.urls import url, include, include
 | 
			
		||||
    2. Add a URL to urlpatterns:  url(r'^blog/', include('blog.urls'))
 | 
			
		||||
"""
 | 
			
		||||
from django.conf.urls import url, include
 | 
			
		||||
from django.contrib import admin
 | 
			
		||||
 | 
			
		||||
urlpatterns = [
 | 
			
		||||
    url(r'^admin/', admin.site.urls),
 | 
			
		||||
    url(r'^', include('cas_server.urls', namespace='cas_server')),
 | 
			
		||||
]
 | 
			
		||||
							
								
								
									
										180
									
								
								cas_server/tests/utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										180
									
								
								cas_server/tests/utils.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,180 @@
 | 
			
		||||
# ⁻*- coding: utf-8 -*-
 | 
			
		||||
# This program is distributed in the hope that it will be useful, but WITHOUT
 | 
			
		||||
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
 | 
			
		||||
# FOR A PARTICULAR PURPOSE. See the GNU General Public License version 3 for
 | 
			
		||||
# more details.
 | 
			
		||||
#
 | 
			
		||||
# You should have received a copy of the GNU General Public License version 3
 | 
			
		||||
# along with this program; if not, write to the Free Software Foundation, Inc., 51
 | 
			
		||||
# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
 | 
			
		||||
#
 | 
			
		||||
# (c) 2016 Valentin Samir
 | 
			
		||||
"""Some utils functions for tests"""
 | 
			
		||||
from cas_server.default_settings import settings
 | 
			
		||||
 | 
			
		||||
from django.test import Client
 | 
			
		||||
 | 
			
		||||
import cgi
 | 
			
		||||
from threading import Thread
 | 
			
		||||
from lxml import etree
 | 
			
		||||
from six.moves import BaseHTTPServer
 | 
			
		||||
from six.moves.urllib.parse import urlparse, parse_qsl
 | 
			
		||||
 | 
			
		||||
from cas_server import models
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def copy_form(form):
 | 
			
		||||
    """Copy form value into a dict"""
 | 
			
		||||
    params = {}
 | 
			
		||||
    for field in form:
 | 
			
		||||
        if field.value():
 | 
			
		||||
            params[field.name] = field.value()
 | 
			
		||||
        else:
 | 
			
		||||
            params[field.name] = ""
 | 
			
		||||
    return params
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_login_page_params(client=None):
 | 
			
		||||
    """Return a client and the POST params for the client to login"""
 | 
			
		||||
    if client is None:
 | 
			
		||||
        client = Client()
 | 
			
		||||
    response = client.get('/login')
 | 
			
		||||
    params = copy_form(response.context["form"])
 | 
			
		||||
    return client, params
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_auth_client(**update):
 | 
			
		||||
    """return a authenticated client"""
 | 
			
		||||
    client, params = get_login_page_params()
 | 
			
		||||
    params["username"] = settings.CAS_TEST_USER
 | 
			
		||||
    params["password"] = settings.CAS_TEST_PASSWORD
 | 
			
		||||
    params.update(update)
 | 
			
		||||
 | 
			
		||||
    client.post('/login', params)
 | 
			
		||||
    assert client.session.get("authenticated")
 | 
			
		||||
 | 
			
		||||
    return client
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_user_ticket_request(service):
 | 
			
		||||
    """Make an auth client to request a ticket for `service`, return the tuple (user, ticket)"""
 | 
			
		||||
    client = get_auth_client()
 | 
			
		||||
    response = client.get("/login", {"service": service})
 | 
			
		||||
    ticket_value = response['Location'].split('ticket=')[-1]
 | 
			
		||||
    user = models.User.objects.get(
 | 
			
		||||
        username=settings.CAS_TEST_USER,
 | 
			
		||||
        session_key=client.session.session_key
 | 
			
		||||
    )
 | 
			
		||||
    ticket = models.ServiceTicket.objects.get(value=ticket_value)
 | 
			
		||||
    return (user, ticket, client)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_validated_ticket(service):
 | 
			
		||||
    """Return a tick that has being already validated. Used to test SLO"""
 | 
			
		||||
    (ticket, auth_client) = get_user_ticket_request(service)[1:3]
 | 
			
		||||
 | 
			
		||||
    client = Client()
 | 
			
		||||
    response = client.get('/validate', {'ticket': ticket.value, 'service': service})
 | 
			
		||||
    assert (response.status_code == 200)
 | 
			
		||||
    assert (response.content == b'yes\ntest\n')
 | 
			
		||||
 | 
			
		||||
    ticket = models.ServiceTicket.objects.get(value=ticket.value)
 | 
			
		||||
    return (auth_client, ticket)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_pgt():
 | 
			
		||||
    """return a dict contening a service, user and PGT ticket for this service"""
 | 
			
		||||
    (httpd, host, port) = HttpParamsHandler.run()[0:3]
 | 
			
		||||
    service = "http://%s:%s" % (host, port)
 | 
			
		||||
 | 
			
		||||
    (user, ticket) = get_user_ticket_request(service)[:2]
 | 
			
		||||
 | 
			
		||||
    client = Client()
 | 
			
		||||
    client.get('/serviceValidate', {'ticket': ticket.value, 'service': service, 'pgtUrl': service})
 | 
			
		||||
    params = httpd.PARAMS
 | 
			
		||||
 | 
			
		||||
    params["service"] = service
 | 
			
		||||
    params["user"] = user
 | 
			
		||||
 | 
			
		||||
    return params
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_proxy_ticket(service):
 | 
			
		||||
    """Return a ProxyTicket waiting for validation"""
 | 
			
		||||
    params = get_pgt()
 | 
			
		||||
 | 
			
		||||
    # get a proxy ticket
 | 
			
		||||
    client = Client()
 | 
			
		||||
    response = client.get('/proxy', {'pgt': params['pgtId'], 'targetService': service})
 | 
			
		||||
    root = etree.fromstring(response.content)
 | 
			
		||||
    proxy_ticket = root.xpath(
 | 
			
		||||
        "//cas:proxyTicket",
 | 
			
		||||
        namespaces={'cas': "http://www.yale.edu/tp/cas"}
 | 
			
		||||
    )
 | 
			
		||||
    proxy_ticket = proxy_ticket[0].text
 | 
			
		||||
    ticket = models.ProxyTicket.objects.get(value=proxy_ticket)
 | 
			
		||||
    return ticket
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class HttpParamsHandler(BaseHTTPServer.BaseHTTPRequestHandler):
 | 
			
		||||
    """
 | 
			
		||||
        A simple http server that return 200 on GET or POST
 | 
			
		||||
        and store GET or POST parameters. Used in unit tests
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def do_GET(self):
 | 
			
		||||
        """Called on a GET request on the BaseHTTPServer"""
 | 
			
		||||
        self.send_response(200)
 | 
			
		||||
        self.send_header(b"Content-type", "text/plain")
 | 
			
		||||
        self.end_headers()
 | 
			
		||||
        self.wfile.write(b"ok")
 | 
			
		||||
        url = urlparse(self.path)
 | 
			
		||||
        params = dict(parse_qsl(url.query))
 | 
			
		||||
        self.server.PARAMS = params
 | 
			
		||||
 | 
			
		||||
    def do_POST(self):
 | 
			
		||||
        """Called on a POST request on the BaseHTTPServer"""
 | 
			
		||||
        ctype, pdict = cgi.parse_header(self.headers.get('content-type'))
 | 
			
		||||
        if ctype == 'multipart/form-data':
 | 
			
		||||
            postvars = cgi.parse_multipart(self.rfile, pdict)
 | 
			
		||||
        elif ctype == 'application/x-www-form-urlencoded':
 | 
			
		||||
            length = int(self.headers.get('content-length'))
 | 
			
		||||
            postvars = cgi.parse_qs(self.rfile.read(length), keep_blank_values=1)
 | 
			
		||||
        else:
 | 
			
		||||
            postvars = {}
 | 
			
		||||
        self.server.PARAMS = postvars
 | 
			
		||||
 | 
			
		||||
    def log_message(self, *args):
 | 
			
		||||
        """silent any log message"""
 | 
			
		||||
        return
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def run(cls):
 | 
			
		||||
        """Run a BaseHTTPServer using this class as handler"""
 | 
			
		||||
        server_class = BaseHTTPServer.HTTPServer
 | 
			
		||||
        httpd = server_class(("127.0.0.1", 0), cls)
 | 
			
		||||
        (host, port) = httpd.socket.getsockname()
 | 
			
		||||
 | 
			
		||||
        def lauch():
 | 
			
		||||
            """routine to lauch in a background thread"""
 | 
			
		||||
            httpd.handle_request()
 | 
			
		||||
            httpd.server_close()
 | 
			
		||||
 | 
			
		||||
        httpd_thread = Thread(target=lauch)
 | 
			
		||||
        httpd_thread.daemon = True
 | 
			
		||||
        httpd_thread.start()
 | 
			
		||||
        return (httpd, host, port)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Http404Handler(HttpParamsHandler):
 | 
			
		||||
    """A simple http server that always return 404 not found. Used in unit tests"""
 | 
			
		||||
    def do_GET(self):
 | 
			
		||||
        """Called on a GET request on the BaseHTTPServer"""
 | 
			
		||||
        self.send_response(404)
 | 
			
		||||
        self.send_header(b"Content-type", "text/plain")
 | 
			
		||||
        self.end_headers()
 | 
			
		||||
        self.wfile.write(b"error 404 not found")
 | 
			
		||||
 | 
			
		||||
    def do_POST(self):
 | 
			
		||||
        """Called on a POST request on the BaseHTTPServer"""
 | 
			
		||||
        return self.do_GET()
 | 
			
		||||
		Reference in New Issue
	
	Block a user