From c8837a2f9aaae7b4c27cb9d33dddf4c4a2c0daff Mon Sep 17 00:00:00 2001 From: Geaaru Date: Tue, 3 Jul 2018 17:16:30 +0200 Subject: [PATCH 1/3] Add async client implementation --- example/auth_async.py | 164 +++++++++++++++ pyrad/client_async.py | 449 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 613 insertions(+) create mode 100644 example/auth_async.py create mode 100644 pyrad/client_async.py diff --git a/example/auth_async.py b/example/auth_async.py new file mode 100644 index 0000000..9ce4a41 --- /dev/null +++ b/example/auth_async.py @@ -0,0 +1,164 @@ +#!/usr/bin/python + +import asyncio + +import logging +import traceback +from pyrad.dictionary import Dictionary +from pyrad.client_async import ClientAsync +from pyrad.packet import AccessAccept + +logging.basicConfig(level="DEBUG", + format="%(asctime)s [%(levelname)-8s] %(message)s") +client = ClientAsync(server="localhost", + secret=b"Kah3choteereethiejeimaeziecumi", + timeout=4, + dict=Dictionary("dictionary")) + +loop = asyncio.get_event_loop() + + +def create_request(client, user): + req = client.CreateAuthPacket(User_Name=user) + + req["NAS-IP-Address"] = "192.168.1.10" + req["NAS-Port"] = 0 + req["Service-Type"] = "Login-User" + req["NAS-Identifier"] = "trillian" + req["Called-Station-Id"] = "00-04-5F-00-0F-D1" + req["Calling-Station-Id"] = "00-01-24-80-B3-9C" + req["Framed-IP-Address"] = "10.0.0.100" + + return req + +def print_reply(reply): + if reply.code == AccessAccept: + print("Access accepted") + else: + print("Access denied") + + print("Attributes returned by server:") + for i in reply.keys(): + print("%s: %s" % (i, reply[i])) + +def test_auth1(): + + global client + + try: + # Initialize transports + loop.run_until_complete( + asyncio.ensure_future( + client.initialize_transports(enable_auth=True, + local_addr='127.0.0.1', + local_auth_port=8000, + enable_acct=True, + enable_coa=True))) + + + + req = client.CreateAuthPacket(User_Name="wichert") + + req["NAS-IP-Address"] = "192.168.1.10" + req["NAS-Port"] = 0 + req["Service-Type"] = "Login-User" + req["NAS-Identifier"] = "trillian" + req["Called-Station-Id"] = "00-04-5F-00-0F-D1" + req["Calling-Station-Id"] = "00-01-24-80-B3-9C" + req["Framed-IP-Address"] = "10.0.0.100" + + future = client.SendPacket(req) + + # loop.run_until_complete(future) + loop.run_until_complete(asyncio.ensure_future( + asyncio.gather( + future, + return_exceptions=True + ) + + )) + + if future.exception(): + print('EXCEPTION ', future.exception()) + else: + reply = future.result() + + if reply.code == AccessAccept: + print("Access accepted") + else: + print("Access denied") + + print("Attributes returned by server:") + for i in reply.keys(): + print("%s: %s" % (i, reply[i])) + + # Close transports + loop.run_until_complete(asyncio.ensure_future( + client.deinitialize_transports())) + print('END') + + del client + except Exception as exc: + print('Error: ', exc) + print('\n'.join(traceback.format_exc().splitlines())) + # Close transports + loop.run_until_complete(asyncio.ensure_future( + client.deinitialize_transports())) + + loop.close() + +def test_multi_auth(): + + global client + + try: + # Initialize transports + loop.run_until_complete( + asyncio.ensure_future( + client.initialize_transports(enable_auth=True, + local_addr='127.0.0.1', + local_auth_port=8000, + enable_acct=True, + enable_coa=True))) + + + + reqs = [] + for i in range(255): + req = create_request(client, "user%s" % i) + future = client.SendPacket(req) + reqs.append(future) + + # loop.run_until_complete(future) + loop.run_until_complete(asyncio.ensure_future( + asyncio.gather( + *reqs, + return_exceptions=True + ) + + )) + + for future in reqs: + if future.exception(): + print('EXCEPTION ', future.exception()) + else: + reply = future.result() + print_reply(reply) + + # Close transports + loop.run_until_complete(asyncio.ensure_future( + client.deinitialize_transports())) + print('END') + + del client + except Exception as exc: + print('Error: ', exc) + print('\n'.join(traceback.format_exc().splitlines())) + # Close transports + loop.run_until_complete(asyncio.ensure_future( + client.deinitialize_transports())) + + loop.close() + +#test_multi_auth() +test_auth1() diff --git a/pyrad/client_async.py b/pyrad/client_async.py new file mode 100644 index 0000000..906bba2 --- /dev/null +++ b/pyrad/client_async.py @@ -0,0 +1,449 @@ +# client_async.py +# +# Copyright 2018-2020 Geaaru gmail.com> + +__docformat__ = "epytext en" + +from datetime import datetime +import asyncio +import six +import logging +import random + +from pyrad.packet import Packet, AuthPacket, AcctPacket, CoAPacket + + +class DatagramProtocolClient(asyncio.Protocol): + + def __init__(self, server, port, logger, + client, retries=3, timeout=30): + self.transport = None + self.port = port + self.server = server + self.logger = logger + self.retries = retries + self.timeout = timeout + self.client = client + + # Map of pending requests + self.pending_requests = {} + + # Use cryptographic-safe random generator as provided by the OS. + random_generator = random.SystemRandom() + self.packet_id = random_generator.randrange(0, 256) + + self.timeout_future = None + + async def __timeout_handler__(self): + + try: + + while True: + + req2delete = [] + now = datetime.now() + next_weak_up = self.timeout + # noinspection PyShadowingBuiltins + for id, req in self.pending_requests.items(): + + secs = (req['send_date'] - now).seconds + if secs > self.timeout: + if req['retries'] == self.retries: + self.logger.debug( + '[%s:%d] For request %d execute all retries.' % ( + self.server, self.port, id + ) + ) + req['future'].set_exception( + TimeoutError('Timeout on Reply') + ) + req2delete.append(id) + else: + # Send again packet + req['send_date'] = now + req['retries'] += 1 + self.logger.debug( + '[%s:%d] For request %d execute retry %d.' % ( + self.server, self.port, id, + req['retries'] + ) + ) + self.transport.sendto(req['packet'].RequestPacket()) + elif next_weak_up > secs: + next_weak_up = secs + + # noinspection PyShadowingBuiltins + for id in req2delete: + # Remove request for map + del self.pending_requests[id] + + await asyncio.sleep(next_weak_up) + + except asyncio.CancelledError: + pass + + def send_packet(self, packet, future): + if packet.id in self.pending_requests: + raise Exception('Packet with id %d already present' % ( + packet.id + )) + + # Store packet on pending requests map + self.pending_requests[packet.id] = { + 'packet': packet, + 'creation_date': datetime.now(), + 'retries': 0, + 'future': future, + 'send_date': datetime.now() + } + + # In queue packet raw on socket buffer + self.transport.sendto(packet.RequestPacket()) + + def connection_made(self, transport): + self.transport = transport + socket = transport.get_extra_info('socket') + self.logger.info( + '[%s:%d] Transport created with binding in %s:%d.' % ( + self.server, self.port, + socket.getsockname()[0], + socket.getsockname()[1] + ) + ) + + pre_loop = asyncio.get_event_loop() + asyncio.set_event_loop(loop=self.client.loop) + # Start asynchronous timer handler + self.timeout_future = asyncio.ensure_future( + self.__timeout_handler__() + ) + asyncio.set_event_loop(loop=pre_loop) + + def error_received(self, exc): + self.logger.error( + '[%s:%d] Error received: %s.' % ( + self.server, self.port, exc + ) + ) + + def connection_lost(self, exc): + if exc: + self.logger.warn( + '[%s:%d] Connection lost: %s.' % ( + self.server, self.port, str(exc) + ) + ) + else: + self.logger.info( + '[%s:%d] Transport closed.' % (self.server, self.port) + ) + + # noinspection PyUnusedLocal + def datagram_received(self, data, addr): + try: + + reply = Packet(packet=data) + + if reply and reply.id in self.pending_requests: + req = self.pending_requests[reply.id] + packet = req['packet'] + + reply.dict = packet.dict + reply.secret = packet.secret + + if packet.VerifyReply(reply, data): + req['future'].set_result(reply) + # Remove request for map + del self.pending_requests[reply.id] + else: + self.logger.warn( + '[%s:%d] Received invalid reply for id %d. %s' % ( + self.server, self.port, reply.id, + 'Ignoring it.' + ) + ) + else: + self.logger.warn( + '[%s:%d] Received invalid reply: %d.\nIgnoring it.' % ( + self.server, self.port, data, + ) + ) + + except Exception as exc: + self.logger.error( + '[%s:%d] Error on decode packet: %s.' % ( + self.server, self.port, exc + ) + ) + pass + + async def close_transport(self): + if self.transport: + self.logger.debug( + '[%s:%d] Closing transport...' % (self.server, self.port) + ) + self.transport.close() + self.transport = None + if self.timeout_future: + self.timeout_future.cancel() + await self.timeout_future + self.timeout_future = None + + def create_id(self): + self.packet_id = (self.packet_id + 1) % 256 + return self.packet_id + + def __str__(self): + return 'DatagramProtocolClient: { server: %s, port: %d }' % ( + self.server, + self.port + ) + + # Used as protocol_factory + def __call__(self): + return self + + +class ClientAsync: + """Basic RADIUS client. + This class implements a basic RADIUS client. It can send requests + to a RADIUS server, taking care of timeouts and retries, and + validate its replies. + + :ivar retries: number of times to retry sending a RADIUS request + :type retries: integer + :ivar timeout: number of seconds to wait for an answer + :type timeout: integer + """ + # noinspection PyShadowingBuiltins + def __init__(self, server, auth_port=1812, acct_port=1813, + coa_port=3799, secret=six.b(''), dict=None, + loop=None, retries=3, timeout=30, + logger_name='pyrad'): + + """Constructor. + + :param server: hostname or IP address of RADIUS server + :type server: string + :param auth_port: port to use for authentication packets + :type auth_port: integer + :param acct_port: port to use for accounting packets + :type acct_port: integer + :param coa_port: port to use for CoA packets + :type coa_port: integer + :param secret: RADIUS secret + :type secret: string + :param dict: RADIUS dictionary + :type dict: pyrad.dictionary.Dictionary + :param loop: Python loop handler + :type loop: asyncio event loop + """ + if not loop: + self.loop = asyncio.get_event_loop() + else: + self.loop = loop + self.logger = logging.getLogger(logger_name) + + self.server = server + self.secret = secret + self.retries = retries + self.timeout = timeout + self.dict = dict + + self.auth_port = auth_port + self.protocol_auth = None + + self.acct_port = acct_port + self.protocol_acct = None + + self.protocol_coa = None + self.coa_port = coa_port + + async def initialize_transports(self, enable_acct=False, + enable_auth=False, enable_coa=False, + local_addr=None, local_auth_port=None, + local_acct_port=None, local_coa_port=None): + + task_list = [] + + if not enable_acct and not enable_auth and not enable_coa: + raise Exception('No transports selected') + + if enable_acct and not self.protocol_acct: + self.protocol_acct = DatagramProtocolClient( + self.server, + self.acct_port, + self.logger, self, + retries=self.retries, + timeout=self.timeout + ) + bind_addr = None + if local_addr and local_acct_port: + bind_addr = (local_addr, local_acct_port) + + acct_connect = self.loop.create_datagram_endpoint( + self.protocol_acct, + reuse_address=True, reuse_port=True, + remote_addr=(self.server, self.acct_port), + local_addr=bind_addr + ) + task_list.append(acct_connect) + + if enable_auth and not self.protocol_auth: + self.protocol_auth = DatagramProtocolClient( + self.server, + self.auth_port, + self.logger, self, + retries=self.retries, + timeout=self.timeout + ) + bind_addr = None + if local_addr and local_auth_port: + bind_addr = (local_addr, local_auth_port) + + auth_connect = self.loop.create_datagram_endpoint( + self.protocol_auth, + reuse_address=True, reuse_port=True, + remote_addr=(self.server, self.auth_port), + local_addr=bind_addr + ) + task_list.append(auth_connect) + + if enable_coa and not self.protocol_coa: + self.protocol_coa = DatagramProtocolClient( + self.server, + self.coa_port, + self.logger, self, + retries=self.retries, + timeout=self.timeout + ) + bind_addr = None + if local_addr and local_coa_port: + bind_addr = (local_addr, local_coa_port) + + coa_connect = self.loop.create_datagram_endpoint( + self.protocol_coa, + reuse_address=True, reuse_port=True, + remote_addr=(self.server, self.coa_port), + local_addr=bind_addr + ) + task_list.append(coa_connect) + + await asyncio.ensure_future( + asyncio.gather( + *task_list, + return_exceptions=False, + ), + loop=self.loop + ) + + # noinspection SpellCheckingInspection + async def deinitialize_transports(self, deinit_coa=True, + deinit_auth=True, + deinit_acct=True): + if self.protocol_coa and deinit_coa: + await self.protocol_coa.close_transport() + del self.protocol_coa + self.protocol_coa = None + if self.protocol_auth and deinit_auth: + await self.protocol_auth.close_transport() + del self.protocol_auth + self.protocol_auth = None + if self.protocol_acct and deinit_acct: + await self.protocol_acct.close_transport() + del self.protocol_acct + self.protocol_acct = None + + # noinspection PyPep8Naming + def CreateAuthPacket(self, **args): + """Create a new RADIUS packet. + This utility function creates a new RADIUS packet which can + be used to communicate with the RADIUS server this client + talks to. This is initializing the new packet with the + dictionary and secret used for the client. + + :return: a new empty packet instance + :rtype: pyrad.packet.Packet + """ + if not self.protocol_auth: + raise Exception('Transport not initialized') + + return AuthPacket(dict=self.dict, + id=self.protocol_auth.create_id(), + secret=self.secret, **args) + + # noinspection PyPep8Naming + def CreateAcctPacket(self, **args): + """Create a new RADIUS packet. + This utility function creates a new RADIUS packet which can + be used to communicate with the RADIUS server this client + talks to. This is initializing the new packet with the + dictionary and secret used for the client. + + :return: a new empty packet instance + :rtype: pyrad.packet.Packet + """ + if not self.protocol_acct: + raise Exception('Transport not initialized') + + return AcctPacket(id=self.protocol_acct.create_id(), + dict=self.dict, + secret=self.secret, **args) + + # noinspection PyPep8Naming + def CreateCoAPacket(self, **args): + """Create a new RADIUS packet. + This utility function creates a new RADIUS packet which can + be used to communicate with the RADIUS server this client + talks to. This is initializing the new packet with the + dictionary and secret used for the client. + + :return: a new empty packet instance + :rtype: pyrad.packet.Packet + """ + + if not self.protocol_acct: + raise Exception('Transport not initialized') + + return CoAPacket(id=self.protocol_coa.create_id(), + dict=self.dict, + secret=self.secret, **args) + + # noinspection PyPep8Naming + # noinspection PyShadowingBuiltins + def CreatePacket(self, id, **args): + if not id: + raise Exception('Missing mandatory packet id') + + return Packet(id=id, dict=self.dict, + secret=self.secret, **args) + + # noinspection PyPep8Naming + def SendPacket(self, pkt): + """Send a packet to a RADIUS server. + + :param pkt: the packet to send + :type pkt: pyrad.packet.Packet + :return: Future related with packet to send + :rtype: asyncio.Future + """ + + ans = asyncio.Future(loop=self.loop) + + if isinstance(pkt, AuthPacket): + if not self.protocol_auth: + raise Exception('Transport not initialized') + + self.protocol_auth.send_packet(pkt, ans) + + elif isinstance(pkt, AcctPacket): + if not self.protocol_acct: + raise Exception('Transport not initialized') + + elif isinstance(pkt, CoAPacket): + if not self.protocol_coa: + raise Exception('Transport not initialized') + else: + raise Exception('Unsupported packet') + + return ans From 51e3499a1a209e5651fa8b3a59a5e345aeb5fc2f Mon Sep 17 00:00:00 2001 From: Geaaru Date: Thu, 5 Jul 2018 18:17:32 +0200 Subject: [PATCH 2/3] Add VerifyAuthRequest on AuthPacket --- pyrad/packet.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/pyrad/packet.py b/pyrad/packet.py index dfb9b58..da6bcbc 100644 --- a/pyrad/packet.py +++ b/pyrad/packet.py @@ -325,6 +325,7 @@ def DecodePacket(self, packet): try: (self.code, self.id, length, self.authenticator) = \ struct.unpack('!BBH16s', packet[0:20]) + except struct.error: raise PacketError('Packet header is corrupt') if len(packet) != length: @@ -415,6 +416,8 @@ def __init__(self, code=AccessRequest, id=None, secret=six.b(''), :type packet: string """ Packet.__init__(self, code, id, secret, authenticator, **attributes) + if 'packet' in attributes: + self.raw_packet = attributes['packet'] def CreateReply(self, **attributes): """Create a new packet as a reply to this one. This method @@ -546,6 +549,17 @@ def VerifyChapPasswd(self, userpwd): return password == md5_constructor("%s%s%s" % (chapid, userpwd, challenge)).digest() + def VerifyAuthRequest(self): + """Verify request authenticator. + + :return: True if verification failed else False + :rtype: boolean + """ + assert(self.raw_packet) + hash = md5_constructor(self.raw_packet[0:4] + 16 * six.b('\x00') + + self.raw_packet[20:] + self.secret).digest() + return hash == self.authenticator + class AcctPacket(Packet): """RADIUS accounting packets. This class is a specialization @@ -651,7 +665,7 @@ def VerifyCoARequest(self): """ assert(self.raw_packet) hash = md5_constructor(self.raw_packet[0:4] + 16 * six.b('\x00') + - self.raw_packet[20:] + self.secret).digest() + self.raw_packet[20:] + self.secret).digest() return hash == self.authenticator def RequestPacket(self): From 54f915c66e321829170517d2777dc3b9a3550414 Mon Sep 17 00:00:00 2001 From: Geaaru Date: Thu, 5 Jul 2018 18:18:17 +0200 Subject: [PATCH 3/3] Add Async Server implementation --- example/server_async.py | 117 +++++++++++ pyrad/server_async.py | 431 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 548 insertions(+) create mode 100644 example/server_async.py create mode 100644 pyrad/server_async.py diff --git a/example/server_async.py b/example/server_async.py new file mode 100644 index 0000000..3b893da --- /dev/null +++ b/example/server_async.py @@ -0,0 +1,117 @@ +#!/usr/bin/python + +import asyncio + +import logging +import traceback +from pyrad.dictionary import Dictionary +from pyrad.server_async import ServerAsync +from pyrad.packet import AccessAccept +from pyrad.server import RemoteHost + +try: + import uvloop + asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) +except: + pass + +logging.basicConfig(level="DEBUG", + format="%(asctime)s [%(levelname)-8s] %(message)s") + +class FakeServer(ServerAsync): + + def __init__(self, loop, dictionary): + + ServerAsync.__init__(self, loop=loop, dictionary=dictionary, + enable_pkt_verify=True, debug=True) + + + def handle_auth_packet(self, protocol, pkt, addr): + + print("Received an authentication request with id ", pkt.id) + print('Authenticator ', pkt.authenticator.hex()) + print('Secret ', pkt.secret) + print("Attributes: ") + for attr in pkt.keys(): + print("%s: %s" % (attr, pkt[attr])) + + reply = self.CreateReplyPacket(pkt, **{ + "Service-Type": "Framed-User", + "Framed-IP-Address": '192.168.0.1', + "Framed-IPv6-Prefix": "fc66::1/64" + }) + + reply.code = AccessAccept + protocol.send_response(reply, addr) + + def handle_acct_packet(self, protocol, pkt, addr): + + print("Received an accounting request") + print("Attributes: ") + for attr in pkt.keys(): + print("%s: %s" % (attr, pkt[attr])) + + reply = self.CreateReplyPacket(pkt) + protocol.send_response(reply, addr) + + def handle_coa_packet(self, protocol, pkt, addr): + + print("Received an coa request") + print("Attributes: ") + for attr in pkt.keys(): + print("%s: %s" % (attr, pkt[attr])) + + reply = self.CreateReplyPacket(pkt) + protocol.send_response(reply, addr) + + def handle_disconnect_packet(self, protocol, pkt, addr): + + print("Received an disconnect request") + print("Attributes: ") + for attr in pkt.keys(): + print("%s: %s" % (attr, pkt[attr])) + + reply = self.CreateReplyPacket(pkt) + # COA NAK + reply.code = 45 + protocol.send_response(reply, addr) + + +if __name__ == '__main__': + + # create server and read dictionary + loop = asyncio.get_event_loop() + server = FakeServer(loop=loop, dictionary=Dictionary('dictionary')) + + # add clients (address, secret, name) + server.hosts["127.0.0.1"] = RemoteHost("127.0.0.1", + b"Kah3choteereethiejeimaeziecumi", + "localhost") + + try: + + # Initialize transports + loop.run_until_complete( + asyncio.ensure_future( + server.initialize_transports(enable_auth=True, + enable_acct=True, + enable_coa=True))) + + try: + # start server + loop.run_forever() + except KeyboardInterrupt as k: + pass + + # Close transports + loop.run_until_complete(asyncio.ensure_future( + server.deinitialize_transports())) + + except Exception as exc: + print('Error: ', exc) + print('\n'.join(traceback.format_exc().splitlines())) + # Close transports + loop.run_until_complete(asyncio.ensure_future( + server.deinitialize_transports())) + + loop.close() diff --git a/pyrad/server_async.py b/pyrad/server_async.py new file mode 100644 index 0000000..e85cd2c --- /dev/null +++ b/pyrad/server_async.py @@ -0,0 +1,431 @@ +# server_async.py +# +# Copyright 2018-2019 Geaaru + +import asyncio +import logging +import traceback + +from abc import abstractmethod, ABCMeta +from enum import Enum +from datetime import datetime +from pyrad.packet import Packet, AccessAccept, AccessReject, \ + AccountingRequest, AccountingResponse, \ + DisconnectACK, DisconnectNAK, DisconnectRequest, CoARequest, \ + CoAACK, CoANAK, AccessRequest, AuthPacket, AcctPacket, CoAPacket, \ + PacketError + +from pyrad.server import ServerPacketError + + +class ServerType(Enum): + Auth = 'Authentication' + Acct = 'Accounting' + Coa = 'Coa' + + +class DatagramProtocolServer(asyncio.Protocol): + + def __init__(self, ip, port, logger, server, server_type, hosts, + request_callback): + self.transport = None + self.ip = ip + self.port = port + self.logger = logger + self.server = server + self.hosts = hosts + self.server_type = server_type + self.request_callback = request_callback + + def __get_remote_host__(self, addr): + ans = None + if addr in self.hosts.keys(): + ans = self.hosts[addr] + return ans + + def connection_made(self, transport): + self.transport = transport + self.logger.info( + '[%s:%d] Transport created.' % ( + self.ip, self.port, + ) + ) + + def connection_lost(self, exc): + if exc: + self.logger.warn( + '[%s:%d] Connection lost: %s.' % ( + self.ip, self.port, str(exc) + ) + ) + else: + self.logger.info( + '[%s:%d] Transport closed.' % (self.ip, self.port) + ) + + def send_response(self, reply, addr): + self.transport.sendto(reply.ReplyPacket(), addr) + + def datagram_received(self, data, addr): + + self.logger.debug( + '[%s:%d] Received %d bytes from %s.' % ( + self.ip, self.port, len(data), addr + ) + ) + + receive_date = datetime.utcnow() + req = None + + try: + remote_host = self.__get_remote_host__(addr[0]) + + if remote_host: + + try: + if self.server.debug: + self.logger.info( + '[%s:%d] Received from %s packet: %s.' % ( + self.ip, self.port, addr, data.hex() + ) + ) + req = Packet(packet=data, dict=self.server.dict) + + except Exception as exc: + self.logger.error( + '[%s:%d] Error on decode packet: %s. Ignore it.' % ( + self.ip, self.port, exc + ) + ) + + if not req: + return + + if req.code in ( + AccountingResponse, + AccessAccept, + AccessReject, + CoANAK, + CoAACK, + DisconnectNAK, + DisconnectACK): + raise ServerPacketError('Invalid response packet %d' % + req.code) + + elif self.server_type == ServerType.Auth: + + if req.code != AccessRequest: + raise ServerPacketError( + 'Received not-authentication packet ' + 'on authentication port') + req = AuthPacket(secret=remote_host.secret, + dict=self.server.dict, + packet=data) + if self.server.enable_pkt_verify: + if req.VerifyAuthRequest(): + raise PacketError('Packet verification failed') + + elif self.server_type == ServerType.Coa: + + if req.code != DisconnectRequest and \ + req.code != CoARequest: + raise ServerPacketError( + 'Received not-coa packet on coa port' + ) + req = CoAPacket(secret=remote_host.secret, + dict=self.server.dict, + packet=data) + if self.server.enable_pkt_verify: + if req.VerifyCoARequest(): + raise PacketError('Packet verification failed') + + elif self.server_type == ServerType.Acct: + + if req.code != AccountingRequest: + raise ServerPacketError( + 'Received not-accounting packet on ' + 'accounting port' + ) + req = AcctPacket(secret=remote_host.secret, + dict=self.server.dict, + packet=data) + + if self.server.enable_pkt_verify: + if req.VerifyAcctRequest(): + raise PacketError('Packet verification failed') + + # Call request callback + self.request_callback(self, req, addr) + + else: + # POST: Unknown source + self.logger.warn( + '[%s:%d] Received package from unknown source %s. Drop data.' % ( + self.ip, self.port, addr + ) + ) + except Exception as e: + self.logger.error( + '[%s:%d] Unexpected error for packet from %s: %s' % ( + self.ip, self.port, addr, + (e, '\n'.join(traceback.format_exc().splitlines()))[ + self.server.debug + ] + ) + ) + + process_date = datetime.utcnow() + + self.logger.debug( + '[%s:%d] Request from %s processed in %d ms.' % ( + self.ip, self.port, addr, + (process_date-receive_date).microseconds/1000 + ) + ) + + def error_received(self, exc): + self.logger.error( + '[%s:%d] Error received: %s.' % ( + self.ip, self.port, exc + ) + ) + + async def close_transport(self): + if self.transport: + self.logger.debug( + '[%s:%d] Closing transport...' % (self.ip, self.port) + ) + self.transport.close() + self.transport = None + + def __str__(self): + return 'DatagramProtocolServer: { ip: %s, port: %d }' % ( + self.ip, self.port + ) + + # Used as protocol_factory + def __call__(self): + return self + + +class ServerAsync(metaclass=ABCMeta): + + def __init__(self, auth_port=1812, acct_port=1813, + coa_port=3799, hosts=None, dictionary=None, + loop=None, logger_name='pyrad', + enable_pkt_verify=False, + debug=False): + + if not loop: + self.loop = asyncio.get_event_loop() + else: + self.loop = loop + self.logger = logging.getLogger(logger_name) + + if hosts is None: + self.hosts = {} + else: + self.hosts = hosts + + self.auth_port = auth_port + self.auth_protocols = [] + + self.acct_port = acct_port + self.acct_protocols = [] + + self.coa_port = coa_port + self.coa_protocols = [] + + self.dict = dictionary + self.enable_pkt_verify = enable_pkt_verify + + self.debug = debug + + def __request_handler__(self, protocol, req, addr): + + try: + if protocol.server_type == ServerType.Acct: + self.handle_acct_packet(protocol, req, addr) + elif protocol.server_type == ServerType.Auth: + self.handle_auth_packet(protocol, req, addr) + elif protocol.server_type == ServerType.Coa and \ + req.code == CoARequest: + self.handle_coa_packet(protocol, req, addr) + elif protocol.server_type == ServerType.Coa and \ + req.code == DisconnectRequest: + self.handle_disconnect_packet(protocol, req, addr) + else: + self.logger.error( + '[%s:%s] Unexpected request found.' % ( + protocol.ip, protocol.port + ) + ) + except Exception as exc: + self.logger.error( + '[%s:%s] Unexpected error catched: %s' % ( + protocol.ip, protocol.port, + (exc, '\n'.join(traceback.format_exc().splitlines()))[ + self.debug + ] + ) + ) + + def __is_present_proto__(self, ip, port): + ans = False + + if port == self.auth_port: + for proto in self.auth_protocols: + if proto.ip == ip: + ans = True + break + + if port == self.acct_port: + for proto in self.acct_protocols: + if proto.ip == ip: + ans = True + break + + if port == self.coa_port: + for proto in self.coa_protocols: + if proto.ip == ip: + ans = True + break + + return ans + + # noinspection PyPep8Naming + @staticmethod + def CreateReplyPacket(pkt, **attributes): + """Create a reply packet. + Create a new packet which can be returned as a reply to a received + packet. + + :param pkt: original packet + :type pkt: Packet instance + """ + reply = pkt.CreateReply(**attributes) + return reply + + async def initialize_transports(self, enable_acct=False, + enable_auth=False, enable_coa=False, + addresses=None): + + task_list = [] + + if not enable_acct and not enable_auth and not enable_coa: + raise Exception('No transports selected') + if not addresses or len(addresses) == 0: + addresses = ['127.0.0.1'] + + # noinspection SpellCheckingInspection + for addr in addresses: + + if enable_acct and not self.__is_present_proto__(addr, + self.acct_port): + protocol_acct = DatagramProtocolServer( + addr, + self.acct_port, + self.logger, self, + ServerType.Acct, + self.hosts, + self.__request_handler__ + ) + + bind_addr = (addr, self.acct_port) + acct_connect = self.loop.create_datagram_endpoint( + protocol_acct, + reuse_address=True, reuse_port=True, + local_addr=bind_addr + ) + self.acct_protocols.append(protocol_acct) + task_list.append(acct_connect) + + if enable_auth and not self.__is_present_proto__(addr, + self.auth_port): + protocol_auth = DatagramProtocolServer( + addr, + self.auth_port, + self.logger, self, + ServerType.Auth, + self.hosts, + self.__request_handler__ + ) + bind_addr = (addr, self.auth_port) + + auth_connect = self.loop.create_datagram_endpoint( + protocol_auth, + reuse_address=True, reuse_port=True, + local_addr=bind_addr + ) + self.auth_protocols.append(protocol_auth) + task_list.append(auth_connect) + + if enable_coa and not self.__is_present_proto__(addr, + self.coa_port): + protocol_coa = DatagramProtocolServer( + addr, + self.coa_port, + self.logger, self, + ServerType.Coa, + self.hosts, + self.__request_handler__ + ) + bind_addr = (addr, self.coa_port) + + coa_connect = self.loop.create_datagram_endpoint( + protocol_coa, + reuse_address=True, reuse_port=True, + local_addr=bind_addr + ) + self.coa_protocols.append(protocol_coa) + task_list.append(coa_connect) + + await asyncio.ensure_future( + asyncio.gather( + *task_list, + return_exceptions=False, + ), + loop=self.loop + ) + + # noinspection SpellCheckingInspection + async def deinitialize_transports(self, deinit_coa=True, + deinit_auth=True, + deinit_acct=True): + + if deinit_coa: + for proto in self.coa_protocols: + await proto.close_transport() + del proto + + self.coa_protocols = [] + + if deinit_auth: + for proto in self.auth_protocols: + await proto.close_transport() + del proto + + self.auth_protocols = [] + + if deinit_acct: + for proto in self.acct_protocols: + await proto.close_transport() + del proto + + self.acct_protocols = [] + + @abstractmethod + def handle_auth_packet(self, protocol, pkt, addr): + pass + + @abstractmethod + def handle_acct_packet(self, protocol, pkt, addr): + pass + + @abstractmethod + def handle_coa_packet(self, protocol, pkt, addr): + pass + + @abstractmethod + def handle_disconnect_packet(self, protocol, pkt, addr): + pass