Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def process_packet(self, data):
"""Process an incoming packet"""
packet = SSHPacket(data)
pkttype = packet.get_byte()
if pkttype == MSG_USERAUTH_FAILURE:
_ = packet.get_namelist()
partial_success = packet.get_boolean()
packet.check_end()
if partial_success: # pragma: no cover
# Partial success not implemented yet
self._auth.auth_succeeded()
else:
self._auth.auth_failed()
self._auth_waiter.set_result((False, self._password_changed))
self._auth = None
self._auth_waiter = None
def process_packet(self, data):
"""Process an incoming packet"""
packet = SSHPacket(data)
pkttype = packet.get_byte()
self._kex.process_packet(pkttype, None, packet)
"""Unit tests for SSH packet encoding and decoding"""
import codecs
import unittest
from asyncssh.packet import Byte, Boolean, UInt32, UInt64, String, MPInt
from asyncssh.packet import NameList, PacketDecodeError, SSHPacket
class _TestPacket(unittest.TestCase):
"""Unit tests for SSH packet module"""
# pylint: disable=bad-whitespace
tests = [
(Byte, SSHPacket.get_byte, [
(0, '00'),
(127, '7f'),
(128, '80'),
(255, 'ff')
]),
(Boolean, SSHPacket.get_boolean, [
(False, '00'),
(True, '01')
]),
(UInt32, SSHPacket.get_uint32, [
(0, '00000000'),
(256, '00000100'),
(0x12345678, '12345678'),
(0x7fffffff, '7fffffff'),
def decode_ssh_certificate(data, comment=None):
"""Decode a packetized SSH certificate"""
try:
packet = SSHPacket(data)
alg = packet.get_string()
key_handler, cert_handler = _certificate_alg_map.get(alg, (None, None))
if cert_handler:
return cert_handler.construct(packet, alg, key_handler, comment)
else:
raise KeyImportError('Unknown certificate algorithm: %s' %
alg.decode('ascii', errors='replace'))
except (PacketDecodeError, ValueError):
raise KeyImportError('Invalid OpenSSH certificate') from None
except ValueError:
raise KeyEncryptionError('Invalid OpenSSH '
'private key') from None
cipher = get_encryption(cipher_name, key[:key_size], key[key_size:])
key_data = cipher.decrypt_packet(0, b'', key_data, 0, mac)
if key_data is None:
raise KeyEncryptionError('Incorrect passphrase')
block_size = max(block_size, 8)
else:
block_size = 8
packet = SSHPacket(key_data)
check1 = packet.get_uint32()
check2 = packet.get_uint32()
if check1 != check2:
if cipher_name != b'none':
raise KeyEncryptionError('Incorrect passphrase') from None
else:
raise KeyImportError('Invalid OpenSSH private key')
alg = packet.get_string()
handler = _public_key_alg_map.get(alg)
if not handler:
raise KeyImportError('Unknown OpenSSH private key algorithm')
key_params = handler.decode_ssh_private(packet)
comment = packet.get_string()
def _decode_openssh_private(data, passphrase):
"""Decode an OpenSSH format private key"""
try:
if not data.startswith(_OPENSSH_KEY_V1):
raise KeyImportError('Unrecognized OpenSSH private key type')
data = data[len(_OPENSSH_KEY_V1):]
packet = SSHPacket(data)
cipher_name = packet.get_string()
kdf = packet.get_string()
kdf_data = packet.get_string()
nkeys = packet.get_uint32()
_ = packet.get_string() # public_key
key_data = packet.get_string()
mac = packet.get_remaining_payload()
if nkeys != 1:
raise KeyImportError('Invalid OpenSSH private key')
if cipher_name != b'none':
if passphrase is None:
raise KeyImportError('Passphrase must be specified to import '
'encrypted private keys')
def _process_secret(self, _pkttype, _pktid, packet):
"""Process a KEXRSA secret message"""
if self._conn.is_client():
raise ProtocolError('Unexpected KEXRSA secret msg')
self._encrypted_k = packet.get_string()
packet.check_end()
decrypted_k = self._trans_key.decrypt(self._encrypted_k, self.algorithm)
if not decrypted_k:
raise KeyExchangeFailed('Key exchange decryption failed')
packet = SSHPacket(decrypted_k)
self._k = packet.get_mpint()
packet.check_end()
host_key = self._conn.get_server_host_key()
h = self._compute_hash()
sig = host_key.sign(h)
self.send_packet(MSG_KEXRSA_DONE, String(sig))
self._conn.send_newkeys(self._k, h)
try:
key_size, iv_size, block_size, _, _, _ = \
get_encryption_params(cipher_name)
except KeyError:
raise KeyEncryptionError('Unknown cipher: %s' %
cipher_name.decode('ascii')) from None
if kdf != b'bcrypt':
raise KeyEncryptionError('Unknown kdf: %s' %
kdf.decode('ascii'))
if not _bcrypt_available: # pragma: no cover
raise KeyEncryptionError('OpenSSH private key encryption '
'requires bcrypt with KDF support')
packet = SSHPacket(kdf_data)
salt = packet.get_string()
rounds = packet.get_uint32()
packet.check_end()
if isinstance(passphrase, str):
passphrase = passphrase.encode('utf-8')
try:
key = bcrypt.kdf(passphrase, salt, key_size + iv_size,
rounds, ignore_few_rounds=True)
except ValueError:
raise KeyEncryptionError('Invalid OpenSSH '
'private key') from None
cipher = get_encryption(cipher_name, key[:key_size], key[key_size:])