Add unit test for the utils function check_password
This commit is contained in:
		@@ -15,7 +15,6 @@ from django.contrib.auth import get_user_model
 | 
			
		||||
try:  # pragma: no cover
 | 
			
		||||
    import MySQLdb
 | 
			
		||||
    import MySQLdb.cursors
 | 
			
		||||
    import crypt
 | 
			
		||||
    from utils import check_password
 | 
			
		||||
except ImportError:
 | 
			
		||||
    MySQLdb = None
 | 
			
		||||
 
 | 
			
		||||
@@ -3,6 +3,7 @@ from .default_settings import settings
 | 
			
		||||
from django.test import TestCase
 | 
			
		||||
from django.test import Client
 | 
			
		||||
 | 
			
		||||
import six
 | 
			
		||||
from lxml import etree
 | 
			
		||||
 | 
			
		||||
from cas_server import models
 | 
			
		||||
@@ -59,6 +60,60 @@ def get_pgt():
 | 
			
		||||
    return params
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class CheckPasswordCase(TestCase):
 | 
			
		||||
 | 
			
		||||
    def setUp(self):
 | 
			
		||||
        self.password1 = utils.gen_saml_id()
 | 
			
		||||
        self.password2 = utils.gen_saml_id()
 | 
			
		||||
        if not isinstance(self.password1, bytes):
 | 
			
		||||
            self.password1 = self.password1.encode("utf8")
 | 
			
		||||
            self.password2 = self.password2.encode("utf8")
 | 
			
		||||
 | 
			
		||||
    def test_setup(self):
 | 
			
		||||
        self.assertIsInstance(self.password1, bytes)
 | 
			
		||||
        self.assertIsInstance(self.password2, bytes)
 | 
			
		||||
 | 
			
		||||
    def test_plain(self):
 | 
			
		||||
        self.assertTrue(utils.check_password("plain", self.password1, self.password1, "utf8"))
 | 
			
		||||
        self.assertFalse(utils.check_password("plain", self.password1, self.password2, "utf8"))
 | 
			
		||||
 | 
			
		||||
    def test_crypt(self):
 | 
			
		||||
        if six.PY3:
 | 
			
		||||
            hashed_password1 = utils.crypt.crypt(
 | 
			
		||||
                self.password1.decode("utf8"),
 | 
			
		||||
                "$6$UVVAQvrMyXMF3FF3"
 | 
			
		||||
            ).encode("utf8")
 | 
			
		||||
        else:
 | 
			
		||||
            hashed_password1 = utils.crypt.crypt(self.password1, "$6$UVVAQvrMyXMF3FF3")
 | 
			
		||||
 | 
			
		||||
        self.assertTrue(utils.check_password("crypt", self.password1, hashed_password1, "utf8"))
 | 
			
		||||
        self.assertFalse(utils.check_password("crypt", self.password2, hashed_password1, "utf8"))
 | 
			
		||||
 | 
			
		||||
    def test_ldap_ssha(self):
 | 
			
		||||
        salt = b"UVVAQvrMyXMF3FF3"
 | 
			
		||||
        hashed_password1 = utils.LdapHashUserPassword.hash(b'{SSHA}', self.password1, salt, "utf8")
 | 
			
		||||
 | 
			
		||||
        self.assertIsInstance(hashed_password1, bytes)
 | 
			
		||||
        self.assertTrue(utils.check_password("ldap", self.password1, hashed_password1, "utf8"))
 | 
			
		||||
        self.assertFalse(utils.check_password("ldap", self.password2, hashed_password1, "utf8"))
 | 
			
		||||
 | 
			
		||||
    def test_hex_md5(self):
 | 
			
		||||
        hashed_password1 = utils.hashlib.md5(self.password1).hexdigest()
 | 
			
		||||
 | 
			
		||||
        self.assertTrue(utils.check_password("hex_md5", self.password1, hashed_password1, "utf8"))
 | 
			
		||||
        self.assertFalse(utils.check_password("hex_md5", self.password2, hashed_password1, "utf8"))
 | 
			
		||||
 | 
			
		||||
    def test_hox_sha512(self):
 | 
			
		||||
        hashed_password1 = utils.hashlib.sha512(self.password1).hexdigest()
 | 
			
		||||
 | 
			
		||||
        self.assertTrue(
 | 
			
		||||
            utils.check_password("hex_sha512", self.password1, hashed_password1, "utf8")
 | 
			
		||||
        )
 | 
			
		||||
        self.assertFalse(
 | 
			
		||||
            utils.check_password("hex_sha512", self.password2, hashed_password1, "utf8")
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class LoginTestCase(TestCase):
 | 
			
		||||
 | 
			
		||||
    def setUp(self):
 | 
			
		||||
 
 | 
			
		||||
@@ -177,6 +177,7 @@ class PGTUrlHandler(BaseHTTPServer.BaseHTTPRequestHandler):
 | 
			
		||||
        httpd_thread.start()
 | 
			
		||||
        return (httpd_thread, host, port)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class LdapHashUserPassword(object):
 | 
			
		||||
    """Please see https://tools.ietf.org/id/draft-stroeder-hashed-userpassword-values-01.html"""
 | 
			
		||||
 | 
			
		||||
@@ -204,8 +205,6 @@ class LdapHashUserPassword(object):
 | 
			
		||||
        b"{SSHA512}": 64,
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    class BadScheme(ValueError):
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
@@ -217,9 +216,9 @@ class LdapHashUserPassword(object):
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def _raise_bad_scheme(cls, scheme, valid, msg):
 | 
			
		||||
        valid_schemes = [s for s in valid]
 | 
			
		||||
        valid_schemes = [s.decode() for s in valid]
 | 
			
		||||
        valid_schemes.sort()
 | 
			
		||||
        raise cls.BadScheme(msg % (scheme, ", ".join(valid_schemes)))
 | 
			
		||||
        raise cls.BadScheme(msg % (scheme, u", ".join(valid_schemes)))
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def _test_scheme(cls, scheme):
 | 
			
		||||
@@ -258,7 +257,9 @@ class LdapHashUserPassword(object):
 | 
			
		||||
        elif salt is not None:
 | 
			
		||||
            cls._test_scheme_salt(scheme)
 | 
			
		||||
        try:
 | 
			
		||||
            return scheme + base64.b64encode(cls._schemes_to_hash[scheme](password + salt).digest() + salt)
 | 
			
		||||
            return scheme + base64.b64encode(
 | 
			
		||||
                cls._schemes_to_hash[scheme](password + salt).digest() + salt
 | 
			
		||||
            )
 | 
			
		||||
        except KeyError:
 | 
			
		||||
            if six.PY3:
 | 
			
		||||
                password = password.decode(charset)
 | 
			
		||||
@@ -272,13 +273,12 @@ class LdapHashUserPassword(object):
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def get_scheme(cls, hashed_passord):
 | 
			
		||||
        if not hashed_passord[0] == b'{' or not b'}' in hashed_passord:
 | 
			
		||||
        if not hashed_passord[0] == b'{'[0] or b'}' not in hashed_passord:
 | 
			
		||||
            raise cls.BadHash("%r should start with the scheme enclosed with { }" % hashed_passord)
 | 
			
		||||
        scheme = hashed_passord.split(b'}', 1)[0]
 | 
			
		||||
        scheme = scheme.upper() + b"}"
 | 
			
		||||
        return scheme
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def get_salt(cls, hashed_passord):
 | 
			
		||||
        scheme = cls.get_scheme(hashed_passord)
 | 
			
		||||
@@ -294,7 +294,6 @@ class LdapHashUserPassword(object):
 | 
			
		||||
            return hashed_passord[cls._schemes_to_len[scheme]:]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def check_password(method, password, hashed_password, charset):
 | 
			
		||||
    if not isinstance(password, six.binary_type):
 | 
			
		||||
        password = password.encode(charset)
 | 
			
		||||
@@ -325,6 +324,9 @@ def check_password(method, password, hashed_password, charset):
 | 
			
		||||
       method.startswith("hex_") and
 | 
			
		||||
       method[4:] in {"md5", "sha1", "sha224", "sha256", "sha384", "sha512"}
 | 
			
		||||
    ):
 | 
			
		||||
        return getattr(hashlib, method[4:])(password).hexdigest() == hashed_password.lower()
 | 
			
		||||
        return getattr(
 | 
			
		||||
            hashlib,
 | 
			
		||||
            method[4:]
 | 
			
		||||
        )(password).hexdigest().encode("ascii") == hashed_password.lower()
 | 
			
		||||
    else:
 | 
			
		||||
        raise ValueError("Unknown password method check %r" % method)
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user