diff --git a/example/acct_async.py b/example/acct_async.py new file mode 100644 index 0000000..9df18cd --- /dev/null +++ b/example/acct_async.py @@ -0,0 +1,101 @@ +#!/usr/bin/python + +import asyncio + +import logging +import traceback +from pyrad.dictionary import Dictionary +from pyrad.client_async import ClientAsync +from pyrad.packet import AccountingResponse + +logging.basicConfig(level="DEBUG", + format="%(asctime)s [%(levelname)-8s] %(message)s") +client = ClientAsync(server="127.0.0.1", + secret=b"Kah3choteereethiejeimaeziecumi", + timeout=3, debug=True, + dict=Dictionary("dictionary")) + +loop = asyncio.get_event_loop() + + +def create_request(client, user): + req = client.CreateAcctPacket(User_Name=user) + + req["NAS-IP-Address"] = "192.168.1.10" + req["NAS-Port"] = 0 + req["Service-Type"] = "Login-User" + req["NAS-Identifier"] = "trillian" + req["Called-Station-Id"] = "00-04-5F-00-0F-D1" + req["Calling-Station-Id"] = "00-01-24-80-B3-9C" + req["Framed-IP-Address"] = "10.0.0.100" + + return req + + +def print_reply(reply): + print("Received Accounting-Response") + + print("Attributes returned by server:") + for i in reply.keys(): + print("%s: %s" % (i, reply[i])) + + +def test_acct1(enable_message_authenticator=False): + + global client + + try: + # Initialize transports + loop.run_until_complete( + asyncio.ensure_future( + client.initialize_transports(enable_auth=True, + # local_addr='127.0.0.1', + # local_auth_port=8000, + enable_acct=True, + enable_coa=True))) + + req = create_request(client, "wichert") + if enable_message_authenticator: + req.add_message_authenticator() + + future = client.SendPacket(req) + + # loop.run_until_complete(future) + loop.run_until_complete(asyncio.ensure_future( + asyncio.gather( + future, + return_exceptions=True + ) + + )) + + if future.exception(): + print('EXCEPTION ', future.exception()) + else: + reply = future.result() + + if reply.code == AccountingResponse: + print("Accounting accepted") + + print("Attributes returned by server:") + for i in reply.keys(): + print("%s: %s" % (i, reply[i])) + + # Close transports + loop.run_until_complete(asyncio.ensure_future( + client.deinitialize_transports())) + print('END') + + del client + except Exception as exc: + print('Error: ', exc) + print('\n'.join(traceback.format_exc().splitlines())) + # Close transports + loop.run_until_complete(asyncio.ensure_future( + client.deinitialize_transports())) + + loop.close() + + +#test_acct1() +test_acct1(enable_message_authenticator=True) diff --git a/example/auth_async.py b/example/auth_async.py index 5ae6c24..89a0198 100644 --- a/example/auth_async.py +++ b/example/auth_async.py @@ -10,7 +10,7 @@ logging.basicConfig(level="DEBUG", format="%(asctime)s [%(levelname)-8s] %(message)s") -client = ClientAsync(server="localhost", +client = ClientAsync(server="127.0.0.1", secret=b"Kah3choteereethiejeimaeziecumi", timeout=3, debug=True, dict=Dictionary("dictionary")) @@ -31,6 +31,7 @@ def create_request(client, user): return req + def print_reply(reply): if reply.code == AccessAccept: print("Access accepted") @@ -41,6 +42,7 @@ def print_reply(reply): for i in reply.keys(): print("%s: %s" % (i, reply[i])) + def test_auth1(): global client @@ -50,13 +52,11 @@ def test_auth1(): loop.run_until_complete( asyncio.ensure_future( client.initialize_transports(enable_auth=True, - #local_addr='127.0.0.1', - #local_auth_port=8000, + # local_addr='127.0.0.1', + # local_auth_port=8000, enable_acct=True, enable_coa=True))) - - req = client.CreateAuthPacket(User_Name="wichert") req["NAS-IP-Address"] = "192.168.1.10" @@ -107,6 +107,7 @@ def test_auth1(): loop.close() + def test_multi_auth(): global client @@ -117,12 +118,10 @@ def test_multi_auth(): asyncio.ensure_future( client.initialize_transports(enable_auth=True, local_addr='127.0.0.1', - #local_auth_port=8000, + # local_auth_port=8000, enable_acct=True, enable_coa=True))) - - reqs = [] for i in range(150): req = create_request(client, "user%s" % i) @@ -162,6 +161,7 @@ def test_multi_auth(): loop.close() + def test_multi_client(): clients = [] @@ -189,8 +189,8 @@ def test_multi_client(): enable_coa=False))) # Send - for i in range(n_req4client): - req = create_request(client, "user%s" % i) + for j in range(n_req4client): + req = create_request(client, "user%s" % j) print('CREATE REQUEST with id %d' % req.id) future = client.SendPacket(req) reqs.append(future) @@ -240,6 +240,64 @@ def test_multi_client(): loop.close() -#test_multi_auth() -#test_auth1() -test_multi_client() +def test_auth1_msg_authenticator(): + global client + + try: + # Initialize transports + loop.run_until_complete( + asyncio.ensure_future( + client.initialize_transports(enable_auth=True, + # local_addr='127.0.0.1', + # local_auth_port=8000, + enable_acct=True, + enable_coa=True))) + + req = create_request(client, "wichert") + req.add_message_authenticator() + + future = client.SendPacket(req) + + # loop.run_until_complete(future) + loop.run_until_complete(asyncio.ensure_future( + asyncio.gather( + future, + return_exceptions=True + ) + + )) + + if future.exception(): + print('EXCEPTION ', future.exception()) + else: + reply = future.result() + + if reply.code == AccessAccept: + print("Access accepted") + else: + print("Access denied") + + print("Attributes returned by server:") + for i in reply.keys(): + print("%s: %s" % (i, reply[i])) + + # Close transports + loop.run_until_complete(asyncio.ensure_future( + client.deinitialize_transports())) + print('END') + + del client + except Exception as exc: + print('Error: ', exc) + print('\n'.join(traceback.format_exc().splitlines())) + # Close transports + loop.run_until_complete(asyncio.ensure_future( + client.deinitialize_transports())) + + loop.close() + + +# test_multi_auth() +# test_auth1() +# test_multi_client() +test_auth1_msg_authenticator() diff --git a/example/pyrad.log b/example/pyrad.log deleted file mode 100644 index e69de29..0000000 diff --git a/example/server_async.py b/example/server_async.py index d1f63ac..0999219 100644 --- a/example/server_async.py +++ b/example/server_async.py @@ -21,10 +21,11 @@ class FakeServer(ServerAsync): - def __init__(self, loop, dictionary): + def __init__(self, loop, dictionary, enable_message_authenticator=False): ServerAsync.__init__(self, loop=loop, dictionary=dictionary, enable_pkt_verify=True, debug=True) + self.enable_message_authenticator = enable_message_authenticator def handle_auth_packet(self, protocol, pkt, addr): @@ -43,6 +44,10 @@ def handle_auth_packet(self, protocol, pkt, addr): }) reply.code = AccessAccept + + if self.enable_message_authenticator and pkt.message_authenticator: + reply.add_message_authenticator() + protocol.send_response(reply, addr) def handle_acct_packet(self, protocol, pkt, addr): @@ -53,6 +58,9 @@ def handle_acct_packet(self, protocol, pkt, addr): print("%s: %s" % (attr, pkt[attr])) reply = self.CreateReplyPacket(pkt) + + if self.enable_message_authenticator and pkt.message_authenticator: + reply.add_message_authenticator() protocol.send_response(reply, addr) def handle_coa_packet(self, protocol, pkt, addr): @@ -63,6 +71,8 @@ def handle_coa_packet(self, protocol, pkt, addr): print("%s: %s" % (attr, pkt[attr])) reply = self.CreateReplyPacket(pkt) + if self.enable_message_authenticator and pkt.message_authenticator: + reply.add_message_authenticator() protocol.send_response(reply, addr) def handle_disconnect_packet(self, protocol, pkt, addr): @@ -75,6 +85,9 @@ def handle_disconnect_packet(self, protocol, pkt, addr): reply = self.CreateReplyPacket(pkt) # COA NAK reply.code = 45 + + if self.enable_message_authenticator and pkt.message_authenticator: + reply.add_message_authenticator() protocol.send_response(reply, addr) @@ -82,7 +95,8 @@ def handle_disconnect_packet(self, protocol, pkt, addr): # create server and read dictionary loop = asyncio.get_event_loop() - server = FakeServer(loop=loop, dictionary=Dictionary('dictionary')) + server = FakeServer(loop=loop, dictionary=Dictionary('dictionary'), + enable_message_authenticator=True) # add clients (address, secret, name) server.hosts["127.0.0.1"] = RemoteHost("127.0.0.1", diff --git a/pyrad/client_async.py b/pyrad/client_async.py index c7b4c02..8512d45 100644 --- a/pyrad/client_async.py +++ b/pyrad/client_async.py @@ -132,16 +132,29 @@ def datagram_received(self, data, addr): reply = Packet(packet=data, dict=self.client.dict) - if reply and reply.id in self.pending_requests: + if reply is not None and reply.id in self.pending_requests: req = self.pending_requests[reply.id] packet = req['packet'] reply.secret = packet.secret if packet.VerifyReply(reply, data): - req['future'].set_result(reply) - # Remove request for map - del self.pending_requests[reply.id] + + if reply.message_authenticator and not \ + reply.verify_message_authenticator( + original_authenticator=packet.authenticator): + self.logger.warn( + '[%s:%d] Received invalid reply for id %d. %s' % ( + self.server, self.port, reply.id, + 'Invalid Message-Authenticator. Ignoring it.' + ) + ) + self.errors += 1 + else: + + req['future'].set_result(reply) + # Remove request for map + del self.pending_requests[reply.id] else: self.logger.warn( '[%s:%d] Received invalid reply for id %d. %s' % ( @@ -430,9 +443,14 @@ def SendPacket(self, pkt): if not self.protocol_acct: raise Exception('Transport not initialized') + self.protocol_acct.send_packet(pkt, ans) + elif isinstance(pkt, CoAPacket): if not self.protocol_coa: raise Exception('Transport not initialized') + + self.protocol_coa.send_packet(pkt, ans) + else: raise Exception('Unsupported packet') diff --git a/pyrad/packet.py b/pyrad/packet.py index 86a7e69..379fb10 100644 --- a/pyrad/packet.py +++ b/pyrad/packet.py @@ -4,9 +4,11 @@ # # A RADIUS packet as defined in RFC 2138 - +from collections import OrderedDict import struct import random +# Hmac needed for Message-Authenticator +import hmac try: import hashlib md5_constructor = hashlib.md5 @@ -44,7 +46,7 @@ class PacketError(Exception): pass -class Packet(dict): +class Packet(OrderedDict): """Packet acts like a standard python map to provide simple access to the RADIUS attributes. Since RADIUS allows for repeated attributes the value will always be a sequence. pyrad makes sure @@ -60,7 +62,8 @@ class Packet(dict): :obj:`AuthPacket` or :obj:`AcctPacket` classes. """ - def __init__(self, code=0, id=None, secret=six.b(''), authenticator=None, **attributes): + def __init__(self, code=0, id=None, secret=six.b(''), authenticator=None, + **attributes): """Constructor :param dict: RADIUS dictionary @@ -74,7 +77,7 @@ def __init__(self, code=0, id=None, secret=six.b(''), authenticator=None, **attr :param packet: raw packet to decode :type packet: string """ - dict.__init__(self) + OrderedDict.__init__(self) self.code = code if id is not None: self.id = id @@ -87,6 +90,7 @@ def __init__(self, code=0, id=None, secret=six.b(''), authenticator=None, **attr not isinstance(authenticator, six.binary_type): raise TypeError('authenticator must be a binary string') self.authenticator = authenticator + self.message_authenticator = None if 'dict' in attributes: self.dict = attributes['dict'] @@ -94,12 +98,113 @@ def __init__(self, code=0, id=None, secret=six.b(''), authenticator=None, **attr if 'packet' in attributes: self.DecodePacket(attributes['packet']) + if 'message_authenticator' in attributes: + self.message_authenticator = attributes['message_authenticator'] + for (key, value) in attributes.items(): - if key in ['dict', 'fd', 'packet']: + if key in ['dict', 'fd', 'packet', 'message_authenticator']: continue key = key.replace('_', '-') self.AddAttribute(key, value) + def add_message_authenticator(self): + + self.message_authenticator = True + # Maintain a zero octects content for md5 and hmac calculation. + self['Message-Authenticator'] = 16 * six.b('\00') + + if self.id is None: + self.id = self.CreateID() + + if self.authenticator is None and self.code == AccessRequest: + self.authenticator = self.CreateAuthenticator() + self._refresh_message_authenticator() + + def get_message_authenticator(self): + self._refresh_message_authenticator() + return self.message_authenticator + + def _refresh_message_authenticator(self): + hmac_constructor = hmac.new(self.secret) + + # Maintain a zero octects content for md5 and hmac calculation. + self['Message-Authenticator'] = 16 * six.b('\00') + attr = self._PktEncodeAttributes() + + header = struct.pack('!BBH', self.code, self.id, + (20 + len(attr))) + + hmac_constructor.update(header[0:4]) + if self.code in (AccountingRequest, DisconnectRequest, + CoARequest, AccountingResponse): + hmac_constructor.update(16 * six.b('\00')) + else: + # NOTE: self.authenticator on reply packet is initialized + # with request authenticator by design. + # For AccessAccept, AccessReject and AccessChallenge + # it is needed use original Authenticato. + # For AccessAccept, AccessReject and AccessChallenge + # it is needed use original Authenticator. + if self.authenticator is None: + raise Exception('No authenticator found') + hmac_constructor.update(self.authenticator) + + hmac_constructor.update(attr) + self['Message-Authenticator'] = hmac_constructor.digest() + + def verify_message_authenticator(self, secret=None, + original_authenticator=None, + original_code=None): + """Verify packet Message-Authenticator. + + :return: False if verification failed else True + :rtype: boolean + """ + if self.message_authenticator is None: + raise Exception('No Message-Authenticator AVP present') + + prev_ma = self['Message-Authenticator'] + # Set zero bytes for Message-Authenticator for md5 calculation + if secret is None and self.secret is None: + raise Exception('Missing secret for HMAC/MD5 verification') + + if secret: + key = secret + else: + key = self.secret + + self['Message-Authenticator'] = 16 * six.b('\00') + attr = self._PktEncodeAttributes() + + header = struct.pack('!BBH', self.code, self.id, + (20 + len(attr))) + + hmac_constructor = hmac.new(key) + hmac_constructor.update(header) + if self.code in (AccountingRequest, DisconnectRequest, + CoARequest, AccountingResponse): + if original_code is None or original_code != StatusServer: + # TODO: Handle Status-Server response correctly. + hmac_constructor.update(16 * six.b('\00')) + elif self.code in (AccessAccept, AccessChallenge, + AccessReject): + if original_authenticator is None: + if self.authenticator: + # NOTE: self.authenticator on reply packet is initialized + # with request authenticator by design. + original_authenticator = self.authenticator + else: + raise Exception('Missing original authenticator') + + hmac_constructor.update(original_authenticator) + else: + # On Access-Request and Status-Server use dynamic authenticator + hmac_constructor.update(self.authenticator) + + hmac_constructor.update(attr) + self['Message-Authenticator'] = prev_ma[0] + return prev_ma[0] == hmac_constructor.digest() + def CreateReply(self, **attributes): """Create a new packet as a reply to this one. This method makes sure the authenticator and secret are copied over @@ -187,9 +292,9 @@ def AddAttribute(self, key, value): def __getitem__(self, key): if not isinstance(key, six.string_types): - return dict.__getitem__(self, key) + return OrderedDict.__getitem__(self, key) - values = dict.__getitem__(self, self._EncodeKey(key)) + values = OrderedDict.__getitem__(self, self._EncodeKey(key)) attr = self.dict.attributes[key] if attr.type == 'tlv': # return map from sub attribute code to its values res = {} @@ -207,25 +312,24 @@ def __getitem__(self, key): def __contains__(self, key): try: - return dict.__contains__(self, self._EncodeKey(key)) + return OrderedDict.__contains__(self, self._EncodeKey(key)) except KeyError: return False has_key = __contains__ def __delitem__(self, key): - dict.__delitem__(self, self._EncodeKey(key)) + OrderedDict.__delitem__(self, self._EncodeKey(key)) def __setitem__(self, key, item): if isinstance(key, six.string_types): (key, item) = self._EncodeKeyValues(key, [item]) - dict.__setitem__(self, key, item) + OrderedDict.__setitem__(self, key, item) else: - assert isinstance(item, list) - dict.__setitem__(self, key, item) + OrderedDict.__setitem__(self, key, item) def keys(self): - return [self._DecodeKey(key) for key in dict.keys(self)] + return [self._DecodeKey(key) for key in OrderedDict.keys(self)] @staticmethod def CreateAuthenticator(): @@ -269,11 +373,15 @@ def ReplyPacket(self): assert(self.authenticator) assert(self.secret is not None) + if self.message_authenticator: + self._refresh_message_authenticator() + attr = self._PktEncodeAttributes() header = struct.pack('!BBH', self.code, self.id, (20 + len(attr))) authenticator = md5_constructor(header[0:4] + self.authenticator - + attr + self.secret).digest() + + attr + self.secret).digest() + return header + authenticator + attr def VerifyReply(self, reply, rawreply=None): @@ -283,8 +391,17 @@ def VerifyReply(self, reply, rawreply=None): if rawreply is None: rawreply = reply.ReplyPacket() + attr = reply._PktEncodeAttributes() + # The Authenticator field in an Accounting-Response packet is called + # the Response Authenticator, and contains a one-way MD5 hash + # calculated over a stream of octets consisting of the Accounting + # Response Code, Identifier, Length, the Request Authenticator field + # from the Accounting-Request packet being replied to, and the + # response attributes if any, followed by the shared secret. The + # resulting 16 octet MD5 hash value is stored in the Authenticator + # field of the Accounting-Response packet. hash = md5_constructor(rawreply[0:4] + self.authenticator + - rawreply[20:] + self.secret).digest() + attr + self.secret).digest() if hash != rawreply[4:20]: return False @@ -368,7 +485,6 @@ def _PktDecodeVendorAttribute(self, data): return tlvs def _PktDecodeTlvAttribute(self, code, data): - sub_attributes = self.setdefault(code, {}) loc = 0 @@ -412,6 +528,11 @@ def DecodePacket(self, packet): if key == 26: for (key, value) in self._PktDecodeVendorAttribute(value): self.setdefault(key, []).append(value) + elif key == 80: + # POST: Message Authenticator AVP is present. + self.message_authenticator = True + self.setdefault(key, []).append(value) + elif self.dict.attributes[self._DecodeKey(key)].type == 'tlv': self._PktDecodeTlvAttribute(key,value) else: @@ -489,8 +610,8 @@ def CreateReply(self, **attributes): to the new instance. """ return AuthPacket(AccessAccept, self.id, - self.secret, self.authenticator, dict=self.dict, - **attributes) + self.secret, self.authenticator, dict=self.dict, + **attributes) def RequestPacket(self): """Create a ready-to-transmit authentication request packet. @@ -500,14 +621,16 @@ def RequestPacket(self): :return: raw packet :rtype: string """ - attr = self._PktEncodeAttributes() - if self.authenticator is None: self.authenticator = self.CreateAuthenticator() if self.id is None: self.id = self.CreateID() + if self.message_authenticator: + self._refresh_message_authenticator() + + attr = self._PktEncodeAttributes() header = struct.pack('!BBH16s', self.code, self.id, (20 + len(attr)), self.authenticator) @@ -611,7 +734,10 @@ def VerifyChapPasswd(self, userpwd): if 'CHAP-Challenge' in self: challenge = self['CHAP-Challenge'][0] - return password == md5_constructor("%s%s%s" % (chapid, userpwd, challenge)).digest() + return password == md5_constructor( + "%s%s%s" % ( + chapid, userpwd, challenge) + ).digest() class AcctPacket(Packet): @@ -620,7 +746,7 @@ class AcctPacket(Packet): """ def __init__(self, code=AccountingRequest, id=None, secret=six.b(''), - authenticator=None, **attributes): + authenticator=None, **attributes): """Constructor :param dict: RADIUS dictionary @@ -644,8 +770,8 @@ def CreateReply(self, **attributes): to the new instance. """ return AcctPacket(AccountingResponse, self.id, - self.secret, self.authenticator, dict=self.dict, - **attributes) + self.secret, self.authenticator, dict=self.dict, + **attributes) def VerifyAcctRequest(self): """Verify request authenticator. @@ -669,15 +795,21 @@ def RequestPacket(self): :rtype: string """ - attr = self._PktEncodeAttributes() - if self.id is None: self.id = self.CreateID() + if self.message_authenticator: + self._refresh_message_authenticator() + + attr = self._PktEncodeAttributes() header = struct.pack('!BBH', self.code, self.id, (20 + len(attr))) self.authenticator = md5_constructor(header[0:4] + 16 * six.b('\x00') + attr + self.secret).digest() - return header + self.authenticator + attr + + ans = header + self.authenticator + attr + + return ans + class CoAPacket(Packet): """RADIUS CoA packets. This class is a specialization @@ -709,8 +841,8 @@ def CreateReply(self, **attributes): to the new instance. """ return CoAPacket(CoAACK, self.id, - self.secret, self.authenticator, dict=self.dict, - **attributes) + self.secret, self.authenticator, dict=self.dict, + **attributes) def VerifyCoARequest(self): """Verify request authenticator. @@ -740,8 +872,16 @@ def RequestPacket(self): header = struct.pack('!BBH', self.code, self.id, (20 + len(attr))) self.authenticator = md5_constructor(header[0:4] + 16 * six.b('\x00') + attr + self.secret).digest() + + if self.message_authenticator: + self._refresh_message_authenticator() + attr = self._PktEncodeAttributes() + self.authenticator = md5_constructor(header[0:4] + 16 * six.b('\x00') + + attr + self.secret).digest() + return header + self.authenticator + attr + def CreateID(): """Generate a packet ID. diff --git a/pyrad/server_async.py b/pyrad/server_async.py index 381f285..3619f7f 100644 --- a/pyrad/server_async.py +++ b/pyrad/server_async.py @@ -64,7 +64,8 @@ def __get_remote_host__(self, addr): return ans def datagram_received(self, data, addr): - self.logger.debug('[%s:%d] Received %d bytes from %s', self.ip, self.port, len(data), addr) + self.logger.debug('[%s:%d] Received %d bytes from %s', self.ip, + self.port, len(data), addr) receive_date = datetime.utcnow() @@ -114,6 +115,13 @@ def datagram_received(self, data, addr): dict=self.server.dict, packet=data) + if self.server.enable_pkt_verify and \ + req.message_authenticator and \ + not req.verify_message_authenticator(): + raise PacketError( + 'Received invalid Message-Authenticator' + ) + elif self.server_type == ServerType.Coa: if req.code != DisconnectRequest and \ @@ -127,6 +135,11 @@ def datagram_received(self, data, addr): if self.server.enable_pkt_verify: if not req.VerifyCoARequest(): raise PacketError('Packet verification failed') + if req.message_authenticator and \ + not req.verify_message_authenticator(): + raise PacketError( + 'Received invalid Message-Authenticator' + ) elif self.server_type == ServerType.Acct: @@ -142,6 +155,11 @@ def datagram_received(self, data, addr): if self.server.enable_pkt_verify: if not req.VerifyAcctRequest(): raise PacketError('Packet verification failed') + if req.message_authenticator and not \ + req.verify_message_authenticator(): + raise PacketError( + 'Received invalid Message-Authenticator' + ) # Call request callback self.request_callback(self, req, addr)