diff --git a/srudp/__init__.py b/srudp/__init__.py index c327511..1799fac 100644 --- a/srudp/__init__.py +++ b/srudp/__init__.py @@ -152,7 +152,7 @@ class SecureReliableSocket(socket): "timeout", "span", "address", "shared_key", "mtu_size", "sender_seq", "sender_buffer", "sender_signal", "sender_buffer_lock", "receiver_seq", "receiver_buffer", "receiver_signal", "receiver_buffer_lock", - "broadcast_hook_fnc", "loss", "try_connect", "established"] + "backend_lock", "broadcast_hook_fnc", "loss", "try_connect", "established"] def __init__(self, family: int = s.AF_INET, timeout: float = 21.0, span: float = 3.0) -> None: """ @@ -177,6 +177,7 @@ def __init__(self, family: int = s.AF_INET, timeout: float = 21.0, span: float = self.receiver_buffer = BytesIO() self.receiver_signal = threading.Event() self.receiver_buffer_lock = threading.Lock() + self.backend_lock = threading.Lock() # to wait packet processing finish # broadcast hook self.broadcast_hook_fnc: Optional[_BroadcastHook] = None # status @@ -383,37 +384,41 @@ def _backend(self) -> None: last_ack_time = time() while not self.is_closed: - r, _w, _x = select([self], [], [], self.span) + with self.backend_lock: + r, _w, _x = select([self], [], [], self.span) - # re-transmit - if 0 < len(self.sender_buffer): - with self.sender_buffer_lock: - now = time() - self.span * 2 - transmit_limit = MAX_RETRANSMIT_LIMIT # max transmit at once - for i, p in enumerate(self.sender_buffer): - if transmit_limit == 0: - break - if p.time < now: - self.loss += 1 - re_packet = Packet(p.control, p.sequence, p.retry+1, time(), p.data) - self.sender_buffer[i] = re_packet - self.sendto(self._encrypt(packet2bin(re_packet)), self.address) - transmit_limit -= 1 - - # send ack as ping (stream may be free) - if self.span < time() - last_ack_time: - p = Packet(CONTROL_ACK, self.receiver_seq - 1, 0, time(), b'as ping') - self.sendto(self._encrypt(packet2bin(p)), self.address) - last_ack_time = time() - - # connection may be broken - if self.timeout < time() - last_receive_time: - p = Packet(CONTROL_FIN, CYC_INT0, 0, time(), b'stream may be broken') - self.sendto(self._encrypt(packet2bin(p)), self.address) - break + # re-transmit + if 0 < len(self.sender_buffer): + with self.sender_buffer_lock: + now = time() - self.span * 2 + transmit_limit = MAX_RETRANSMIT_LIMIT # max transmit at once + for i, p in enumerate(self.sender_buffer): + if transmit_limit == 0: + break + if p.time < now: + self.loss += 1 + re_packet = Packet(p.control, p.sequence, p.retry+1, time(), p.data) + self.sender_buffer[i] = re_packet + self.sendto(self._encrypt(packet2bin(re_packet)), self.address) + transmit_limit -= 1 - # received packet - if r: + # send ack as ping (stream may be free) + if self.span < time() - last_ack_time: + p = Packet(CONTROL_ACK, self.receiver_seq - 1, 0, time(), b'as ping') + self.sendto(self._encrypt(packet2bin(p)), self.address) + last_ack_time = time() + + # connection may be broken + if self.timeout < time() - last_receive_time: + p = Packet(CONTROL_FIN, CYC_INT0, 0, time(), b'stream may be broken') + self.sendto(self._encrypt(packet2bin(p)), self.address) + break + + # just socket select timeout (no data received yet) + if len(r) == 0: + continue + + # received a packet data try: data, _addr = self.recvfrom(65536) packet = bin2packet(self._decrypt(data)) @@ -640,12 +645,15 @@ def recv(self, buflen: int = 1024, flags: int = 0) -> bytes: # check data exist if timeout is None: # blocking forever - if not self.receiver_signal.wait(): + if not self.receiver_signal.wait(1.0): continue elif timeout == 0.0: # non-blocking - if not self.receiver_signal.is_set(): - raise BlockingIOError("not data found in socket") + with self.backend_lock: + # note: wait backend to avoid check signal before backend's process finish + if not self.receiver_signal.is_set(): + raise BlockingIOError("you should ignore this error " + "because it caused by inner packet") else: # blocking for some Secs if not self.receiver_signal.wait(timeout):