diff --git a/pyrad/client.py b/pyrad/client.py index 6964963..0b8b5c7 100644 --- a/pyrad/client.py +++ b/pyrad/client.py @@ -34,7 +34,7 @@ class Client(host.Host): :type timeout: float """ def __init__(self, server, authport=1812, acctport=1813, - coaport=3799, secret=six.b(''), dict=None, retries=3, timeout=5): + coaport=3799, secret=six.b(''), dict=None, retries=3, timeout=5, enforce_ma=False): """Constructor. @@ -50,6 +50,8 @@ def __init__(self, server, authport=1812, acctport=1813, :type secret: string :param dict: RADIUS dictionary :type dict: pyrad.dictionary.Dictionary + :param enforce_ma: Enforce usage and check of Message-Authenticator + :type enforce_ma: boolean """ host.Host.__init__(self, authport, acctport, coaport, dict) @@ -58,6 +60,7 @@ def __init__(self, server, authport=1812, acctport=1813, self._socket = None self.retries = retries self.timeout = timeout + self.enforce_ma = enforce_ma self._poll = select.poll() def bind(self, addr): @@ -100,6 +103,9 @@ def CreateAuthPacket(self, **args): :return: a new empty packet instance :rtype: pyrad.packet.AuthPacket """ + if self.enforce_ma: + return host.Host.CreateAuthPacket(self, secret=self.secret, + message_authenticator=True, **args) return host.Host.CreateAuthPacket(self, secret=self.secret, **args) def CreateAcctPacket(self, **args): @@ -163,7 +169,7 @@ def _SendPacket(self, pkt, port): try: reply = pkt.CreateReply(packet=rawreply) - if pkt.VerifyReply(reply, rawreply): + if pkt.VerifyReply(reply, rawreply, enforce_ma=self.enforce_ma): return reply except packet.PacketError: pass diff --git a/pyrad/client_async.py b/pyrad/client_async.py index e9d5df3..7a9d344 100644 --- a/pyrad/client_async.py +++ b/pyrad/client_async.py @@ -128,7 +128,7 @@ def datagram_received(self, data, addr): reply.dict = packet.dict reply.secret = packet.secret - if packet.VerifyReply(reply, data): + if packet.VerifyReply(reply, data, enforce_ma=self.client.enforce_ma): req['future'].set_result(reply) # Remove request for map del self.pending_requests[reply.id] @@ -177,7 +177,7 @@ class ClientAsync: def __init__(self, server, auth_port=1812, acct_port=1813, coa_port=3799, secret=six.b(''), dict=None, loop=None, retries=3, timeout=30, - logger_name='pyrad'): + logger_name='pyrad', enforce_ma=False): """Constructor. @@ -216,6 +216,7 @@ def __init__(self, server, auth_port=1812, acct_port=1813, self.protocol_coa = None self.coa_port = coa_port + self.enforce_ma = enforce_ma async def initialize_transports(self, enable_acct=False, enable_auth=False, enable_coa=False, @@ -325,6 +326,11 @@ def CreateAuthPacket(self, **args): """ if not self.protocol_auth: raise Exception('Transport not initialized') + if self.enforce_ma: + return AuthPacket(dict=self.dict, + id=self.protocol_auth.create_id(), + secret=self.secret, + message_authenticator=True, **args) return AuthPacket(dict=self.dict, id=self.protocol_auth.create_id(), diff --git a/pyrad/packet.py b/pyrad/packet.py index 821c02d..99f8819 100644 --- a/pyrad/packet.py +++ b/pyrad/packet.py @@ -422,7 +422,7 @@ def ReplyPacket(self): return header + authenticator + attr - def VerifyReply(self, reply, rawreply=None): + def VerifyReply(self, reply, rawreply=None, enforce_ma=False): if reply.id != self.id: return False @@ -443,6 +443,13 @@ def VerifyReply(self, reply, rawreply=None): if hash != rawreply[4:20]: return False + + if enforce_ma: + if self.message_authenticator is None: + return False + if not self.verify_message_authenticator(): + return False + return True def _PktEncodeAttribute(self, key, value):