From bc14a4cac51202f4b91e518c1d7d86daa5cb22d6 Mon Sep 17 00:00:00 2001 From: gic Date: Fri, 6 Jul 2018 11:30:10 +0200 Subject: [PATCH] issue #85 async client/server enhancements --- CHANGES.rst | 2 +- pyrad/client_async.py | 68 ++++++----------- pyrad/server_async.py | 172 +++++++++++++++--------------------------- 3 files changed, 85 insertions(+), 157 deletions(-) diff --git a/CHANGES.rst b/CHANGES.rst index 9692943..5ec3079 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -1,7 +1,7 @@ Changelog ========= -* Add async client and server implementation for python >=3.5. +* Add experimental async client and server implementation for python >=3.5. * Add IPv6 bind support for client and server. diff --git a/pyrad/client_async.py b/pyrad/client_async.py index ed3b4d8..c7b4c02 100644 --- a/pyrad/client_async.py +++ b/pyrad/client_async.py @@ -52,11 +52,7 @@ async def __timeout_handler__(self): secs = (req['send_date'] - now).seconds if secs > self.timeout: if req['retries'] == self.retries: - self.logger.debug( - '[%s:%d] For request %d execute all retries.' % ( - self.server, self.port, id - ) - ) + self.logger.debug('[%s:%d] For request %d execute all retries', self.server, self.port, id) req['future'].set_exception( TimeoutError('Timeout on Reply') ) @@ -72,6 +68,7 @@ async def __timeout_handler__(self): req['retries'] ) ) + self.transport.sendto(req['packet'].RequestPacket()) elif next_weak_up > secs: next_weak_up = secs @@ -88,9 +85,7 @@ async def __timeout_handler__(self): def send_packet(self, packet, future): if packet.id in self.pending_requests: - raise Exception('Packet with id %d already present' % ( - packet.id - )) + raise Exception('Packet with id %d already present' % packet.id) # Store packet on pending requests map self.pending_requests[packet.id] = { @@ -108,11 +103,10 @@ def connection_made(self, transport): self.transport = transport socket = transport.get_extra_info('socket') self.logger.info( - '[%s:%d] Transport created with binding in %s:%d.' % ( - self.server, self.port, - socket.getsockname()[0], - socket.getsockname()[1] - ) + '[%s:%d] Transport created with binding in %s:%d', + self.server, self.port, + socket.getsockname()[0], + socket.getsockname()[1] ) pre_loop = asyncio.get_event_loop() @@ -124,23 +118,13 @@ def connection_made(self, transport): asyncio.set_event_loop(loop=pre_loop) def error_received(self, exc): - self.logger.error( - '[%s:%d] Error received: %s.' % ( - self.server, self.port, exc - ) - ) + self.logger.error('[%s:%d] Error received: %s', self.server, self.port, exc) def connection_lost(self, exc): if exc: - self.logger.warn( - '[%s:%d] Connection lost: %s.' % ( - self.server, self.port, str(exc) - ) - ) + self.logger.warn('[%s:%d] Connection lost: %s', self.server, self.port, str(exc)) else: - self.logger.info( - '[%s:%d] Transport closed.' % (self.server, self.port) - ) + self.logger.info('[%s:%d] Transport closed', self.server, self.port) # noinspection PyUnusedLocal def datagram_received(self, data, addr): @@ -185,13 +169,10 @@ def datagram_received(self, data, addr): ] ) ) - pass async def close_transport(self): if self.transport: - self.logger.debug( - '[%s:%d] Closing transport...' % (self.server, self.port) - ) + self.logger.debug('[%s:%d] Closing transport...', self.server, self.port) self.transport.close() self.transport = None if self.timeout_future: @@ -204,10 +185,7 @@ def create_id(self): return self.packet_id def __str__(self): - return 'DatagramProtocolClient: { server: %s, port: %d }' % ( - self.server, - self.port - ) + return 'DatagramProtocolClient(server?=%s, port=%d)' % (self.server, self.port) # Used as protocol_factory def __call__(self): @@ -233,20 +211,20 @@ def __init__(self, server, auth_port=1812, acct_port=1813, """Constructor. - :param server: hostname or IP address of RADIUS server - :type server: string + :param server: hostname or IP address of RADIUS server + :type server: string :param auth_port: port to use for authentication packets :type auth_port: integer :param acct_port: port to use for accounting packets :type acct_port: integer - :param coa_port: port to use for CoA packets - :type coa_port: integer - :param secret: RADIUS secret - :type secret: string - :param dict: RADIUS dictionary - :type dict: pyrad.dictionary.Dictionary - :param loop: Python loop handler - :type loop: asyncio event loop + :param coa_port: port to use for CoA packets + :type coa_port: integer + :param secret: RADIUS secret + :type secret: string + :param dict: RADIUS dictionary + :type dict: pyrad.dictionary.Dictionary + :param loop: Python loop handler + :type loop: asyncio event loop """ if not loop: self.loop = asyncio.get_event_loop() @@ -435,7 +413,7 @@ def SendPacket(self, pkt): """Send a packet to a RADIUS server. :param pkt: the packet to send - :type pkt: pyrad.packet.Packet + :type pkt: pyrad.packet.Packet :return: Future related with packet to send :rtype: asyncio.Future """ diff --git a/pyrad/server_async.py b/pyrad/server_async.py index 9170bb6..381f285 100644 --- a/pyrad/server_async.py +++ b/pyrad/server_async.py @@ -38,31 +38,15 @@ def __init__(self, ip, port, logger, server, server_type, hosts, self.request_callback = request_callback self.requests = 0 - def __get_remote_host__(self, addr): - ans = None - if addr in self.hosts.keys(): - ans = self.hosts[addr] - return ans - def connection_made(self, transport): self.transport = transport - self.logger.info( - '[%s:%d] Transport created.' % ( - self.ip, self.port, - ) - ) + self.logger.info('[%s:%d] Transport created', self.ip, self.port) def connection_lost(self, exc): if exc: - self.logger.warn( - '[%s:%d] Connection lost: %s.' % ( - self.ip, self.port, str(exc) - ) - ) + self.logger.warn('[%s:%d] Connection lost: %s', self.ip, self.port, str(exc)) else: - self.logger.info( - '[%s:%d] Transport closed.' % (self.ip, self.port) - ) + self.logger.info('[%s:%d] Transport closed', self.ip, self.port) def send_response(self, reply, addr): if self.server.debug: @@ -73,41 +57,42 @@ def send_response(self, reply, addr): ) self.transport.sendto(reply.ReplyPacket(), addr) - def datagram_received(self, data, addr): + def __get_remote_host__(self, addr): + ans = None + if addr in self.hosts.keys(): + ans = self.hosts[addr] + return ans - self.logger.debug( - '[%s:%d] Received %d bytes from %s.' % ( - self.ip, self.port, len(data), addr - ) - ) + def datagram_received(self, data, addr): + self.logger.debug('[%s:%d] Received %d bytes from %s', self.ip, self.port, len(data), addr) receive_date = datetime.utcnow() - req = None - try: - remote_host = self.__get_remote_host__(addr[0]) + remote_host = self.__get_remote_host__(addr[0]) - if remote_host: + if remote_host: - try: - if self.server.debug: - self.logger.info( - '[%s:%d] Received from %s packet: %s.' % ( - self.ip, self.port, addr, data.hex() - ) + try: + if self.server.debug: + self.logger.info( + '[%s:%d] Received from %s packet: %s.' % ( + self.ip, self.port, addr, data.hex() ) - req = Packet(packet=data, dict=self.server.dict) + ) + req = Packet(packet=data, dict=self.server.dict) - except Exception as exc: - self.logger.error( - '[%s:%d] Error on decode packet: %s. Ignore it.' % ( - self.ip, self.port, exc - ) + except Exception as exc: + self.logger.error( + '[%s:%d] Error on decode packet: %s. Ignore it.' % ( + self.ip, self.port, exc ) + ) + req = None - if not req: - return + if not req: + return + try: if req.code in ( AccountingResponse, AccessAccept, @@ -163,51 +148,36 @@ def datagram_received(self, data, addr): self.requests += 1 - else: - # POST: Unknown source - self.logger.warn( - '[%s:%d] Received package from unknown source %s. Drop data.' % ( - self.ip, self.port, addr + except Exception as e: + self.logger.error( + '[%s:%d] Unexpected error for packet from %s: %s' % ( + self.ip, self.port, addr, + (e, '\n'.join(traceback.format_exc().splitlines()))[ + self.server.debug + ] ) ) - except Exception as e: - self.logger.error( - '[%s:%d] Unexpected error for packet from %s: %s' % ( - self.ip, self.port, addr, - (e, '\n'.join(traceback.format_exc().splitlines()))[ - self.server.debug - ] - ) - ) - process_date = datetime.utcnow() + else: + self.logger.error('[%s:%d] Drop package from unknown source %s', + self.ip, self.port, addr) - self.logger.debug( - '[%s:%d] Request from %s processed in %d ms.' % ( - self.ip, self.port, addr, - (process_date-receive_date).microseconds/1000 - ) - ) + process_date = datetime.utcnow() + self.logger.debug('[%s:%d] Request from %s processed in %d ms', + self.ip, self.port, addr, + (process_date-receive_date).microseconds/1000) def error_received(self, exc): - self.logger.error( - '[%s:%d] Error received: %s.' % ( - self.ip, self.port, exc - ) - ) + self.logger.error('[%s:%d] Error received: %s', self.ip, self.port, exc) async def close_transport(self): if self.transport: - self.logger.debug( - '[%s:%d] Closing transport...' % (self.ip, self.port) - ) + self.logger.debug('[%s:%d] Close transport...', self.ip, self.port) self.transport.close() self.transport = None def __str__(self): - return 'DatagramProtocolServer: { ip: %s, port: %d }' % ( - self.ip, self.port - ) + return 'DatagramProtocolServer(ip=%s, port=%d)' % (self.ip, self.port) # Used as protocol_factory def __call__(self): @@ -261,43 +231,28 @@ def __request_handler__(self, protocol, req, addr): req.code == DisconnectRequest: self.handle_disconnect_packet(protocol, req, addr) else: - self.logger.error( - '[%s:%s] Unexpected request found.' % ( - protocol.ip, protocol.port - ) - ) + self.logger.error('[%s:%s] Unexpected request found', protocol.ip, protocol.port) except Exception as exc: - self.logger.error( - '[%s:%s] Unexpected error catched: %s' % ( - protocol.ip, protocol.port, - (exc, '\n'.join(traceback.format_exc().splitlines()))[ - self.debug - ] - ) - ) + if self.debug: + self.logger.exception('[%s:%s] Unexpected error', protocol.ip, protocol.port) - def __is_present_proto__(self, ip, port): - ans = False + else: + self.logger.error('[%s:%s] Unexpected error: %s', protocol.ip, protocol.port, exc) + def __is_present_proto__(self, ip, port): if port == self.auth_port: for proto in self.auth_protocols: if proto.ip == ip: - ans = True - break - - if port == self.acct_port: + return True + elif port == self.acct_port: for proto in self.acct_protocols: if proto.ip == ip: - ans = True - break - - if port == self.coa_port: + return True + elif port == self.coa_port: for proto in self.coa_protocols: if proto.ip == ip: - ans = True - break - - return ans + return True + return False # noinspection PyPep8Naming @staticmethod @@ -326,8 +281,7 @@ async def initialize_transports(self, enable_acct=False, # noinspection SpellCheckingInspection for addr in addresses: - if enable_acct and not self.__is_present_proto__(addr, - self.acct_port): + if enable_acct and not self.__is_present_proto__(addr, self.acct_port): protocol_acct = DatagramProtocolServer( addr, self.acct_port, @@ -346,8 +300,7 @@ async def initialize_transports(self, enable_acct=False, self.acct_protocols.append(protocol_acct) task_list.append(acct_connect) - if enable_auth and not self.__is_present_proto__(addr, - self.auth_port): + if enable_auth and not self.__is_present_proto__(addr, self.auth_port): protocol_auth = DatagramProtocolServer( addr, self.auth_port, @@ -366,8 +319,7 @@ async def initialize_transports(self, enable_acct=False, self.auth_protocols.append(protocol_auth) task_list.append(auth_connect) - if enable_coa and not self.__is_present_proto__(addr, - self.coa_port): + if enable_coa and not self.__is_present_proto__(addr, self.coa_port): protocol_coa = DatagramProtocolServer( addr, self.coa_port, @@ -407,9 +359,7 @@ def stats(self): return ans # noinspection SpellCheckingInspection - async def deinitialize_transports(self, deinit_coa=True, - deinit_auth=True, - deinit_acct=True): + async def deinitialize_transports(self, deinit_coa=True, deinit_auth=True, deinit_acct=True): if deinit_coa: for proto in self.coa_protocols: