Skip to content
This repository has been archived by the owner on Dec 10, 2018. It is now read-only.

add: TProcessPoolServer #330

Open
wants to merge 1 commit into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 20 additions & 5 deletions thriftpy/rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@

from __future__ import absolute_import

import signal
import contextlib
import warnings

from thriftpy.protocol import TBinaryProtocolFactory
from thriftpy.server import TThreadedServer
from thriftpy.server import TThreadedServer, TProcessPoolServer
from thriftpy.thrift import TProcessor, TClient
from thriftpy.transport import (
TBufferedTransportFactory,
Expand Down Expand Up @@ -43,11 +44,18 @@ def make_client(service, host="localhost", port=9090, unix_socket=None,
return TClient(service, protocol)


def _init_handler():
signal.signal(signal.SIGINT, signal.SIG_DFL)
signal.signal(signal.SIGTERM, signal.SIG_DFL)
signal.signal(signal.SIGHUP, signal.SIG_DFL)


def make_server(service, handler,
host="localhost", port=9090, unix_socket=None,
proto_factory=TBinaryProtocolFactory(),
trans_factory=TBufferedTransportFactory(),
client_timeout=3000, certfile=None):
client_timeout=3000, certfile=None,
num_workers=None):
processor = TProcessor(service, handler)

if unix_socket:
Expand All @@ -65,9 +73,16 @@ def make_server(service, handler,
else:
raise ValueError("Either host/port or unix_socket must be provided.")

server = TThreadedServer(processor, server_socket,
iprot_factory=proto_factory,
itrans_factory=trans_factory)
if num_workers is None:
server = TThreadedServer(processor, server_socket,
iprot_factory=proto_factory,
itrans_factory=trans_factory)
else:
server = TProcessPoolServer(processor, server_socket,
iprot_factory=proto_factory,
itrans_factory=trans_factory)
server.setNumWorkers(num_workers)
server.setPostForkCallback(_init_handler)
return server


Expand Down
103 changes: 103 additions & 0 deletions thriftpy/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
TTransportException
)

from multiprocessing import Process, Value, Condition, reduction

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -103,3 +104,105 @@ def handle(self, client):

def close(self):
self.closed = True


class TProcessPoolServer(TServer):
"""Server with a fixed size pool of worker subprocesses to service requests

Note that if you need shared state between the handlers - it's up to you!
Written by Dvir Volk, doat.com
"""
def __init__(self, *args, **kwargs):
self.daemon = kwargs.pop("daemon", False)
TServer.__init__(self, *args, **kwargs)
self.closed = False

self.numWorkers = 1
self.workers = []
self.isRunning = Value('b', False)
self.stopCondition = Condition()
self.postForkCallback = None

def setPostForkCallback(self, callback):
if not callable(callback):
raise TypeError("This is not a callback!")
self.postForkCallback = callback

def setNumWorkers(self, num):
"""Set the number of worker threads that should be created"""
self.numWorkers = num

def workerProcess(self):
"""Loop getting clients from the shared queue and process them"""
if self.postForkCallback:
self.postForkCallback()

while self.isRunning.value:
try:
client = self.trans.accept()
if not client:
continue
self.serveClient(client)
except (KeyboardInterrupt, SystemExit):
return 0
except Exception as x:
logger.exception(x)

def serveClient(self, client):
"""Process input/output from a client for as long as possible"""

itrans = self.itrans_factory.get_transport(client)
otrans = self.otrans_factory.get_transport(client)
iprot = self.iprot_factory.get_protocol(itrans)
oprot = self.oprot_factory.get_protocol(otrans)

try:
while True:
self.processor.process(iprot, oprot)
except TTransportException as tx:
pass
except Exception as x:
logger.exception(x)

itrans.close()
otrans.close()

def serve(self):
"""Start workers and put into queue"""
# this is a shared state that can tell the workers to exit when False
self.isRunning.value = True

# first bind and listen to the port
self.trans.listen()

# fork the children
for i in range(self.numWorkers):
try:
w = Process(target=self.workerProcess)
w.daemon = True
w.start()
self.workers.append(w)
except Exception as x:
logger.exception(x)

# wait until the condition is set by stop()
while True:
self.stopCondition.acquire()
try:
self.stopCondition.wait()
break
except (SystemExit, KeyboardInterrupt):
break
except Exception as x:
logger.exception(x)

self.isRunning.value = False

def stop(self):
self.isRunning.value = False
self.stopCondition.acquire()
self.stopCondition.notify()
self.stopCondition.release()

def close(self):
self.closed = True