More unit tests (essentially for the login view) and some docstrings
This commit is contained in:
		@@ -5,3 +5,4 @@ exclude_lines =
 | 
			
		||||
    def __unicode__
 | 
			
		||||
    raise AssertionError
 | 
			
		||||
    raise NotImplementedError
 | 
			
		||||
    if six.PY3:
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										3
									
								
								Makefile
									
									
									
									
									
								
							
							
						
						
									
										3
									
								
								Makefile
									
									
									
									
									
								
							@@ -49,8 +49,9 @@ coverage: test_venv
 | 
			
		||||
	test_venv/bin/pip install coverage
 | 
			
		||||
	test_venv/bin/coverage run --source='cas_server' --omit='cas_server/migrations*' run_tests
 | 
			
		||||
	test_venv/bin/coverage html
 | 
			
		||||
	test_venv/bin/coverage xml
 | 
			
		||||
	rm htmlcov/coverage_html.js  # I am really pissed off by those keybord shortcuts
 | 
			
		||||
 | 
			
		||||
coverage_codacy: coverage
 | 
			
		||||
	test_venv/bin/coverage xml
 | 
			
		||||
	test_venv/bin/pip install codacy-coverage
 | 
			
		||||
	test_venv/bin/python-codacy-coverage -r coverage.xml
 | 
			
		||||
 
 | 
			
		||||
@@ -219,7 +219,8 @@ Test backend settings. Only usefull if you are using the test authentication bac
 | 
			
		||||
* ``CAS_TEST_USER``: Username of the test user. The default is ``"test"``.
 | 
			
		||||
* ``CAS_TEST_PASSWORD``: Password of the test user. The default is ``"test"``.
 | 
			
		||||
* ``CAS_TEST_ATTRIBUTES``: Attributes of the test user. The default is
 | 
			
		||||
  ``{'nom': 'Nymous', 'prenom': 'Ano', 'email': 'anonymous@example.net'}``.
 | 
			
		||||
  ``{'nom': 'Nymous', 'prenom': 'Ano', 'email': 'anonymous@example.net',
 | 
			
		||||
  'alias': ['demo1', 'demo2']}``.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
Authentication backend
 | 
			
		||||
 
 | 
			
		||||
@@ -78,5 +78,10 @@ setting_default('CAS_TEST_USER', 'test')
 | 
			
		||||
setting_default('CAS_TEST_PASSWORD', 'test')
 | 
			
		||||
setting_default(
 | 
			
		||||
    'CAS_TEST_ATTRIBUTES',
 | 
			
		||||
    {'nom': 'Nymous', 'prenom': 'Ano', 'email': 'anonymous@example.net'}
 | 
			
		||||
    {
 | 
			
		||||
        'nom': 'Nymous',
 | 
			
		||||
        'prenom': 'Ano',
 | 
			
		||||
        'email': 'anonymous@example.net',
 | 
			
		||||
        'alias': ['demo1', 'demo2']
 | 
			
		||||
    }
 | 
			
		||||
)
 | 
			
		||||
 
 | 
			
		||||
@@ -3,36 +3,49 @@ from .default_settings import settings
 | 
			
		||||
from django.test import TestCase
 | 
			
		||||
from django.test import Client
 | 
			
		||||
 | 
			
		||||
import re
 | 
			
		||||
import six
 | 
			
		||||
import random
 | 
			
		||||
from lxml import etree
 | 
			
		||||
from six.moves import range
 | 
			
		||||
 | 
			
		||||
from cas_server import models
 | 
			
		||||
from cas_server import utils
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_login_page_params():
 | 
			
		||||
    client = Client()
 | 
			
		||||
    response = client.get('/login')
 | 
			
		||||
    form = response.context["form"]
 | 
			
		||||
def copy_form(form):
 | 
			
		||||
    """Copy form value into a dict"""
 | 
			
		||||
    params = {}
 | 
			
		||||
    for field in form:
 | 
			
		||||
        if field.value():
 | 
			
		||||
            params[field.name] = field.value()
 | 
			
		||||
        else:
 | 
			
		||||
            params[field.name] = ""
 | 
			
		||||
    return params
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_login_page_params(client=None):
 | 
			
		||||
    """Return a client and the POST params for the client to login"""
 | 
			
		||||
    if client is None:
 | 
			
		||||
        client = Client()
 | 
			
		||||
    response = client.get('/login')
 | 
			
		||||
    params = copy_form(response.context["form"])
 | 
			
		||||
    return client, params
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_auth_client():
 | 
			
		||||
def get_auth_client(**update):
 | 
			
		||||
    """return a authenticated client"""
 | 
			
		||||
    client, params = get_login_page_params()
 | 
			
		||||
    params["username"] = settings.CAS_TEST_USER
 | 
			
		||||
    params["password"] = settings.CAS_TEST_PASSWORD
 | 
			
		||||
    params.update(update)
 | 
			
		||||
 | 
			
		||||
    client.post('/login', params)
 | 
			
		||||
    return client
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_user_ticket_request(service):
 | 
			
		||||
    """Make an auth client to request a ticket for `service`, return the tuple (user, ticket)"""
 | 
			
		||||
    client = get_auth_client()
 | 
			
		||||
    response = client.get("/login", {"service": service})
 | 
			
		||||
    ticket_value = response['Location'].split('ticket=')[-1]
 | 
			
		||||
@@ -45,6 +58,7 @@ def get_user_ticket_request(service):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_pgt():
 | 
			
		||||
    """return a dict contening a service, user and PGT ticket for this service"""
 | 
			
		||||
    (host, port) = utils.PGTUrlHandler.run()[1:3]
 | 
			
		||||
    service = "http://%s:%s" % (host, port)
 | 
			
		||||
 | 
			
		||||
@@ -110,7 +124,7 @@ class CheckPasswordCase(TestCase):
 | 
			
		||||
        self.assertTrue(utils.check_password("hex_md5", self.password1, hashed_password1, "utf8"))
 | 
			
		||||
        self.assertFalse(utils.check_password("hex_md5", self.password2, hashed_password1, "utf8"))
 | 
			
		||||
 | 
			
		||||
    def test_hox_sha512(self):
 | 
			
		||||
    def test_hex_sha512(self):
 | 
			
		||||
        """test the hex_sha512 auth method"""
 | 
			
		||||
        hashed_password1 = utils.hashlib.sha512(self.password1).hexdigest()
 | 
			
		||||
 | 
			
		||||
@@ -123,29 +137,83 @@ class CheckPasswordCase(TestCase):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class LoginTestCase(TestCase):
 | 
			
		||||
 | 
			
		||||
    """Tests for the login view"""
 | 
			
		||||
    def setUp(self):
 | 
			
		||||
        """
 | 
			
		||||
            Prepare the test context:
 | 
			
		||||
                * set the auth class to 'cas_server.auth.TestAuthUser'
 | 
			
		||||
                * create a service pattern for https://www.example.com/**
 | 
			
		||||
                * Set the service pattern to return all user attributes
 | 
			
		||||
        """
 | 
			
		||||
        settings.CAS_AUTH_CLASS = 'cas_server.auth.TestAuthUser'
 | 
			
		||||
 | 
			
		||||
        # For general purpose testing
 | 
			
		||||
        self.service_pattern = models.ServicePattern.objects.create(
 | 
			
		||||
            name="example",
 | 
			
		||||
            pattern="^https://www\.example\.com(/.*)?$",
 | 
			
		||||
        )
 | 
			
		||||
        models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern)
 | 
			
		||||
 | 
			
		||||
    def test_login_view_post_goodpass_goodlt(self):
 | 
			
		||||
        client, params = get_login_page_params()
 | 
			
		||||
        params["username"] = settings.CAS_TEST_USER
 | 
			
		||||
        params["password"] = settings.CAS_TEST_PASSWORD
 | 
			
		||||
        # For testing the restrict_users attributes
 | 
			
		||||
        self.service_pattern_restrict_user_fail = models.ServicePattern.objects.create(
 | 
			
		||||
            name="restrict_user_fail",
 | 
			
		||||
            pattern="^https://restrict_user_fail\.example\.com(/.*)?$",
 | 
			
		||||
            restrict_users=True,
 | 
			
		||||
        )
 | 
			
		||||
        self.service_pattern_restrict_user_success = models.ServicePattern.objects.create(
 | 
			
		||||
            name="restrict_user_success",
 | 
			
		||||
            pattern="^https://restrict_user_success\.example\.com(/.*)?$",
 | 
			
		||||
            restrict_users=True,
 | 
			
		||||
        )
 | 
			
		||||
        models.Username.objects.create(
 | 
			
		||||
            value=settings.CAS_TEST_USER,
 | 
			
		||||
            service_pattern=self.service_pattern_restrict_user_success
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        response = client.post('/login', params)
 | 
			
		||||
        # For testing the user attributes filtering conditions
 | 
			
		||||
        self.service_pattern_filter_fail = models.ServicePattern.objects.create(
 | 
			
		||||
            name="filter_fail",
 | 
			
		||||
            pattern="^https://filter_fail\.example\.com(/.*)?$",
 | 
			
		||||
        )
 | 
			
		||||
        models.FilterAttributValue.objects.create(
 | 
			
		||||
            attribut="right",
 | 
			
		||||
            pattern="^admin$",
 | 
			
		||||
            service_pattern=self.service_pattern_filter_fail
 | 
			
		||||
        )
 | 
			
		||||
        self.service_pattern_filter_success = models.ServicePattern.objects.create(
 | 
			
		||||
            name="filter_success",
 | 
			
		||||
            pattern="^https://filter_success\.example\.com(/.*)?$",
 | 
			
		||||
        )
 | 
			
		||||
        models.FilterAttributValue.objects.create(
 | 
			
		||||
            attribut="email",
 | 
			
		||||
            pattern="^%s$" % re.escape(settings.CAS_TEST_ATTRIBUTES['email']),
 | 
			
		||||
            service_pattern=self.service_pattern_filter_success
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(response.status_code, 200)
 | 
			
		||||
        # For testing the user_field attributes
 | 
			
		||||
        self.service_pattern_field_needed_fail = models.ServicePattern.objects.create(
 | 
			
		||||
            name="field_needed_fail",
 | 
			
		||||
            pattern="^https://field_needed_fail\.example\.com(/.*)?$",
 | 
			
		||||
            user_field="uid"
 | 
			
		||||
        )
 | 
			
		||||
        self.service_pattern_field_needed_success = models.ServicePattern.objects.create(
 | 
			
		||||
            name="field_needed_success",
 | 
			
		||||
            pattern="^https://field_needed_success\.example\.com(/.*)?$",
 | 
			
		||||
            user_field="nom"
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def assert_logged(self, client, response, warn=False, code=200):
 | 
			
		||||
        """Assertions testing that client is well authenticated"""
 | 
			
		||||
        self.assertEqual(response.status_code, code)
 | 
			
		||||
        self.assertTrue(
 | 
			
		||||
            (
 | 
			
		||||
                b"You have successfully logged into "
 | 
			
		||||
                b"the Central Authentication Service"
 | 
			
		||||
            ) in response.content
 | 
			
		||||
        )
 | 
			
		||||
        self.assertTrue(client.session["username"] == settings.CAS_TEST_USER)
 | 
			
		||||
        self.assertTrue(client.session["warn"] is warn)
 | 
			
		||||
        self.assertTrue(client.session["authenticated"] is True)
 | 
			
		||||
 | 
			
		||||
        self.assertTrue(
 | 
			
		||||
            models.User.objects.get(
 | 
			
		||||
@@ -154,7 +222,59 @@ class LoginTestCase(TestCase):
 | 
			
		||||
            )
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def assert_login_failed(self, client, response, code=200):
 | 
			
		||||
        """Assertions testing a failed login attempt"""
 | 
			
		||||
        self.assertEqual(response.status_code, code)
 | 
			
		||||
        self.assertFalse(
 | 
			
		||||
            (
 | 
			
		||||
                b"You have successfully logged into "
 | 
			
		||||
                b"the Central Authentication Service"
 | 
			
		||||
            ) in response.content
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        self.assertTrue(client.session.get("username") is None)
 | 
			
		||||
        self.assertTrue(client.session.get("warn") is None)
 | 
			
		||||
        self.assertTrue(client.session.get("authenticated") is None)
 | 
			
		||||
 | 
			
		||||
    def test_login_view_post_goodpass_goodlt(self):
 | 
			
		||||
        """Test a successul login"""
 | 
			
		||||
        client, params = get_login_page_params()
 | 
			
		||||
        params["username"] = settings.CAS_TEST_USER
 | 
			
		||||
        params["password"] = settings.CAS_TEST_PASSWORD
 | 
			
		||||
        self.assertTrue(params['lt'] in client.session['lt'])
 | 
			
		||||
 | 
			
		||||
        response = client.post('/login', params)
 | 
			
		||||
        self.assert_logged(client, response)
 | 
			
		||||
        # LoginTicket conssumed
 | 
			
		||||
        self.assertTrue(params['lt'] not in client.session['lt'])
 | 
			
		||||
 | 
			
		||||
    def test_login_view_post_goodpass_goodlt_warn(self):
 | 
			
		||||
        """Test a successul login requesting to be warned before creating services tickets"""
 | 
			
		||||
        client, params = get_login_page_params()
 | 
			
		||||
        params["username"] = settings.CAS_TEST_USER
 | 
			
		||||
        params["password"] = settings.CAS_TEST_PASSWORD
 | 
			
		||||
        params["warn"] = "on"
 | 
			
		||||
 | 
			
		||||
        response = client.post('/login', params)
 | 
			
		||||
        self.assert_logged(client, response, warn=True)
 | 
			
		||||
 | 
			
		||||
    def test_lt_max(self):
 | 
			
		||||
        """Check we only keep the last 100 Login Ticket for a user"""
 | 
			
		||||
        client, params = get_login_page_params()
 | 
			
		||||
        current_lt = params["lt"]
 | 
			
		||||
        i_in_test = random.randint(0, 100)
 | 
			
		||||
        i_not_in_test = random.randint(100, 150)
 | 
			
		||||
        for i in range(150):
 | 
			
		||||
            if i == i_in_test:
 | 
			
		||||
                self.assertTrue(current_lt in client.session['lt'])
 | 
			
		||||
            if i == i_not_in_test:
 | 
			
		||||
                self.assertTrue(current_lt not in client.session['lt'])
 | 
			
		||||
                self.assertTrue(len(client.session['lt']) <= 100)
 | 
			
		||||
            client, params = get_login_page_params(client)
 | 
			
		||||
        self.assertTrue(len(client.session['lt']) <= 100)
 | 
			
		||||
 | 
			
		||||
    def test_login_view_post_badlt(self):
 | 
			
		||||
        """Login attempt with a bad LoginTicket"""
 | 
			
		||||
        client, params = get_login_page_params()
 | 
			
		||||
        params["username"] = settings.CAS_TEST_USER
 | 
			
		||||
        params["password"] = settings.CAS_TEST_PASSWORD
 | 
			
		||||
@@ -162,47 +282,26 @@ class LoginTestCase(TestCase):
 | 
			
		||||
 | 
			
		||||
        response = client.post('/login', params)
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(response.status_code, 200)
 | 
			
		||||
        self.assert_login_failed(client, response)
 | 
			
		||||
        self.assertTrue(b"Invalid login ticket" in response.content)
 | 
			
		||||
        self.assertFalse(
 | 
			
		||||
            (
 | 
			
		||||
                b"You have successfully logged into "
 | 
			
		||||
                b"the Central Authentication Service"
 | 
			
		||||
            ) in response.content
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def test_login_view_post_badpass_good_lt(self):
 | 
			
		||||
        """Login attempt with a bad password"""
 | 
			
		||||
        client, params = get_login_page_params()
 | 
			
		||||
        params["username"] = settings.CAS_TEST_USER
 | 
			
		||||
        params["password"] = "test2"
 | 
			
		||||
        response = client.post('/login', params)
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(response.status_code, 200)
 | 
			
		||||
        self.assert_login_failed(client, response)
 | 
			
		||||
        self.assertTrue(
 | 
			
		||||
            (
 | 
			
		||||
                b"The credentials you provided cannot be "
 | 
			
		||||
                b"determined to be authentic"
 | 
			
		||||
            ) in response.content
 | 
			
		||||
        )
 | 
			
		||||
        self.assertFalse(
 | 
			
		||||
            (
 | 
			
		||||
                b"You have successfully logged into "
 | 
			
		||||
                b"the Central Authentication Service"
 | 
			
		||||
            ) in response.content
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def test_view_login_get_auth_allowed_service(self):
 | 
			
		||||
        client = get_auth_client()
 | 
			
		||||
        response = client.get("/login?service=https://www.example.com")
 | 
			
		||||
        self.assertEqual(response.status_code, 302)
 | 
			
		||||
        self.assertTrue(response.has_header('Location'))
 | 
			
		||||
        self.assertTrue(
 | 
			
		||||
            response['Location'].startswith(
 | 
			
		||||
                "https://www.example.com?ticket=%s-" % settings.CAS_SERVICE_TICKET_PREFIX
 | 
			
		||||
            )
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        ticket_value = response['Location'].split('ticket=')[-1]
 | 
			
		||||
    def assert_ticket_attributes(self, client, ticket_value):
 | 
			
		||||
        """check the ticket attributes in the db"""
 | 
			
		||||
        user = models.User.objects.get(
 | 
			
		||||
            username=settings.CAS_TEST_USER,
 | 
			
		||||
            session_key=client.session.session_key
 | 
			
		||||
@@ -214,12 +313,136 @@ class LoginTestCase(TestCase):
 | 
			
		||||
        self.assertEqual(ticket.validate, False)
 | 
			
		||||
        self.assertEqual(ticket.service_pattern, self.service_pattern)
 | 
			
		||||
 | 
			
		||||
    def assert_service_ticket(self, client, response):
 | 
			
		||||
        """check that a ticket is well emited when requested on a allowed service"""
 | 
			
		||||
        self.assertEqual(response.status_code, 302)
 | 
			
		||||
        self.assertTrue(response.has_header('Location'))
 | 
			
		||||
        self.assertTrue(
 | 
			
		||||
            response['Location'].startswith(
 | 
			
		||||
                "https://www.example.com?ticket=%s-" % settings.CAS_SERVICE_TICKET_PREFIX
 | 
			
		||||
            )
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        ticket_value = response['Location'].split('ticket=')[-1]
 | 
			
		||||
        self.assert_ticket_attributes(client, ticket_value)
 | 
			
		||||
 | 
			
		||||
    def test_view_login_get_allowed_service(self):
 | 
			
		||||
        """Request a ticket for an allowed service by an unauthenticated client"""
 | 
			
		||||
        client = Client()
 | 
			
		||||
        response = client.get("/login?service=https://www.example.com")
 | 
			
		||||
        self.assertEqual(response.status_code, 200)
 | 
			
		||||
        self.assertTrue(
 | 
			
		||||
            (
 | 
			
		||||
                "Authentication required by service "
 | 
			
		||||
                "example (https://www.example.com)"
 | 
			
		||||
            ) in response.content
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def test_view_login_get_denied_service(self):
 | 
			
		||||
        """Request a ticket for an denied service by an unauthenticated client"""
 | 
			
		||||
        client = Client()
 | 
			
		||||
        response = client.get("/login?service=https://www.example.net")
 | 
			
		||||
        self.assertEqual(response.status_code, 200)
 | 
			
		||||
        self.assertTrue("Service https://www.example.net non allowed" in response.content)
 | 
			
		||||
 | 
			
		||||
    def test_view_login_get_auth_allowed_service(self):
 | 
			
		||||
        """Request a ticket for an allowed service by an authenticated client"""
 | 
			
		||||
        # client is already authenticated
 | 
			
		||||
        client = get_auth_client()
 | 
			
		||||
        response = client.get("/login?service=https://www.example.com")
 | 
			
		||||
        self.assert_service_ticket(client, response)
 | 
			
		||||
 | 
			
		||||
    def test_view_login_get_auth_allowed_service_warn(self):
 | 
			
		||||
        """Request a ticket for an allowed service by an authenticated client"""
 | 
			
		||||
        # client is already authenticated
 | 
			
		||||
        client = get_auth_client(warn="on")
 | 
			
		||||
        response = client.get("/login?service=https://www.example.com")
 | 
			
		||||
        self.assertEqual(response.status_code, 200)
 | 
			
		||||
        self.assertTrue(
 | 
			
		||||
            (
 | 
			
		||||
                "Authentication has been required by service "
 | 
			
		||||
                "example (https://www.example.com)"
 | 
			
		||||
            ) in response.content
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        params = copy_form(response.context["form"])
 | 
			
		||||
        response = client.post("/login", params)
 | 
			
		||||
        self.assert_service_ticket(client, response)
 | 
			
		||||
 | 
			
		||||
    def test_view_login_get_auth_denied_service(self):
 | 
			
		||||
        """Request a ticket for a not allowed service by an authenticated client"""
 | 
			
		||||
        client = get_auth_client()
 | 
			
		||||
        response = client.get("/login?service=https://www.example.org")
 | 
			
		||||
        self.assertEqual(response.status_code, 200)
 | 
			
		||||
        self.assertTrue(b"Service https://www.example.org non allowed" in response.content)
 | 
			
		||||
 | 
			
		||||
    def test_user_logged_not_in_db(self):
 | 
			
		||||
        """If the user is logged but has been delete from the database, it should be logged out"""
 | 
			
		||||
        client = get_auth_client()
 | 
			
		||||
        models.User.objects.get(
 | 
			
		||||
            username=settings.CAS_TEST_USER,
 | 
			
		||||
            session_key=client.session.session_key
 | 
			
		||||
        ).delete()
 | 
			
		||||
        response = client.get("/login")
 | 
			
		||||
 | 
			
		||||
        self.assert_login_failed(client, response, code=302)
 | 
			
		||||
        self.assertEqual(response["Location"], "/login?")
 | 
			
		||||
 | 
			
		||||
    def test_service_restrict_user(self):
 | 
			
		||||
        """Testing the restric user capability fro a service"""
 | 
			
		||||
        service = "https://restrict_user_fail.example.com"
 | 
			
		||||
        client = get_auth_client()
 | 
			
		||||
        response = client.get("/login", {'service': service})
 | 
			
		||||
        self.assertEqual(response.status_code, 200)
 | 
			
		||||
        self.assertTrue("Username non allowed" in response.content)
 | 
			
		||||
 | 
			
		||||
        service = "https://restrict_user_success.example.com"
 | 
			
		||||
        response = client.get("/login", {'service': service})
 | 
			
		||||
        self.assertEqual(response.status_code, 302)
 | 
			
		||||
        self.assertTrue(response["Location"].startswith("%s?ticket=" % service))
 | 
			
		||||
 | 
			
		||||
    def test_service_filter(self):
 | 
			
		||||
        """Test the filtering on user attributes"""
 | 
			
		||||
        service = "https://filter_fail.example.com"
 | 
			
		||||
        client = get_auth_client()
 | 
			
		||||
        response = client.get("/login", {'service': service})
 | 
			
		||||
        self.assertEqual(response.status_code, 200)
 | 
			
		||||
        self.assertTrue("User charateristics non allowed" in response.content)
 | 
			
		||||
 | 
			
		||||
        service = "https://filter_success.example.com"
 | 
			
		||||
        response = client.get("/login", {'service': service})
 | 
			
		||||
        self.assertEqual(response.status_code, 302)
 | 
			
		||||
        self.assertTrue(response["Location"].startswith("%s?ticket=" % service))
 | 
			
		||||
 | 
			
		||||
    def test_service_user_field(self):
 | 
			
		||||
        """Test using a user attribute as username: case on if the attribute exists or not"""
 | 
			
		||||
        service = "https://field_needed_fail.example.com"
 | 
			
		||||
        client = get_auth_client()
 | 
			
		||||
        response = client.get("/login", {'service': service})
 | 
			
		||||
        self.assertEqual(response.status_code, 200)
 | 
			
		||||
        self.assertTrue("The attribut uid is needed to use that service" in response.content)
 | 
			
		||||
 | 
			
		||||
        service = "https://field_needed_success.example.com"
 | 
			
		||||
        response = client.get("/login", {'service': service})
 | 
			
		||||
        self.assertEqual(response.status_code, 302)
 | 
			
		||||
        self.assertTrue(response["Location"].startswith("%s?ticket=" % service))
 | 
			
		||||
 | 
			
		||||
    def test_gateway(self):
 | 
			
		||||
        """test gateway parameter"""
 | 
			
		||||
 | 
			
		||||
        # First with an authenticated client that fail to get a ticket for a service
 | 
			
		||||
        service = "https://restrict_user_fail.example.com"
 | 
			
		||||
        client = get_auth_client()
 | 
			
		||||
        response = client.get("/login", {'service': service, 'gateway': 'on'})
 | 
			
		||||
        self.assertEqual(response.status_code, 302)
 | 
			
		||||
        self.assertEqual(response["Location"], service)
 | 
			
		||||
 | 
			
		||||
        # second for an user not yet authenticated on a valid service
 | 
			
		||||
        client = Client()
 | 
			
		||||
        response = client.get('/login', {'service': service, 'gateway': 'on'})
 | 
			
		||||
        self.assertEqual(response.status_code, 302)
 | 
			
		||||
        self.assertEqual(response["Location"], service)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class LogoutTestCase(TestCase):
 | 
			
		||||
 | 
			
		||||
@@ -454,17 +677,24 @@ class ValidateServiceTestCase(TestCase):
 | 
			
		||||
            namespaces={'cas': "http://www.yale.edu/tp/cas"}
 | 
			
		||||
        )
 | 
			
		||||
        self.assertEqual(len(attributes), 1)
 | 
			
		||||
        attrs1 = {}
 | 
			
		||||
        attrs1 = set()
 | 
			
		||||
        for attr in attributes[0]:
 | 
			
		||||
            attrs1[attr.tag[len("http://www.yale.edu/tp/cas")+2:]] = attr.text
 | 
			
		||||
            attrs1.add((attr.tag[len("http://www.yale.edu/tp/cas")+2:], attr.text))
 | 
			
		||||
 | 
			
		||||
        attributes = root.xpath("//cas:attribute", namespaces={'cas': "http://www.yale.edu/tp/cas"})
 | 
			
		||||
        self.assertEqual(len(attributes), len(attrs1))
 | 
			
		||||
        attrs2 = {}
 | 
			
		||||
        attrs2 = set()
 | 
			
		||||
        for attr in attributes:
 | 
			
		||||
            attrs2[attr.attrib['name']] = attr.attrib['value']
 | 
			
		||||
            attrs2.add((attr.attrib['name'], attr.attrib['value']))
 | 
			
		||||
        original = set()
 | 
			
		||||
        for key, value in settings.CAS_TEST_ATTRIBUTES.items():
 | 
			
		||||
            if isinstance(value, list):
 | 
			
		||||
                for v in value:
 | 
			
		||||
                    original.add((key, v))
 | 
			
		||||
            else:
 | 
			
		||||
                original.add((key, value))
 | 
			
		||||
        self.assertEqual(attrs1, attrs2)
 | 
			
		||||
        self.assertEqual(attrs1, settings.CAS_TEST_ATTRIBUTES)
 | 
			
		||||
        self.assertEqual(attrs1, original)
 | 
			
		||||
 | 
			
		||||
    def test_validate_service_view_badservice(self):
 | 
			
		||||
        ticket = get_user_ticket_request(self.service)[1]
 | 
			
		||||
@@ -623,17 +853,24 @@ class ProxyTestCase(TestCase):
 | 
			
		||||
            namespaces={'cas': "http://www.yale.edu/tp/cas"}
 | 
			
		||||
        )
 | 
			
		||||
        self.assertEqual(len(attributes), 1)
 | 
			
		||||
        attrs1 = {}
 | 
			
		||||
        attrs1 = set()
 | 
			
		||||
        for attr in attributes[0]:
 | 
			
		||||
            attrs1[attr.tag[len("http://www.yale.edu/tp/cas")+2:]] = attr.text
 | 
			
		||||
            attrs1.add((attr.tag[len("http://www.yale.edu/tp/cas")+2:], attr.text))
 | 
			
		||||
 | 
			
		||||
        attributes = root.xpath("//cas:attribute", namespaces={'cas': "http://www.yale.edu/tp/cas"})
 | 
			
		||||
        self.assertEqual(len(attributes), len(attrs1))
 | 
			
		||||
        attrs2 = {}
 | 
			
		||||
        attrs2 = set()
 | 
			
		||||
        for attr in attributes:
 | 
			
		||||
            attrs2[attr.attrib['name']] = attr.attrib['value']
 | 
			
		||||
            attrs2.add((attr.attrib['name'], attr.attrib['value']))
 | 
			
		||||
        original = set()
 | 
			
		||||
        for key, value in settings.CAS_TEST_ATTRIBUTES.items():
 | 
			
		||||
            if isinstance(value, list):
 | 
			
		||||
                for v in value:
 | 
			
		||||
                    original.add((key, v))
 | 
			
		||||
            else:
 | 
			
		||||
                original.add((key, value))
 | 
			
		||||
        self.assertEqual(attrs1, attrs2)
 | 
			
		||||
        self.assertEqual(attrs1, settings.CAS_TEST_ATTRIBUTES)
 | 
			
		||||
        self.assertEqual(attrs1, original)
 | 
			
		||||
 | 
			
		||||
    def test_validate_proxy_bad(self):
 | 
			
		||||
        params = get_pgt()
 | 
			
		||||
 
 | 
			
		||||
@@ -105,6 +105,7 @@ class LogoutView(View, LogoutMixin):
 | 
			
		||||
    service = None
 | 
			
		||||
 | 
			
		||||
    def init_get(self, request):
 | 
			
		||||
        """Initialize GET received parameters"""
 | 
			
		||||
        self.request = request
 | 
			
		||||
        self.service = request.GET.get('service')
 | 
			
		||||
        self.url = request.GET.get('url')
 | 
			
		||||
@@ -196,6 +197,7 @@ class LoginView(View, LogoutMixin):
 | 
			
		||||
    USER_NOT_AUTHENTICATED = 6
 | 
			
		||||
 | 
			
		||||
    def init_post(self, request):
 | 
			
		||||
        """Initialize POST received parameters"""
 | 
			
		||||
        self.request = request
 | 
			
		||||
        self.service = request.POST.get('service')
 | 
			
		||||
        self.renew = bool(request.POST.get('renew') and request.POST['renew'] != "False")
 | 
			
		||||
@@ -205,15 +207,19 @@ class LoginView(View, LogoutMixin):
 | 
			
		||||
        if request.POST.get('warned') and request.POST['warned'] != "False":
 | 
			
		||||
            self.warned = True
 | 
			
		||||
 | 
			
		||||
    def check_lt(self):
 | 
			
		||||
        # save LT for later check
 | 
			
		||||
        lt_valid = self.request.session.get('lt', [])
 | 
			
		||||
        lt_send = self.request.POST.get('lt')
 | 
			
		||||
        # generate a new LT (by posting the LT has been consumed)
 | 
			
		||||
    def gen_lt(self):
 | 
			
		||||
        """Generate a new LoginTicket and add it to the list of valid LT for the user"""
 | 
			
		||||
        self.request.session['lt'] = self.request.session.get('lt', []) + [utils.gen_lt()]
 | 
			
		||||
        if len(self.request.session['lt']) > 100:
 | 
			
		||||
            self.request.session['lt'] = self.request.session['lt'][-100:]
 | 
			
		||||
 | 
			
		||||
    def check_lt(self):
 | 
			
		||||
        """Check is the POSTed LoginTicket is valid, if yes invalide it"""
 | 
			
		||||
        # save LT for later check
 | 
			
		||||
        lt_valid = self.request.session.get('lt', [])
 | 
			
		||||
        lt_send = self.request.POST.get('lt')
 | 
			
		||||
        # generate a new LT (by posting the LT has been consumed)
 | 
			
		||||
        self.gen_lt()
 | 
			
		||||
        # check if send LT is valid
 | 
			
		||||
        if lt_valid is None or lt_send not in lt_valid:
 | 
			
		||||
            return False
 | 
			
		||||
@@ -238,7 +244,7 @@ class LoginView(View, LogoutMixin):
 | 
			
		||||
                    username=self.request.session['username'],
 | 
			
		||||
                    session_key=self.request.session.session_key
 | 
			
		||||
                )
 | 
			
		||||
                self.user.save()
 | 
			
		||||
                self.user.save()  # pragma: no cover (should not happend)
 | 
			
		||||
            except models.User.DoesNotExist:
 | 
			
		||||
                self.user = models.User.objects.create(
 | 
			
		||||
                    username=self.request.session['username'],
 | 
			
		||||
@@ -250,10 +256,15 @@ class LoginView(View, LogoutMixin):
 | 
			
		||||
        elif ret == self.USER_ALREADY_LOGGED:
 | 
			
		||||
            pass
 | 
			
		||||
        else:
 | 
			
		||||
            raise EnvironmentError("invalid output for LoginView.process_post")
 | 
			
		||||
            raise EnvironmentError("invalid output for LoginView.process_post")  # pragma: no cover
 | 
			
		||||
        return self.common()
 | 
			
		||||
 | 
			
		||||
    def process_post(self):
 | 
			
		||||
        """
 | 
			
		||||
            Analyse the POST request:
 | 
			
		||||
                * check that the LoginTicket is valid
 | 
			
		||||
                * check that the user sumited credentials are valid
 | 
			
		||||
        """
 | 
			
		||||
        if not self.check_lt():
 | 
			
		||||
            values = self.request.POST.copy()
 | 
			
		||||
            # if not set a new LT and fail
 | 
			
		||||
@@ -280,6 +291,7 @@ class LoginView(View, LogoutMixin):
 | 
			
		||||
            return self.USER_ALREADY_LOGGED
 | 
			
		||||
 | 
			
		||||
    def init_get(self, request):
 | 
			
		||||
        """Initialize GET received parameters"""
 | 
			
		||||
        self.request = request
 | 
			
		||||
        self.service = request.GET.get('service')
 | 
			
		||||
        self.renew = bool(request.GET.get('renew') and request.GET['renew'] != "False")
 | 
			
		||||
@@ -294,15 +306,16 @@ class LoginView(View, LogoutMixin):
 | 
			
		||||
        return self.common()
 | 
			
		||||
 | 
			
		||||
    def process_get(self):
 | 
			
		||||
        # generate a new LT if none is present
 | 
			
		||||
        self.request.session['lt'] = self.request.session.get('lt', []) + [utils.gen_lt()]
 | 
			
		||||
 | 
			
		||||
        """Analyse the GET request"""
 | 
			
		||||
        # generate a new LT
 | 
			
		||||
        self.gen_lt()
 | 
			
		||||
        if not self.request.session.get("authenticated") or self.renew:
 | 
			
		||||
            self.init_form()
 | 
			
		||||
            return self.USER_NOT_AUTHENTICATED
 | 
			
		||||
        return self.USER_AUTHENTICATED
 | 
			
		||||
 | 
			
		||||
    def init_form(self, values=None):
 | 
			
		||||
        """Initialization of the good form depending of POST and GET parameters"""
 | 
			
		||||
        self.form = forms.UserCredential(
 | 
			
		||||
            values,
 | 
			
		||||
            initial={
 | 
			
		||||
 
 | 
			
		||||
@@ -52,7 +52,7 @@ MIDDLEWARE_CLASSES = [
 | 
			
		||||
    'django.middleware.locale.LocaleMiddleware',
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
ROOT_URLCONF = 'cas_server.urls'
 | 
			
		||||
ROOT_URLCONF = 'urls_tests'
 | 
			
		||||
 | 
			
		||||
# Database
 | 
			
		||||
# https://docs.djangoproject.com/en/1.9/ref/settings/#databases
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										22
									
								
								urls_tests.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										22
									
								
								urls_tests.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,22 @@
 | 
			
		||||
"""cas URL Configuration
 | 
			
		||||
 | 
			
		||||
The `urlpatterns` list routes URLs to views. For more information please see:
 | 
			
		||||
    https://docs.djangoproject.com/en/1.9/topics/http/urls/
 | 
			
		||||
Examples:
 | 
			
		||||
Function views
 | 
			
		||||
    1. Add an import:  from my_app import views
 | 
			
		||||
    2. Add a URL to urlpatterns:  url(r'^$', views.home, name='home')
 | 
			
		||||
Class-based views
 | 
			
		||||
    1. Add an import:  from other_app.views import Home
 | 
			
		||||
    2. Add a URL to urlpatterns:  url(r'^$', Home.as_view(), name='home')
 | 
			
		||||
Including another URLconf
 | 
			
		||||
    1. Import the include() function: from django.conf.urls import url, include, include
 | 
			
		||||
    2. Add a URL to urlpatterns:  url(r'^blog/', include('blog.urls'))
 | 
			
		||||
"""
 | 
			
		||||
from django.conf.urls import url, include
 | 
			
		||||
from django.contrib import admin
 | 
			
		||||
 | 
			
		||||
urlpatterns = [
 | 
			
		||||
    url(r'^admin/', admin.site.urls),
 | 
			
		||||
    url(r'^', include('cas_server.urls', namespace='cas_server')),
 | 
			
		||||
]
 | 
			
		||||
		Reference in New Issue
	
	Block a user