diff --git a/Dockerfile b/Dockerfile index 8552270..5dc5d21 100644 --- a/Dockerfile +++ b/Dockerfile @@ -2,7 +2,15 @@ FROM ubuntu:focal LABEL org.opencontainers.image.authors="benjamin.c.burns@gmail.com" -RUN apt-get update && apt-get install -y python2 python2-dev iptables dnsmasq uml-utilities net-tools build-essential curl && apt-get clean +RUN apt-get update && apt-get install -y \ + python2-minimal \ + iptables \ + dnsmasq \ + uml-utilities \ + net-tools \ + build-essential \ + curl \ + && apt-get clean RUN curl https://bootstrap.pypa.io/pip/2.7/get-pip.py --output get-pip.py && python2 get-pip.py && rm get-pip.py @@ -16,5 +24,3 @@ RUN pip2 install -r /opt/websockproxy/requirements.txt EXPOSE 80 CMD /opt/websockproxy/docker-startup.sh - - diff --git a/docker-image-config/docker-startup.sh b/docker-image-config/docker-startup.sh index b7a765a..24f873a 100755 --- a/docker-image-config/docker-startup.sh +++ b/docker-image-config/docker-startup.sh @@ -11,7 +11,7 @@ ifconfig tap0 up ###################### ## IP Forwarding config for TAP device ## -echo 1 > /proc/sys/net/ipv4/ip_forward +echo 1 >/proc/sys/net/ipv4/ip_forward /sbin/iptables -t nat -A POSTROUTING -o eth0 -j MASQUERADE /sbin/iptables -A FORWARD -i eth0 -o tap0 -m state --state RELATED,ESTABLISHED -j ACCEPT diff --git a/limiter.py b/limiter.py index 1abf8bb..6822c48 100644 --- a/limiter.py +++ b/limiter.py @@ -1,5 +1,6 @@ import time + class RateLimitingState(object): def __init__(self, rate, clientip, name): self.name = name @@ -16,10 +17,10 @@ def do_throttle(self, message): self.allowance += time_passed * self.rate if self.allowance > self.rate: - self.allowance = self.rate #throttle + self.allowance = self.rate # throttle if self.allowance > 1.0: self.allowance -= len(message) - return True; + return True return False diff --git a/requirements.txt b/requirements.txt index 4baecff..2165975 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,3 @@ -argparse==1.2.1 -python-pytun==2.2.1 -tornado==3.1.1 -wsgiref==0.1.2 +argparse==1.4.0 +python-pytun==2.4.1 +tornado==5.1.1 diff --git a/switchedrelay.py b/switchedrelay.py index 6e68410..73eb06a 100644 --- a/switchedrelay.py +++ b/switchedrelay.py @@ -5,12 +5,10 @@ import traceback import functools -from select import poll -from select import POLLIN, POLLOUT, POLLHUP, POLLERR, POLLNVAL +from select import poll, POLLIN from pytun import TunTapDevice, IFF_TAP, IFF_NO_PI - from limiter import RateLimitingState import tornado.ioloop @@ -22,16 +20,23 @@ from tornado import websocket -FORMAT = '%(asctime)-15s %(message)s' -RATE = 40980.0 #unit: bytes -BROADCAST = '%s%s%s%s%s%s' % (chr(0xff),chr(0xff),chr(0xff),chr(0xff),chr(0xff),chr(0xff)) +FORMAT = "%(asctime)-15s %(message)s" +RATE = 40980.0 # unit: bytes +BROADCAST = "%s%s%s%s%s%s" % ( + chr(0xFF), + chr(0xFF), + chr(0xFF), + chr(0xFF), + chr(0xFF), + chr(0xFF), +) PING_INTERVAL = 30 -logger = logging.getLogger('relay') - +logger = logging.getLogger("relay") macmap = {} + @return_future def delay_future(t, callback): timestamp = time.time() @@ -40,13 +45,14 @@ def delay_future(t, callback): else: callback(t) + class TunThread(threading.Thread): def __init__(self, *args, **kwargs): super(TunThread, self).__init__(*args, **kwargs) self.running = True - self.tun = TunTapDevice(name="tap0", flags= (IFF_TAP | IFF_NO_PI)) - self.tun.addr = '10.5.0.1' - self.tun.netmask = '255.255.0.0' + self.tun = TunTapDevice(name="tap0", flags=(IFF_TAP | IFF_NO_PI)) + self.tun.addr = "10.5.0.1" + self.tun.netmask = "255.255.0.0" self.tun.mtu = 1500 self.tun.up() @@ -57,16 +63,17 @@ def run(self): p = poll() p.register(self.tun, POLLIN) try: - while(self.running): - #TODO: log IP headers in the future + while self.running: + # TODO: log IP headers in the future pollret = p.poll(1000) - for (f,e) in pollret: + for f, e in pollret: if f == self.tun.fileno() and (e & POLLIN): - buf = self.tun.read(self.tun.mtu+18) #MTU doesn't include header or CRC32 + buf = self.tun.read(self.tun.mtu + 18) # MTU doesn't include header or CRC32 if len(buf): mac = buf[0:6] if mac == BROADCAST or (ord(mac[0]) & 0x1) == 1: for socket in macmap.values(): + def send_message(socket): try: socket.rate_limited_downstream(str(buf)) @@ -76,6 +83,7 @@ def send_message(socket): loop.add_callback(functools.partial(send_message, socket)) elif macmap.get(mac, False): + def send_message(): try: macmap[mac].rate_limited_downstream(str(buf)) @@ -84,30 +92,29 @@ def send_message(): loop.add_callback(send_message) except: - logger.error('closing due to tun error') + logger.error("closing due to tun error") finally: self.tun.close() class MainHandler(websocket.WebSocketHandler): def __init__(self, *args, **kwargs): - super(MainHandler, self).__init__(*args,**kwargs) - self.remote_ip = self.request.headers.get('X-Forwarded-For', self.request.remote_ip) - logger.info('%s: connected.' % self.remote_ip) + super(MainHandler, self).__init__(*args, **kwargs) + self.remote_ip = self.request.headers.get("X-Forwarded-For", self.request.remote_ip) + logger.info("%s: connected." % self.remote_ip) self.thread = None - self.mac = '' - self.allowance = RATE #unit: messages - self.last_check = time.time() #floating-point, e.g. usec accuracy. Unit: seconds - self.upstream = RateLimitingState(RATE, name='upstream', clientip=self.remote_ip) - self.downstream = RateLimitingState(RATE, name='downstream', clientip=self.remote_ip) - - ping_future = delay_future(time.time()+PING_INTERVAL, self.do_ping) + self.mac = "" + self.allowance = RATE # unit: messages + self.last_check = (time.time()) # floating-point, e.g. usec accuracy. Unit: seconds + self.upstream = RateLimitingState(RATE, name="upstream", clientip=self.remote_ip) + self.downstream = RateLimitingState(RATE, name="downstream", clientip=self.remote_ip) + ping_future = delay_future(time.time() + PING_INTERVAL, self.do_ping) loop.add_future(ping_future, lambda: None) def do_ping(self, timestamp): self.ping(str(timestamp)) - ping_future = delay_future(time.time()+PING_INTERVAL, self.do_ping) + ping_future = delay_future(time.time() + PING_INTERVAL, self.do_ping) loop.add_future(ping_future, lambda: None) def on_pong(self, data): @@ -121,16 +128,16 @@ def open(self): self.set_nodelay(True) def on_message(self, message): - #TODO: log IP headers in the future + # TODO: log IP headers in the future - #Logs which user is tied to which MAC so that we detect which user is acting maliciously + # Logs which user is tied to which MAC so that we detect which user is acting maliciously if self.mac != message[6:12]: if macmap.get(self.mac, False): del macmap[self.mac] self.mac = message[6:12] - formatted_mac = ':'.join('{0:02x}'.format(ord(a)) for a in message[6:12]) - logger.info('%s: using mac %s' % (self.remote_ip, formatted_mac)) + formatted_mac = ":".join("{0:02x}".format(ord(a)) for a in message[6:12]) + logger.info("%s: using mac %s" % (self.remote_ip, formatted_mac)) macmap[self.mac] = self @@ -140,7 +147,7 @@ def on_message(self, message): if self.upstream.do_throttle(message): for socket in macmap.values(): try: - socket.write_message(str(message),binary=True) + socket.write_message(str(message), binary=True) except: pass @@ -148,7 +155,7 @@ def on_message(self, message): elif macmap.get(dest, False): if self.upstream.do_throttle(message): try: - macmap[dest].write_message(str(message),binary=True) + macmap[dest].write_message(str(message), binary=True) except: pass else: @@ -157,14 +164,14 @@ def on_message(self, message): except: tb = traceback.format_exc() - logger.error('%s: error on receive. Closing\n%s' % (self.remote_ip, tb)) + logger.error("%s: error on receive. Closing\n%s" % (self.remote_ip, tb)) try: self.close() except: pass def on_close(self): - logger.info('%s: disconnected.' % self.remote_ip) + logger.info("%s: disconnected." % self.remote_ip) if self.thread is not None: self.thread.running = False @@ -174,13 +181,16 @@ def on_close(self): except: pass -application = tornado.web.Application([(r'/', MainHandler)]) + def check_origin(self, origin): + return True + -if __name__ == '__main__': +application = tornado.web.Application([(r"/", MainHandler)]) +if __name__ == "__main__": tunthread = TunThread() tunthread.start() - + args = sys.argv tornado.options.parse_command_line(args) application.listen(80) @@ -190,5 +200,4 @@ def on_close(self): except: pass - tunthread.running = False; - + tunthread.running = False