Current File : //opt/alt/python37/lib/python3.7/site-packages/jwt/algorithms.py
import hashlib
import hmac
import json

from .exceptions import InvalidKeyError
from .utils import (
    base64url_decode,
    base64url_encode,
    der_to_raw_signature,
    force_bytes,
    from_base64url_uint,
    raw_to_der_signature,
    to_base64url_uint,
)

try:
    import cryptography.exceptions
    from cryptography.exceptions import InvalidSignature
    from cryptography.hazmat.primitives import hashes
    from cryptography.hazmat.primitives.asymmetric import ec, padding
    from cryptography.hazmat.primitives.asymmetric.ec import (
        EllipticCurvePrivateKey,
        EllipticCurvePublicKey,
    )
    from cryptography.hazmat.primitives.asymmetric.ed25519 import (
        Ed25519PrivateKey,
        Ed25519PublicKey,
    )
    from cryptography.hazmat.primitives.asymmetric.rsa import (
        RSAPrivateKey,
        RSAPrivateNumbers,
        RSAPublicKey,
        RSAPublicNumbers,
        rsa_crt_dmp1,
        rsa_crt_dmq1,
        rsa_crt_iqmp,
        rsa_recover_prime_factors,
    )
    from cryptography.hazmat.primitives.serialization import (
        Encoding,
        NoEncryption,
        PrivateFormat,
        PublicFormat,
        load_pem_private_key,
        load_pem_public_key,
        load_ssh_public_key,
    )

    has_crypto = True
except ModuleNotFoundError:
    has_crypto = False

requires_cryptography = {
    "RS256",
    "RS384",
    "RS512",
    "ES256",
    "ES256K",
    "ES384",
    "ES521",
    "ES512",
    "PS256",
    "PS384",
    "PS512",
    "EdDSA",
}


def get_default_algorithms():
    """
    Returns the algorithms that are implemented by the library.
    """
    default_algorithms = {
        "none": NoneAlgorithm(),
        "HS256": HMACAlgorithm(HMACAlgorithm.SHA256),
        "HS384": HMACAlgorithm(HMACAlgorithm.SHA384),
        "HS512": HMACAlgorithm(HMACAlgorithm.SHA512),
    }

    if has_crypto:
        default_algorithms.update(
            {
                "RS256": RSAAlgorithm(RSAAlgorithm.SHA256),
                "RS384": RSAAlgorithm(RSAAlgorithm.SHA384),
                "RS512": RSAAlgorithm(RSAAlgorithm.SHA512),
                "ES256": ECAlgorithm(ECAlgorithm.SHA256),
                "ES256K": ECAlgorithm(ECAlgorithm.SHA256),
                "ES384": ECAlgorithm(ECAlgorithm.SHA384),
                "ES521": ECAlgorithm(ECAlgorithm.SHA512),
                "ES512": ECAlgorithm(
                    ECAlgorithm.SHA512
                ),  # Backward compat for #219 fix
                "PS256": RSAPSSAlgorithm(RSAPSSAlgorithm.SHA256),
                "PS384": RSAPSSAlgorithm(RSAPSSAlgorithm.SHA384),
                "PS512": RSAPSSAlgorithm(RSAPSSAlgorithm.SHA512),
                "EdDSA": Ed25519Algorithm(),
            }
        )

    return default_algorithms


class Algorithm:
    """
    The interface for an algorithm used to sign and verify tokens.
    """

    def prepare_key(self, key):
        """
        Performs necessary validation and conversions on the key and returns
        the key value in the proper format for sign() and verify().
        """
        raise NotImplementedError

    def sign(self, msg, key):
        """
        Returns a digital signature for the specified message
        using the specified key value.
        """
        raise NotImplementedError

    def verify(self, msg, key, sig):
        """
        Verifies that the specified digital signature is valid
        for the specified message and key values.
        """
        raise NotImplementedError

    @staticmethod
    def to_jwk(key_obj):
        """
        Serializes a given RSA key into a JWK
        """
        raise NotImplementedError

    @staticmethod
    def from_jwk(jwk):
        """
        Deserializes a given RSA key from JWK back into a PublicKey or PrivateKey object
        """
        raise NotImplementedError


class NoneAlgorithm(Algorithm):
    """
    Placeholder for use when no signing or verification
    operations are required.
    """

    def prepare_key(self, key):
        if key == "":
            key = None

        if key is not None:
            raise InvalidKeyError('When alg = "none", key value must be None.')

        return key

    def sign(self, msg, key):
        return b""

    def verify(self, msg, key, sig):
        return False


class HMACAlgorithm(Algorithm):
    """
    Performs signing and verification operations using HMAC
    and the specified hash function.
    """

    SHA256 = hashlib.sha256
    SHA384 = hashlib.sha384
    SHA512 = hashlib.sha512

    def __init__(self, hash_alg):
        self.hash_alg = hash_alg

    def prepare_key(self, key):
        key = force_bytes(key)

        invalid_strings = [
            b"-----BEGIN PUBLIC KEY-----",
            b"-----BEGIN CERTIFICATE-----",
            b"-----BEGIN RSA PUBLIC KEY-----",
            b"ssh-rsa",
        ]

        if any(string_value in key for string_value in invalid_strings):
            raise InvalidKeyError(
                "The specified key is an asymmetric key or x509 certificate and"
                " should not be used as an HMAC secret."
            )

        return key

    @staticmethod
    def to_jwk(key_obj):
        return json.dumps(
            {
                "k": base64url_encode(force_bytes(key_obj)).decode(),
                "kty": "oct",
            }
        )

    @staticmethod
    def from_jwk(jwk):
        try:
            if isinstance(jwk, str):
                obj = json.loads(jwk)
            elif isinstance(jwk, dict):
                obj = jwk
            else:
                raise ValueError
        except ValueError:
            raise InvalidKeyError("Key is not valid JSON")

        if obj.get("kty") != "oct":
            raise InvalidKeyError("Not an HMAC key")

        return base64url_decode(obj["k"])

    def sign(self, msg, key):
        return hmac.new(key, msg, self.hash_alg).digest()

    def verify(self, msg, key, sig):
        return hmac.compare_digest(sig, self.sign(msg, key))


if has_crypto:

    class RSAAlgorithm(Algorithm):
        """
        Performs signing and verification operations using
        RSASSA-PKCS-v1_5 and the specified hash function.
        """

        SHA256 = hashes.SHA256
        SHA384 = hashes.SHA384
        SHA512 = hashes.SHA512

        def __init__(self, hash_alg):
            self.hash_alg = hash_alg

        def prepare_key(self, key):
            if isinstance(key, RSAPrivateKey) or isinstance(key, RSAPublicKey):
                return key

            if isinstance(key, (bytes, str)):
                key = force_bytes(key)

                try:
                    if key.startswith(b"ssh-rsa"):
                        key = load_ssh_public_key(key)
                    else:
                        key = load_pem_private_key(key, password=None)
                except ValueError:
                    key = load_pem_public_key(key)
            else:
                raise TypeError("Expecting a PEM-formatted key.")

            return key

        @staticmethod
        def to_jwk(key_obj):
            obj = None

            if getattr(key_obj, "private_numbers", None):
                # Private key
                numbers = key_obj.private_numbers()

                obj = {
                    "kty": "RSA",
                    "key_ops": ["sign"],
                    "n": to_base64url_uint(numbers.public_numbers.n).decode(),
                    "e": to_base64url_uint(numbers.public_numbers.e).decode(),
                    "d": to_base64url_uint(numbers.d).decode(),
                    "p": to_base64url_uint(numbers.p).decode(),
                    "q": to_base64url_uint(numbers.q).decode(),
                    "dp": to_base64url_uint(numbers.dmp1).decode(),
                    "dq": to_base64url_uint(numbers.dmq1).decode(),
                    "qi": to_base64url_uint(numbers.iqmp).decode(),
                }

            elif getattr(key_obj, "verify", None):
                # Public key
                numbers = key_obj.public_numbers()

                obj = {
                    "kty": "RSA",
                    "key_ops": ["verify"],
                    "n": to_base64url_uint(numbers.n).decode(),
                    "e": to_base64url_uint(numbers.e).decode(),
                }
            else:
                raise InvalidKeyError("Not a public or private key")

            return json.dumps(obj)

        @staticmethod
        def from_jwk(jwk):
            try:
                if isinstance(jwk, str):
                    obj = json.loads(jwk)
                elif isinstance(jwk, dict):
                    obj = jwk
                else:
                    raise ValueError
            except ValueError:
                raise InvalidKeyError("Key is not valid JSON")

            if obj.get("kty") != "RSA":
                raise InvalidKeyError("Not an RSA key")

            if "d" in obj and "e" in obj and "n" in obj:
                # Private key
                if "oth" in obj:
                    raise InvalidKeyError(
                        "Unsupported RSA private key: > 2 primes not supported"
                    )

                other_props = ["p", "q", "dp", "dq", "qi"]
                props_found = [prop in obj for prop in other_props]
                any_props_found = any(props_found)

                if any_props_found and not all(props_found):
                    raise InvalidKeyError(
                        "RSA key must include all parameters if any are present besides d"
                    )

                public_numbers = RSAPublicNumbers(
                    from_base64url_uint(obj["e"]),
                    from_base64url_uint(obj["n"]),
                )

                if any_props_found:
                    numbers = RSAPrivateNumbers(
                        d=from_base64url_uint(obj["d"]),
                        p=from_base64url_uint(obj["p"]),
                        q=from_base64url_uint(obj["q"]),
                        dmp1=from_base64url_uint(obj["dp"]),
                        dmq1=from_base64url_uint(obj["dq"]),
                        iqmp=from_base64url_uint(obj["qi"]),
                        public_numbers=public_numbers,
                    )
                else:
                    d = from_base64url_uint(obj["d"])
                    p, q = rsa_recover_prime_factors(
                        public_numbers.n, d, public_numbers.e
                    )

                    numbers = RSAPrivateNumbers(
                        d=d,
                        p=p,
                        q=q,
                        dmp1=rsa_crt_dmp1(d, p),
                        dmq1=rsa_crt_dmq1(d, q),
                        iqmp=rsa_crt_iqmp(p, q),
                        public_numbers=public_numbers,
                    )

                return numbers.private_key()
            elif "n" in obj and "e" in obj:
                # Public key
                numbers = RSAPublicNumbers(
                    from_base64url_uint(obj["e"]),
                    from_base64url_uint(obj["n"]),
                )

                return numbers.public_key()
            else:
                raise InvalidKeyError("Not a public or private key")

        def sign(self, msg, key):
            return key.sign(msg, padding.PKCS1v15(), self.hash_alg())

        def verify(self, msg, key, sig):
            try:
                key.verify(sig, msg, padding.PKCS1v15(), self.hash_alg())
                return True
            except InvalidSignature:
                return False

    class ECAlgorithm(Algorithm):
        """
        Performs signing and verification operations using
        ECDSA and the specified hash function
        """

        SHA256 = hashes.SHA256
        SHA384 = hashes.SHA384
        SHA512 = hashes.SHA512

        def __init__(self, hash_alg):
            self.hash_alg = hash_alg

        def prepare_key(self, key):
            if isinstance(key, EllipticCurvePrivateKey) or isinstance(
                key, EllipticCurvePublicKey
            ):
                return key

            if isinstance(key, (bytes, str)):
                key = force_bytes(key)

                # Attempt to load key. We don't know if it's
                # a Signing Key or a Verifying Key, so we try
                # the Verifying Key first.
                try:
                    if key.startswith(b"ecdsa-sha2-"):
                        key = load_ssh_public_key(key)
                    else:
                        key = load_pem_public_key(key)
                except ValueError:
                    key = load_pem_private_key(key, password=None)

            else:
                raise TypeError("Expecting a PEM-formatted key.")

            return key

        def sign(self, msg, key):
            der_sig = key.sign(msg, ec.ECDSA(self.hash_alg()))

            return der_to_raw_signature(der_sig, key.curve)

        def verify(self, msg, key, sig):
            try:
                der_sig = raw_to_der_signature(sig, key.curve)
            except ValueError:
                return False

            try:
                if isinstance(key, EllipticCurvePrivateKey):
                    key = key.public_key()
                key.verify(der_sig, msg, ec.ECDSA(self.hash_alg()))
                return True
            except InvalidSignature:
                return False

        @staticmethod
        def from_jwk(jwk):
            try:
                if isinstance(jwk, str):
                    obj = json.loads(jwk)
                elif isinstance(jwk, dict):
                    obj = jwk
                else:
                    raise ValueError
            except ValueError:
                raise InvalidKeyError("Key is not valid JSON")

            if obj.get("kty") != "EC":
                raise InvalidKeyError("Not an Elliptic curve key")

            if "x" not in obj or "y" not in obj:
                raise InvalidKeyError("Not an Elliptic curve key")

            x = base64url_decode(obj.get("x"))
            y = base64url_decode(obj.get("y"))

            curve = obj.get("crv")
            if curve == "P-256":
                if len(x) == len(y) == 32:
                    curve_obj = ec.SECP256R1()
                else:
                    raise InvalidKeyError("Coords should be 32 bytes for curve P-256")
            elif curve == "P-384":
                if len(x) == len(y) == 48:
                    curve_obj = ec.SECP384R1()
                else:
                    raise InvalidKeyError("Coords should be 48 bytes for curve P-384")
            elif curve == "P-521":
                if len(x) == len(y) == 66:
                    curve_obj = ec.SECP521R1()
                else:
                    raise InvalidKeyError("Coords should be 66 bytes for curve P-521")
            elif curve == "secp256k1":
                if len(x) == len(y) == 32:
                    curve_obj = ec.SECP256K1()
                else:
                    raise InvalidKeyError(
                        "Coords should be 32 bytes for curve secp256k1"
                    )
            else:
                raise InvalidKeyError(f"Invalid curve: {curve}")

            public_numbers = ec.EllipticCurvePublicNumbers(
                x=int.from_bytes(x, byteorder="big"),
                y=int.from_bytes(y, byteorder="big"),
                curve=curve_obj,
            )

            if "d" not in obj:
                return public_numbers.public_key()

            d = base64url_decode(obj.get("d"))
            if len(d) != len(x):
                raise InvalidKeyError(
                    "D should be {} bytes for curve {}", len(x), curve
                )

            return ec.EllipticCurvePrivateNumbers(
                int.from_bytes(d, byteorder="big"), public_numbers
            ).private_key()

    class RSAPSSAlgorithm(RSAAlgorithm):
        """
        Performs a signature using RSASSA-PSS with MGF1
        """

        def sign(self, msg, key):
            return key.sign(
                msg,
                padding.PSS(
                    mgf=padding.MGF1(self.hash_alg()),
                    salt_length=self.hash_alg.digest_size,
                ),
                self.hash_alg(),
            )

        def verify(self, msg, key, sig):
            try:
                key.verify(
                    sig,
                    msg,
                    padding.PSS(
                        mgf=padding.MGF1(self.hash_alg()),
                        salt_length=self.hash_alg.digest_size,
                    ),
                    self.hash_alg(),
                )
                return True
            except InvalidSignature:
                return False

    class Ed25519Algorithm(Algorithm):
        """
        Performs signing and verification operations using Ed25519

        This class requires ``cryptography>=2.6`` to be installed.
        """

        def __init__(self, **kwargs):
            pass

        def prepare_key(self, key):

            if isinstance(key, (Ed25519PrivateKey, Ed25519PublicKey)):
                return key

            if isinstance(key, (bytes, str)):
                if isinstance(key, str):
                    key = key.encode("utf-8")
                str_key = key.decode("utf-8")

                if "-----BEGIN PUBLIC" in str_key:
                    return load_pem_public_key(key)
                if "-----BEGIN PRIVATE" in str_key:
                    return load_pem_private_key(key, password=None)
                if str_key[0:4] == "ssh-":
                    return load_ssh_public_key(key)

            raise TypeError("Expecting a PEM-formatted or OpenSSH key.")

        def sign(self, msg, key):
            """
            Sign a message ``msg`` using the Ed25519 private key ``key``
            :param str|bytes msg: Message to sign
            :param Ed25519PrivateKey key: A :class:`.Ed25519PrivateKey` instance
            :return bytes signature: The signature, as bytes
            """
            msg = bytes(msg, "utf-8") if type(msg) is not bytes else msg
            return key.sign(msg)

        def verify(self, msg, key, sig):
            """
            Verify a given ``msg`` against a signature ``sig`` using the Ed25519 key ``key``

            :param str|bytes sig: Ed25519 signature to check ``msg`` against
            :param str|bytes msg: Message to sign
            :param Ed25519PrivateKey|Ed25519PublicKey key: A private or public Ed25519 key instance
            :return bool verified: True if signature is valid, False if not.
            """
            try:
                msg = bytes(msg, "utf-8") if type(msg) is not bytes else msg
                sig = bytes(sig, "utf-8") if type(sig) is not bytes else sig

                if isinstance(key, Ed25519PrivateKey):
                    key = key.public_key()
                key.verify(sig, msg)
                return True  # If no exception was raised, the signature is valid.
            except cryptography.exceptions.InvalidSignature:
                return False

        @staticmethod
        def to_jwk(key):
            if isinstance(key, Ed25519PublicKey):
                x = key.public_bytes(
                    encoding=Encoding.Raw,
                    format=PublicFormat.Raw,
                )

                return json.dumps(
                    {
                        "x": base64url_encode(force_bytes(x)).decode(),
                        "kty": "OKP",
                        "crv": "Ed25519",
                    }
                )

            if isinstance(key, Ed25519PrivateKey):
                d = key.private_bytes(
                    encoding=Encoding.Raw,
                    format=PrivateFormat.Raw,
                    encryption_algorithm=NoEncryption(),
                )

                x = key.public_key().public_bytes(
                    encoding=Encoding.Raw,
                    format=PublicFormat.Raw,
                )

                return json.dumps(
                    {
                        "x": base64url_encode(force_bytes(x)).decode(),
                        "d": base64url_encode(force_bytes(d)).decode(),
                        "kty": "OKP",
                        "crv": "Ed25519",
                    }
                )

            raise InvalidKeyError("Not a public or private key")

        @staticmethod
        def from_jwk(jwk):
            try:
                if isinstance(jwk, str):
                    obj = json.loads(jwk)
                elif isinstance(jwk, dict):
                    obj = jwk
                else:
                    raise ValueError
            except ValueError:
                raise InvalidKeyError("Key is not valid JSON")

            if obj.get("kty") != "OKP":
                raise InvalidKeyError("Not an Octet Key Pair")

            curve = obj.get("crv")
            if curve != "Ed25519":
                raise InvalidKeyError(f"Invalid curve: {curve}")

            if "x" not in obj:
                raise InvalidKeyError('OKP should have "x" parameter')
            x = base64url_decode(obj.get("x"))

            try:
                if "d" not in obj:
                    return Ed25519PublicKey.from_public_bytes(x)
                d = base64url_decode(obj.get("d"))
                return Ed25519PrivateKey.from_private_bytes(d)
            except ValueError as err:
                raise InvalidKeyError("Invalid key parameter") from err