From ef9d2f894f520c41a6f91ada1ef21623cbe51552 Mon Sep 17 00:00:00 2001 From: Dawsh Date: Wed, 31 Jul 2024 06:13:40 +0330 Subject: [PATCH] feat(backends): implement hysteria2 support (#3) * feat(backends): implement hysteria2 support * refactor(backends): make it sane * fix(storage): inbound removal in memory storage * refactor(service): getting stats from backends * fix(hysteria): stop hysteria on exit * feat(config): allow enabling and disabling both xray and hysteria2 * feat(config): parsing obfuscation settings and port from hysteria config * feat(hysteria): restarts and hashing the key * fix(xray): reuse the same runner and also capture stderr * feat(hysteria): log capturing * feat(hysteria): set inbound tls correctly --- .env.example | 7 ++ hysteria.yaml | 16 ++++ marznode/backends/base.py | 2 +- marznode/backends/hysteria2/_config.py | 57 ++++++++++++ marznode/backends/hysteria2/_runner.py | 76 ++++++++++++++++ marznode/backends/hysteria2/interface.py | 106 +++++++++++++++++++++++ marznode/backends/xray/_config.py | 10 +++ marznode/backends/xray/_runner.py | 49 ++++++----- marznode/backends/xray/interface.py | 41 ++++----- marznode/config.py | 11 +++ marznode/marznode.py | 24 +++-- marznode/service/service.py | 24 ++--- marznode/storage/base.py | 13 ++- marznode/storage/memory.py | 9 +- requirements.txt | 3 + 15 files changed, 385 insertions(+), 63 deletions(-) create mode 100644 hysteria.yaml create mode 100644 marznode/backends/hysteria2/_config.py create mode 100644 marznode/backends/hysteria2/_runner.py create mode 100644 marznode/backends/hysteria2/interface.py diff --git a/.env.example b/.env.example index d25c2e0..23c7ed6 100644 --- a/.env.example +++ b/.env.example @@ -2,11 +2,18 @@ #SERVICE_PORT=53042 #INSECURE=False +#XRAY_ENABLED=True #XRAY_EXECUTABLE_PATH=/usr/bin/xray #XRAY_ASSETS_PATH=/usr/share/xray #XRAY_CONFIG_PATH=/etc/xray/xray_config.json #XRAY_VLESS_REALITY_FLOW=xtls-rprx-vision + +#HYSTERIA_ENABLED=False +#HYSTERIA_EXECUTABLE_PATH=/usr/bin/hysteria +#HYSTERIA_CONFIG_PATH=/etc/hysteria/config.yaml + + #SSL_KEY_FILE=./server.key #SSL_CERT_FILE=./server.cert #SSL_CLIENT_CERT_FILE=./client.cert diff --git a/hysteria.yaml b/hysteria.yaml new file mode 100644 index 0000000..c6f1c9d --- /dev/null +++ b/hysteria.yaml @@ -0,0 +1,16 @@ +listen: :4443 + +tls: + cert: ./ssl_cert.pem + key: ./ssl_key.pem + +auth: + type: command + command: echo + +masquerade: + type: proxy + proxy: + url: https://news.ycombinator.com/ + rewriteHost: true + diff --git a/marznode/backends/base.py b/marznode/backends/base.py index 3d72e58..c14ab96 100644 --- a/marznode/backends/base.py +++ b/marznode/backends/base.py @@ -13,7 +13,7 @@ def contains_tag(self, tag: str) -> bool: raise NotImplementedError @abstractmethod - async def start(self) -> None: + async def start(self, backend_config: Any) -> None: raise NotImplementedError @abstractmethod diff --git a/marznode/backends/hysteria2/_config.py b/marznode/backends/hysteria2/_config.py new file mode 100644 index 0000000..37359f5 --- /dev/null +++ b/marznode/backends/hysteria2/_config.py @@ -0,0 +1,57 @@ +import yaml + +from marznode.models import Inbound +from marznode.storage import BaseStorage + + +class HysteriaConfig: + def __init__( + self, + config: str, + api_port: int = 9090, + stats_port: int = 9999, + stats_secret: str = "pretty_secret", + ): + loaded_config = yaml.safe_load(config) + loaded_config["auth"] = { + "type": "http", + "http": {"url": "http://127.0.0.1:" + str(api_port)}, + } + loaded_config["trafficStats"] = { + "listen": "127.0.0.1:" + str(stats_port), + "secret": stats_secret, + } + self._config = loaded_config + + port = 443 + if "listen" in loaded_config: + try: + port = int(loaded_config.get("listen").split(":")[-1]) + except ValueError: + pass + obfs_type, obfs_password = None, None + + if "obfs" in loaded_config: + try: + obfs_type = loaded_config["obfs"]["type"] + obfs_password = loaded_config["obfs"][obfs_type]["password"] + except: + pass + + self._inbound = { + "tag": "hysteria2", + "protocol": "hysteria2", + "port": port, + "tls": "tls", + } + if obfs_type and obfs_password: + self._inbound.update({"path": obfs_password, "header_type": obfs_type}) + + def register_inbounds(self, storage: BaseStorage): + inbound = self._inbound + storage.register_inbound( + Inbound(tag=inbound["tag"], protocol=inbound["protocol"], config=inbound) + ) + + def render(self): + return self._config diff --git a/marznode/backends/hysteria2/_runner.py b/marznode/backends/hysteria2/_runner.py new file mode 100644 index 0000000..66d6357 --- /dev/null +++ b/marznode/backends/hysteria2/_runner.py @@ -0,0 +1,76 @@ +import asyncio +import atexit +import logging +import tempfile +from collections import deque + +import yaml +from anyio import BrokenResourceError, ClosedResourceError, create_memory_object_stream + +logger = logging.getLogger(__name__) + + +class Hysteria: + def __init__(self, executable_path: str): + self._executable_path = executable_path + self._process = None + self._snd_streams = [] + self._logs_buffer = deque(maxlen=100) + self._capture_task = None + atexit.register(lambda: self.stop() if self.started else None) + + async def start(self, config: dict): + with tempfile.NamedTemporaryFile( + mode="w", suffix=".yaml", delete=False + ) as temp_file: + yaml.dump(config, temp_file) + cmd = [self._executable_path, "server", "-c", temp_file.name] + + self._process = await asyncio.create_subprocess_shell( + " ".join(cmd), + stdin=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + stdout=asyncio.subprocess.PIPE, + ) + logger.info("Hysteria has started") + asyncio.create_task(self.__capture_process_logs()) + + def stop(self): + if self.started: + self._process.terminate() + + @property + def started(self): + return self._process and self._process.returncode is None + + async def __capture_process_logs(self): + """capture the logs, push it into the stream, and store it in the deck + note that the stream blocks sending if it's full, so a deck is necessary""" + + async def capture_stream(stream): + while True: + output = await stream.readline() + if output == b"": + """break in case of eof""" + return + for stm in self._snd_streams: + try: + await stm.send(output) + except (ClosedResourceError, BrokenResourceError): + self._snd_streams.remove(stm) + continue + self._logs_buffer.append(output) + + await asyncio.gather( + capture_stream(self._process.stderr), capture_stream(self._process.stdout) + ) + + def get_logs_stm(self): + new_snd_stm, new_rcv_stm = create_memory_object_stream() + self._snd_streams.append(new_snd_stm) + return new_rcv_stm + + def get_buffer(self): + """makes a copy of the buffer, so it could be read multiple times + the buffer is never cleared in case logs from xray's exit are useful""" + return self._logs_buffer.copy() diff --git a/marznode/backends/hysteria2/interface.py b/marznode/backends/hysteria2/interface.py new file mode 100644 index 0000000..6cf4c7b --- /dev/null +++ b/marznode/backends/hysteria2/interface.py @@ -0,0 +1,106 @@ +import json +import logging +from secrets import token_hex +from typing import AsyncIterator, Any + +import aiohttp +from aiohttp import web + +from marznode.backends.base import VPNBackend +from marznode.backends.hysteria2._config import HysteriaConfig +from marznode.backends.hysteria2._runner import Hysteria +from marznode.models import User, Inbound +from marznode.storage import BaseStorage +from marznode.utils.key_gen import generate_password +from marznode.utils.network import find_free_port + +logger = logging.getLogger(__name__) + + +class HysteriaBackend(VPNBackend): + def __init__(self, executable_path: str, storage: BaseStorage): + self._executable_path = executable_path + self._storage = storage + self._inbounds = ["hysteria2"] + self._users = {} + self._auth_site = None + self._runner = Hysteria(self._executable_path) + self._stats_secret = None + self._stats_port = None + + def contains_tag(self, tag: str) -> bool: + return bool(tag == "hysteria2") + + async def start(self, config_path: str) -> None: + api_port = find_free_port() + self._stats_port = find_free_port() + self._stats_secret = token_hex(16) + if self._auth_site: + await self._auth_site.stop() + app = web.Application() + app.router.add_post("/", self._auth_callback) + app_runner = web.AppRunner(app) + await app_runner.setup() + + self._auth_site = web.TCPSite(app_runner, "127.0.0.1", api_port) + + await self._auth_site.start() + with open(config_path) as f: + config = f.read() + cfg = HysteriaConfig(config, api_port, self._stats_port, self._stats_secret) + cfg.register_inbounds(self._storage) + await self._runner.start(cfg.render()) + + async def stop(self): + await self._auth_site.stop() + self._storage.remove_inbound("hysteria2") + self._runner.stop() + + async def restart(self, backend_config: Any) -> None: + await self.stop() + await self.start(backend_config) + + async def add_user(self, user: User, inbound: Inbound) -> None: + password = generate_password(user.key) + self._users.update({password: user}) + + async def remove_user(self, user: User, inbound: Inbound) -> None: + self._users.pop(user.key) + url = "http://127.0.0.1:" + str(self._stats_port) + "/kick" + headers = {"Authorization": self._stats_secret} + + payload = json.dumps([str(user.id) + "." + user.username]) + async with aiohttp.ClientSession() as session: + async with session.post(url, data=payload, headers=headers): + pass + + async def get_logs(self, include_buffer: bool) -> AsyncIterator: + if include_buffer: + buffer = self._runner.get_buffer() + for line in buffer: + yield line + log_stm = self._runner.get_logs_stm() + async with log_stm: + async for line in log_stm: + yield line + + async def get_usages(self): + url = "http://127.0.0.1:" + str(self._stats_port) + "/traffic?clear=1" + headers = {"Authorization": self._stats_secret} + + async with aiohttp.ClientSession() as session: + async with session.get(url, headers=headers) as response: + data = await response.json() + usages = {} + for user_identifier, usage in data.items(): + uid = int(user_identifier.split(".")[0]) + usages[uid] = usage["tx"] + usage["rx"] + return usages + + async def _auth_callback(self, request: web.Request): + user_key = (await request.json())["auth"] + if user := self._users.get(user_key): + return web.Response( + body=json.dumps({"ok": True, "id": str(user.id) + "." + user.username}), + ) + return web.Response(status=404) diff --git a/marznode/backends/xray/_config.py b/marznode/backends/xray/_config.py index b7d9fbe..b813e76 100644 --- a/marznode/backends/xray/_config.py +++ b/marznode/backends/xray/_config.py @@ -4,6 +4,8 @@ from marznode.config import XRAY_EXECUTABLE_PATH, XRAY_VLESS_REALITY_FLOW from ._utils import get_x25519 +from ...models import Inbound +from ...storage import BaseStorage class XrayConfig(dict): @@ -191,5 +193,13 @@ def get_outbound(self, tag) -> dict: if outbound["tag"] == tag: return outbound + def register_inbounds(self, storage: BaseStorage): + inbounds = [ + Inbound(tag=i["tag"], protocol=i["protocol"], config=i) + for i in self.inbounds_by_tag.values() + ] + for inbound in inbounds: + storage.register_inbound(inbound) + def to_json(self, **json_kwargs): return json.dumps(self, **json_kwargs) diff --git a/marznode/backends/xray/_runner.py b/marznode/backends/xray/_runner.py index 990c578..13930fc 100644 --- a/marznode/backends/xray/_runner.py +++ b/marznode/backends/xray/_runner.py @@ -21,10 +21,10 @@ def __init__(self, executable_path: str, assets_path: str): self.assets_path = assets_path self.version = get_version(executable_path) - self.process = None + self._process = None self.restarting = False - self._snd_streams, self._rcv_streams = [], [] + self._snd_streams = [] self._logs_buffer = deque(maxlen=100) self._env = {"XRAY_LOCATION_ASSET": assets_path} @@ -38,17 +38,17 @@ async def start(self, config: XrayConfig): config["log"]["logLevel"] = "warning" cmd = [self.executable_path, "run", "-config", "stdin:"] - self.process = await asyncio.create_subprocess_shell( + self._process = await asyncio.create_subprocess_shell( " ".join(cmd), env=self._env, stdin=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, stdout=asyncio.subprocess.PIPE, ) - self.process.stdin.write(str.encode(config.to_json())) - await self.process.stdin.drain() - self.process.stdin.close() - await self.process.stdin.wait_closed() + self._process.stdin.write(str.encode(config.to_json())) + await self._process.stdin.drain() + self._process.stdin.close() + await self._process.stdin.wait_closed() logger.info("Xray core %s started", self.version) asyncio.create_task(self.__capture_process_logs()) @@ -58,8 +58,8 @@ def stop(self): if not self.started: return - self.process.terminate() - self.process = None + self._process.terminate() + self._process = None logger.warning("Xray core stopped") async def restart(self, config: XrayConfig): @@ -78,14 +78,24 @@ async def restart(self, config: XrayConfig): async def __capture_process_logs(self): """capture the logs, push it into the stream, and store it in the deck note that the stream blocks sending if it's full, so a deck is necessary""" - while output := await self.process.stdout.readline(): - for stm in self._snd_streams: - try: - await stm.send(output) - except (ClosedResourceError, BrokenResourceError): - self._snd_streams.remove(stm) - continue - self._logs_buffer.append(output) + + async def capture_stream(stream): + while True: + output = await stream.readline() + if output == b"": + """break in case of eof""" + return + for stm in self._snd_streams: + try: + await stm.send(output) + except (ClosedResourceError, BrokenResourceError): + self._snd_streams.remove(stm) + continue + self._logs_buffer.append(output) + + await asyncio.gather( + capture_stream(self._process.stderr), capture_stream(self._process.stdout) + ) def get_logs_stm(self): new_snd_stm, new_rcv_stm = create_memory_object_stream() @@ -96,9 +106,8 @@ def get_buffer(self): """makes a copy of the buffer, so it could be read multiple times the buffer is never cleared in case logs from xray's exit are useful""" return self._logs_buffer.copy() + # return [line for line in self._logs_buffer] @property def started(self): - if not self.process or self.process.returncode is not None: - return False - return True + return self._process and self._process.returncode is None diff --git a/marznode/backends/xray/interface.py b/marznode/backends/xray/interface.py index a57cc7a..5a8b8c4 100644 --- a/marznode/backends/xray/interface.py +++ b/marznode/backends/xray/interface.py @@ -24,39 +24,36 @@ class XrayBackend(VPNBackend): def __init__(self, storage: BaseStorage): - xray_api_port = find_free_port() - self._config = XrayConfig(config.XRAY_CONFIG_PATH, api_port=xray_api_port) - xray_inbounds = [ - Inbound(tag=i["tag"], protocol=i["protocol"], config=i) - for i in self._config.inbounds_by_tag.values() - ] - storage.set_inbounds(xray_inbounds) - self._inbound_tags = {i.tag for i in xray_inbounds} - self._api = XrayAPI("127.0.0.1", xray_api_port) + self._config = None + self._inbound_tags = set() + self._api = None self._runner = XrayCore(config.XRAY_EXECUTABLE_PATH, config.XRAY_ASSETS_PATH) + self._storage = storage def contains_tag(self, tag: str) -> bool: return tag in self._inbound_tags - async def start(self): + async def start(self, backend_config: str): + xray_api_port = find_free_port() + self._config = XrayConfig(backend_config, api_port=xray_api_port) + self._config.register_inbounds(self._storage) + self._inbound_tags = {i["tag"] for i in self._config.inbounds} + self._api = XrayAPI("127.0.0.1", xray_api_port) await self._runner.start(self._config) + await asyncio.sleep(0.15) + + def stop(self): + self._runner.stop() + for tag in self._inbound_tags: + self._storage.remove_inbound(tag) + self._inbound_tags = set() async def restart(self, backend_config: str | None) -> list[Inbound] | None: # xray_config = backend_config if backend_config else self._config if not backend_config: return await self._runner.restart(self._config) - api_port = find_free_port() - self._config = XrayConfig(backend_config, api_port=api_port) - xray_inbounds = [ - Inbound(tag=i["tag"], protocol=i["protocol"], config=i) - for i in self._config.inbounds_by_tag.values() - ] - await self._runner.restart(self._config) - self._api = XrayAPI("127.0.0.1", api_port) - self._inbound_tags = {i.tag for i in xray_inbounds} - await asyncio.sleep(0.1) # wait until xray api is up, - # I'd rather check if the port is open manually but this is lazier. for now. - return xray_inbounds + self.stop() + await self.start(backend_config) async def add_user(self, user: User, inbound: Inbound): email = f"{user.id}.{user.username}" diff --git a/marznode/config.py b/marznode/config.py index 9241e04..9ca5101 100644 --- a/marznode/config.py +++ b/marznode/config.py @@ -9,11 +9,22 @@ SERVICE_PORT = config("SERVICE_PORT", cast=int, default=53042) INSECURE = config("INSECURE", cast=bool, default=False) +XRAY_ENABLED = config("XRAY_ENABLED", cast=bool, default=True) XRAY_EXECUTABLE_PATH = config("XRAY_EXECUTABLE_PATH", default="/usr/bin/xray") XRAY_ASSETS_PATH = config("XRAY_ASSETS_PATH", default="/usr/share/xray") XRAY_CONFIG_PATH = config("XRAY_CONFIG_PATH", default="/etc/xray/config.json") XRAY_VLESS_REALITY_FLOW = config("XRAY_VLESS_REALITY_FLOW", default="xtls-rprx-vision") + +HYSTERIA_ENABLED = config("HYSTERIA_ENABLED", cast=bool, default=False) +HYSTERIA_EXECUTABLE_PATH = config( + "HYSTERIA_EXECUTABLE_PATH", default="/usr/bin/hysteria" +) +HYSTERIA_CONFIG_PATH = config( + "HYSTERIA_CONFIG_PATH", default="/etc/hysteria/config.yaml" +) + + SSL_CERT_FILE = config("SSL_CERT_FILE", default="./ssl_cert.pem") SSL_KEY_FILE = config("SSL_KEY_FILE", default="./ssl_key.pem") SSL_CLIENT_CERT_FILE = config("SSL_CLIENT_CERT_FILE", default="") diff --git a/marznode/marznode.py b/marznode/marznode.py index fdf8a1f..4e56066 100644 --- a/marznode/marznode.py +++ b/marznode/marznode.py @@ -9,12 +9,19 @@ from grpclib.utils import graceful_exit from marznode import config +from marznode.backends.hysteria2.interface import HysteriaBackend from marznode.backends.xray.interface import XrayBackend +from marznode.config import ( + HYSTERIA_EXECUTABLE_PATH, + HYSTERIA_CONFIG_PATH, + XRAY_CONFIG_PATH, + HYSTERIA_ENABLED, + XRAY_ENABLED, +) from marznode.service import MarzService from marznode.storage import MemoryStorage from marznode.utils.ssl import generate_keypair, create_secure_context - logger = logging.getLogger(__name__) @@ -26,7 +33,7 @@ async def main(): if not all( (os.path.isfile(config.SSL_CERT_FILE), os.path.isfile(config.SSL_KEY_FILE)) ): - logger.info("Generating a keypair for Marz-node.") + logger.info("Generating a keypair for Marznode.") generate_keypair(config.SSL_KEY_FILE, config.SSL_CERT_FILE) if not os.path.isfile(config.SSL_CLIENT_CERT_FILE): @@ -39,9 +46,16 @@ async def main(): ) storage = MemoryStorage() - xray_backend = XrayBackend(storage) - await xray_backend.start() - backends = [xray_backend] + backends = [] + if XRAY_ENABLED: + xray_backend = XrayBackend(storage) + await xray_backend.start(XRAY_CONFIG_PATH) + backends.append(xray_backend) + if HYSTERIA_ENABLED: + hysteria_backend = HysteriaBackend(HYSTERIA_EXECUTABLE_PATH, storage) + await hysteria_backend.start(HYSTERIA_CONFIG_PATH) + backends.append(hysteria_backend) + server = Server([MarzService(storage, backends), Health()]) with graceful_exit([server]): diff --git a/marznode/service/service.py b/marznode/service/service.py index 2c53c10..14aa080 100644 --- a/marznode/service/service.py +++ b/marznode/service/service.py @@ -5,6 +5,7 @@ import json import logging +from collections import defaultdict from grpclib.server import Stream @@ -77,7 +78,7 @@ async def _update_user(self, user_data: UserData): elif not user_data.inbounds and not storage_user: """we're asked to remove a user which we don't have, just pass.""" return - + """otherwise synchronize the user with what the client has sent us""" storage_tags = {i.tag for i in storage_user.inbounds} @@ -121,10 +122,18 @@ async def RepopulateUsers( async def FetchUsersStats(self, stream: Stream[Empty, UsersStats]) -> None: await stream.recv_message() - stats = await self._backends[0].get_usages() - logger.debug(stats) + all_stats = defaultdict(int) + + for backend in self._backends: + stats = await backend.get_usages() + + for user, usage in stats.items(): + all_stats[user] += usage + + logger.debug(all_stats) user_stats = [ - UsersStats.UserStats(uid=uid, usage=usage) for uid, usage in stats.items() + UsersStats.UserStats(uid=uid, usage=usage) + for uid, usage in all_stats.items() ] await stream.send_message(UsersStats(users_stats=user_stats)) @@ -147,11 +156,6 @@ async def RestartXray( await self._storage.flush_users() inbounds = await self._backends[0].restart(message.configuration) logger.debug(inbounds) - if inbounds: - self._storage.set_inbounds(inbounds) - pb2_inbounds = [ - Inbound(tag=i.tag, config=json.dumps(i.config)) for i in inbounds - ] - await stream.send_message(InboundsResponse(inbounds=pb2_inbounds)) + await stream.send_message(InboundsResponse(inbounds=[])) with open(config.XRAY_CONFIG_PATH, "w") as f: f.write(json.dumps(json.loads(message.configuration), indent=2)) diff --git a/marznode/storage/base.py b/marznode/storage/base.py index 3db0499..d7f8319 100644 --- a/marznode/storage/base.py +++ b/marznode/storage/base.py @@ -52,9 +52,16 @@ async def flush_users(self) -> None: """ @abstractmethod - def set_inbounds(self, inbounds: list[Inbound]) -> None: + def register_inbound(self, inbound: Inbound) -> None: """ - resets all inbounds - :param inbounds: inbounds + registers a new inbound + :param inbound: the inbound to register + :return: nothing + """ + + def remove_inbound(self, inbound: Inbound | str) -> None: + """ + removes an inbound + :param inbound: the inbound to remove :return: nothing """ diff --git a/marznode/storage/memory.py b/marznode/storage/memory.py index b91b425..6a5ee49 100644 --- a/marznode/storage/memory.py +++ b/marznode/storage/memory.py @@ -41,8 +41,13 @@ async def update_user_inbounds(self, user: User, inbounds: list[Inbound]) -> Non user.inbounds = inbounds self.storage["users"][user.id] = user - def set_inbounds(self, inbounds: list[Inbound]) -> None: - self.storage["inbounds"] = {i.tag: i for i in inbounds} + def register_inbound(self, inbound: Inbound) -> None: + self.storage["inbounds"][inbound.tag] = inbound + + def remove_inbound(self, inbound: Inbound | str) -> None: + tag = inbound if isinstance(inbound, str) else inbound.tag + if tag in self.storage["inbounds"]: + self.storage["inbounds"].pop(tag) async def flush_users(self): self.storage["users"] = {} diff --git a/requirements.txt b/requirements.txt index c9f2549..e96e621 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,3 +8,6 @@ python-decouple==3.8 python-dotenv==1.0.1 requests==2.31.0 xxhash==3.4.1 + +PyYAML~=6.0.1 +aiohttp~=3.9.5 \ No newline at end of file