From bdbc50fff5f59fa132675b0f33686ad39ba2e1b1 Mon Sep 17 00:00:00 2001 From: Lex Li Date: Sun, 11 Feb 2024 01:23:38 -0500 Subject: [PATCH] Simplified dispatch. Added some type annotations. --- pysnmp/carrier/asyncio/dgram/udp6.py | 4 +- pysnmp/carrier/asyncio/dispatch.py | 31 ++----- pysnmp/carrier/base.py | 131 +++++++++++++++------------ 3 files changed, 83 insertions(+), 83 deletions(-) diff --git a/pysnmp/carrier/asyncio/dgram/udp6.py b/pysnmp/carrier/asyncio/dgram/udp6.py index 358aa7d4e..694ffe5c3 100644 --- a/pysnmp/carrier/asyncio/dgram/udp6.py +++ b/pysnmp/carrier/asyncio/dgram/udp6.py @@ -5,10 +5,12 @@ # License: https://www.pysnmp.com/pysnmp/license.html # import socket +from typing import Tuple from pysnmp.carrier.base import AbstractTransportAddress from pysnmp.carrier.asyncio.dgram.base import DgramAsyncioProtocol - +domainName: Tuple[int, ...] +snmpUDP6Domain: Tuple[int, ...] domainName = snmpUDP6Domain = (1, 3, 6, 1, 2, 1, 100, 1, 2) diff --git a/pysnmp/carrier/asyncio/dispatch.py b/pysnmp/carrier/asyncio/dispatch.py index d0c1b0314..961b0557c 100644 --- a/pysnmp/carrier/asyncio/dispatch.py +++ b/pysnmp/carrier/asyncio/dispatch.py @@ -34,10 +34,9 @@ # THE POSSIBILITY OF SUCH DAMAGE. # import sys -import platform -from time import time import traceback -from pysnmp.carrier.base import AbstractTransportDispatcher +from typing import Tuple +from pysnmp.carrier.base import AbstractTransport, AbstractTransportDispatcher from pysnmp.error import PySnmpError import asyncio @@ -46,20 +45,15 @@ class AsyncioDispatcher(AbstractTransportDispatcher): """AsyncioDispatcher based on asyncio event loop""" + loop: asyncio.AbstractEventLoop + __transportCount: int + def __init__(self, *args, **kwargs): AbstractTransportDispatcher.__init__(self) self.__transportCount = 0 - if 'timeout' in kwargs: - self.setTimerResolution(kwargs['timeout']) - self.loopingcall = None self.loop = kwargs.pop('loop', asyncio.get_event_loop()) - async def handle_timeout(self): - while True: - await asyncio.sleep(self.getTimerResolution()) - self.handleTimerTick(time()) - - def runDispatcher(self, timeout=0.0): + def runDispatcher(self, timeout: float = 0.0): if not self.loop.is_running(): try: if timeout > 0: @@ -75,23 +69,14 @@ def __closeDispatcher(self): self.loop.stop() super().closeDispatcher() - def registerTransport(self, tDomain, transport): - if self.loopingcall is None and self.getTimerResolution() > 0: - self.loopingcall = asyncio.ensure_future(self.handle_timeout()) + def registerTransport(self, tDomain: Tuple[int, ...], transport: AbstractTransport): AbstractTransportDispatcher.registerTransport( self, tDomain, transport ) self.__transportCount += 1 - def unregisterTransport(self, tDomain): + def unregisterTransport(self, tDomain: Tuple[int, ...]): t = AbstractTransportDispatcher.getTransport(self, tDomain) if t is not None: AbstractTransportDispatcher.unregisterTransport(self, tDomain) self.__transportCount -= 1 - - # The last transport has been removed, stop the timeout - if self.__transportCount == 0 and not self.loopingcall.done(): - self.loopingcall.cancel() - self.loopingcall = None - - diff --git a/pysnmp/carrier/base.py b/pysnmp/carrier/base.py index e70ad3b8d..4f0f604a4 100644 --- a/pysnmp/carrier/base.py +++ b/pysnmp/carrier/base.py @@ -2,11 +2,14 @@ # This file is part of pysnmp software. # # Copyright (c) 2005-2019, Ilya Etingof +# Copyright (C) 2024, LeXtudio Inc. # License: https://www.pysnmp.com/pysnmp/license.html # import sys +from typing import Tuple from pysnmp.carrier import error +from typing import Tuple class TimerCallable: @@ -48,7 +51,65 @@ def interval(self, callInterval): self.__callInterval = callInterval +class AbstractTransportAddress: + _localAddress = None + + def setLocalAddress(self, s): + self._localAddress = s + return self + + def getLocalAddress(self): + return self._localAddress + + def clone(self, localAddress=None): + return self.__class__(self).setLocalAddress(localAddress is None and self.getLocalAddress() or localAddress) + + +class AbstractTransport: + protoTransportDispatcher = None + addressType = AbstractTransportAddress + _cbFun = None + + @classmethod + def isCompatibleWithDispatcher(cls, transportDispatcher): + return isinstance(transportDispatcher, cls.protoTransportDispatcher) + + def registerCbFun(self, cbFun): + if self._cbFun: + raise error.CarrierError( + f'Callback function {self._cbFun} already registered at {self}' + ) + self._cbFun = cbFun + + def unregisterCbFun(self): + self._cbFun = None + + def closeTransport(self): + self.unregisterCbFun() + + # Public API + + def openClientMode(self, iface=None): + raise error.CarrierError('Method not implemented') + + def openServerMode(self, iface): + raise error.CarrierError('Method not implemented') + + def sendMessage(self, outgoingMessage, transportAddress: AbstractTransportAddress): + raise error.CarrierError('Method not implemented') + + class AbstractTransportDispatcher: + __transports: dict[Tuple[int, ...], AbstractTransport] + __transportDomainMap: dict[AbstractTransport, Tuple[int, ...]] + __recvCallables: dict[str, callable] + __timerCallables: list[TimerCallable] + __ticks: int + __timerResolution: float + __timerResolution: float + __timerDelta: float + __nextTime: float + def __init__(self): self.__transports = {} self.__transportDomainMap = {} @@ -61,7 +122,7 @@ def __init__(self): self.__nextTime = 0 self.__routingCbFun = None - def _cbFun(self, incomingTransport, transportAddress, incomingMessage): + def _cbFun(self, incomingTransport: AbstractTransport, transportAddress: AbstractTransportAddress, incomingMessage): if incomingTransport in self.__transportDomainMap: transportDomain = self.__transportDomainMap[incomingTransport] else: @@ -120,7 +181,7 @@ def unregisterTimerCbFun(self, timerCbFun=None): else: self.__timerCallables = [] - def registerTransport(self, tDomain, transport): + def registerTransport(self, tDomain: Tuple[int, ...], transport: AbstractTransport): if tDomain in self.__transports: raise error.CarrierError( f'Transport {tDomain} already registered' @@ -129,7 +190,7 @@ def registerTransport(self, tDomain, transport): self.__transports[tDomain] = transport self.__transportDomainMap[transport] = tDomain - def unregisterTransport(self, tDomain): + def unregisterTransport(self, tDomain: Tuple[int, ...]): if tDomain not in self.__transports: raise error.CarrierError( f'Transport {tDomain} not registered' @@ -138,15 +199,15 @@ def unregisterTransport(self, tDomain): del self.__transportDomainMap[self.__transports[tDomain]] del self.__transports[tDomain] - def getTransport(self, transportDomain): + def getTransport(self, transportDomain: Tuple[int, ...]): if transportDomain in self.__transports: return self.__transports[transportDomain] raise error.CarrierError( f'Transport {transportDomain} not registered' ) - def sendMessage(self, outgoingMessage, transportDomain, - transportAddress): + def sendMessage(self, outgoingMessage, transportDomain: Tuple[int, ...], + transportAddress: AbstractTransportAddress): if transportDomain in self.__transports: self.__transports[transportDomain].sendMessage( outgoingMessage, transportAddress @@ -159,7 +220,7 @@ def sendMessage(self, outgoingMessage, transportDomain, def getTimerResolution(self): return self.__timerResolution - def setTimerResolution(self, timerResolution): + def setTimerResolution(self, timerResolution: float): if timerResolution < 0.01 or timerResolution > 10: raise error.CarrierError('Impossible timer resolution') @@ -174,7 +235,7 @@ def setTimerResolution(self, timerResolution): def getTimerTicks(self): return self.__ticks - def handleTimerTick(self, timeNow): + def handleTimerTick(self, timeNow: float): if self.__nextTime == 0: # initial initialization self.__nextTime = timeNow + self.__timerResolution - self.__timerDelta @@ -187,13 +248,13 @@ def handleTimerTick(self, timeNow): for timerCallable in self.__timerCallables: timerCallable(timeNow) - def jobStarted(self, jobId, count=1): + def jobStarted(self, jobId, count: int = 1): if jobId in self.__jobs: self.__jobs[jobId] += count else: self.__jobs[jobId] = count - def jobFinished(self, jobId, count=1): + def jobFinished(self, jobId, count: int = 1): self.__jobs[jobId] -= count if self.__jobs[jobId] == 0: del self.__jobs[jobId] @@ -201,7 +262,7 @@ def jobFinished(self, jobId, count=1): def jobsArePending(self): return bool(self.__jobs) - def runDispatcher(self, timeout=0.0): + def runDispatcher(self, timeout: float = 0.0): raise error.CarrierError('Method not implemented') def closeDispatcher(self): @@ -211,51 +272,3 @@ def closeDispatcher(self): self.__transports.clear() self.unregisterRecvCbFun() self.unregisterTimerCbFun() - - -class AbstractTransportAddress: - _localAddress = None - - def setLocalAddress(self, s): - self._localAddress = s - return self - - def getLocalAddress(self): - return self._localAddress - - def clone(self, localAddress=None): - return self.__class__(self).setLocalAddress(localAddress is None and self.getLocalAddress() or localAddress) - - -class AbstractTransport: - protoTransportDispatcher = None - addressType = AbstractTransportAddress - _cbFun = None - - @classmethod - def isCompatibleWithDispatcher(cls, transportDispatcher): - return isinstance(transportDispatcher, cls.protoTransportDispatcher) - - def registerCbFun(self, cbFun): - if self._cbFun: - raise error.CarrierError( - f'Callback function {self._cbFun} already registered at {self}' - ) - self._cbFun = cbFun - - def unregisterCbFun(self): - self._cbFun = None - - def closeTransport(self): - self.unregisterCbFun() - - # Public API - - def openClientMode(self, iface=None): - raise error.CarrierError('Method not implemented') - - def openServerMode(self, iface): - raise error.CarrierError('Method not implemented') - - def sendMessage(self, outgoingMessage, transportAddress): - raise error.CarrierError('Method not implemented')