Skip to content

Commit

Permalink
refactor some service stuff (#4)
Browse files Browse the repository at this point in the history
* refactor(service): make calls more generic

* fix(interfaces): make inbounds a list

* fix(marznode): backends should be a dictionary

* fixes and improvements

* feat(backends): version

* refactor(xray-config): make it swallow errors

* feat(backends): add backend stats to check if a backend is running

* feat(hysteria): send empty bytes in case of eof

* fix(service): give the backend configuration correctly

* fix(xray): close the stream on after reading eof/start when starting xray

* fix(hysteria): restarting

* fix(service): restarting a backend correctly

* improve(backend-restart): make backends save their config, and acquire a lock for restarting
  • Loading branch information
khodedawsh authored Aug 9, 2024
1 parent f8280e9 commit e44d93b
Show file tree
Hide file tree
Showing 17 changed files with 520 additions and 199 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
__default__:
@echo "Please specify a target to make"

GEN=python3 -m grpc_tools.protoc -I. --python_out=. --grpclib_python_out=.
GEN=python3 -m grpc_tools.protoc -I. --python_out=. --grpclib_python_out=. --pyi_out=.
GENERATED=*{_pb2.py,_grpc.py,.pyi}

clean:
Expand Down
25 changes: 25 additions & 0 deletions marznode/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,19 @@


class VPNBackend(ABC):
backend_type: str
config_format: int

@property
@abstractmethod
def version(self) -> str | None:
raise NotImplementedError

@property
@abstractmethod
def running(self) -> bool:
raise NotImplementedError

@abstractmethod
def contains_tag(self, tag: str) -> bool:
raise NotImplementedError
Expand Down Expand Up @@ -35,3 +48,15 @@ def get_logs(self, include_buffer: bool) -> AsyncIterator:
@abstractmethod
async def get_usages(self):
raise NotImplementedError

@abstractmethod
def list_inbounds(self):
raise NotImplementedError

@abstractmethod
def get_config(self):
raise NotImplementedError

@abstractmethod
def save_config(self, config: str):
raise NotImplementedError
10 changes: 7 additions & 3 deletions marznode/backends/hysteria2/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,13 @@ def __init__(
self._inbound.update({"path": obfs_password, "header_type": obfs_type})

def register_inbounds(self, storage: BaseStorage):
inbound = self._inbound
storage.register_inbound(
Inbound(tag=inbound["tag"], protocol=inbound["protocol"], config=inbound)
storage.register_inbound(self.get_inbound())

def get_inbound(self):
return Inbound(
tag=self._inbound["tag"],
protocol=self._inbound["protocol"],
config=self._inbound,
)

def render(self):
Expand Down
16 changes: 10 additions & 6 deletions marznode/backends/hysteria2/_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import yaml
from anyio import BrokenResourceError, ClosedResourceError, create_memory_object_stream

from marznode.backends.hysteria2._utils import get_version

logger = logging.getLogger(__name__)


Expand All @@ -17,7 +19,8 @@ def __init__(self, executable_path: str):
self._snd_streams = []
self._logs_buffer = deque(maxlen=100)
self._capture_task = None
atexit.register(lambda: self.stop() if self.started else None)
self.version = get_version(executable_path)
atexit.register(lambda: self.stop() if self.running else None)

async def start(self, config: dict):
with tempfile.NamedTemporaryFile(
Expand All @@ -36,11 +39,11 @@ async def start(self, config: dict):
asyncio.create_task(self.__capture_process_logs())

def stop(self):
if self.started:
if self.running:
self._process.terminate()

@property
def started(self):
def running(self):
return self._process and self._process.returncode is None

async def __capture_process_logs(self):
Expand All @@ -50,16 +53,17 @@ async def __capture_process_logs(self):
async def capture_stream(stream):
while True:
output = await stream.readline()
if output == b"":
"""break in case of eof"""
return
for stm in self._snd_streams:
try:
await stm.send(output)
except (ClosedResourceError, BrokenResourceError):
self._snd_streams.remove(stm)
continue
self._logs_buffer.append(output)
if output == b"":
"""break in case of eof"""
logger.warning("Hysteria has stopped")
return

await asyncio.gather(
capture_stream(self._process.stderr), capture_stream(self._process.stdout)
Expand Down
18 changes: 18 additions & 0 deletions marznode/backends/hysteria2/_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import re
import subprocess


def get_version(hysteria_path: str) -> str | None:
"""
get xray version by running its executable
:param hysteria_path:
:return: xray version
"""
cmd = [hysteria_path, "version"]
output = subprocess.check_output(cmd, stderr=subprocess.STDOUT).decode()
pattern = r"Version:\s*v(\d+\.\d+\.\d+)"
match = re.search(pattern, output)
if match:
return match.group(1)
else:
return None
72 changes: 55 additions & 17 deletions marznode/backends/hysteria2/interface.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import asyncio
import json
import logging
from secrets import token_hex
from typing import AsyncIterator, Any
from typing import AsyncIterator

import aiohttp
from aiohttp import web
from aiohttp import web, ClientConnectorError

from marznode.backends.base import VPNBackend
from marznode.backends.hysteria2._config import HysteriaConfig
Expand All @@ -18,47 +19,81 @@


class HysteriaBackend(VPNBackend):
def __init__(self, executable_path: str, storage: BaseStorage):
backend_type = "hysteria2"
config_format = 2

def __init__(self, executable_path: str, config_path: str, storage: BaseStorage):
self._app_runner = None
self._executable_path = executable_path
self._storage = storage
self._inbounds = ["hysteria2"]
self._inbound_tags = ["hysteria2"]
self._inbounds = list()
self._users = {}
self._auth_site = None
self._runner = Hysteria(self._executable_path)
self._stats_secret = None
self._stats_port = None
self._config_path = config_path
self._restart_lock = asyncio.Lock()

@property
def running(self) -> bool:
return self._runner.running

@property
def version(self):
return self._runner.version

def contains_tag(self, tag: str) -> bool:
return bool(tag == "hysteria2")

async def start(self, config_path: str) -> None:
def list_inbounds(self) -> list:
return self._inbounds

def get_config(self) -> str:
with open(self._config_path) as f:
return f.read()

def save_config(self, config: str) -> None:
with open(self._config_path, "w") as f:
f.write(config)

async def start(self, config: str | None = None) -> None:
if config is None:
with open(self._config_path) as f:
config = f.read()
else:
self.save_config(config)
api_port = find_free_port()
self._stats_port = find_free_port()
self._stats_secret = token_hex(16)
if self._auth_site:
await self._auth_site.stop()
await self._app_runner.cleanup()
app = web.Application()
app.router.add_post("/", self._auth_callback)
app_runner = web.AppRunner(app)
await app_runner.setup()
self._app_runner = web.AppRunner(app)
await self._app_runner.setup()

self._auth_site = web.TCPSite(app_runner, "127.0.0.1", api_port)
self._auth_site = web.TCPSite(self._app_runner, "127.0.0.1", api_port)

await self._auth_site.start()
with open(config_path) as f:
config = f.read()
cfg = HysteriaConfig(config, api_port, self._stats_port, self._stats_secret)
cfg.register_inbounds(self._storage)
self._inbounds = [cfg.get_inbound()]
await self._runner.start(cfg.render())

async def stop(self):
await self._auth_site.stop()
self._storage.remove_inbound("hysteria2")
self._runner.stop()

async def restart(self, backend_config: Any) -> None:
await self.stop()
await self.start(backend_config)
async def restart(self, backend_config: str | None) -> None:
await self._restart_lock.acquire()
try:
await self.stop()
await self.start(backend_config)
finally:
self._restart_lock.release()

async def add_user(self, user: User, inbound: Inbound) -> None:
password = generate_password(user.key)
Expand Down Expand Up @@ -88,9 +123,12 @@ async def get_usages(self):
url = "http://127.0.0.1:" + str(self._stats_port) + "/traffic?clear=1"
headers = {"Authorization": self._stats_secret}

async with aiohttp.ClientSession() as session:
async with session.get(url, headers=headers) as response:
data = await response.json()
try:
async with aiohttp.ClientSession() as session:
async with session.get(url, headers=headers) as response:
data = await response.json()
except ClientConnectorError:
data = {}
usages = {}
for user_identifier, usage in data.items():
uid = int(user_identifier.split(".")[0])
Expand Down
56 changes: 12 additions & 44 deletions marznode/backends/xray/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,8 @@ def __init__(
super().__init__(config)

self.inbounds = []
self.inbounds_by_protocol = {}
self.inbounds_by_tag = {}
# self._fallbacks_inbound = self.get_inbound(XRAY_FALLBACKS_INBOUND_TAG)
self._addr_clients_by_tag = {}
self._resolve_inbounds()

self._apply_api()
Expand Down Expand Up @@ -72,7 +70,7 @@ def _apply_api(self):
def _resolve_inbounds(self):
for inbound in self["inbounds"]:
if (
inbound["protocol"].lower()
inbound.get("protocol", "").lower()
not in {
"vmess",
"trojan",
Expand All @@ -83,12 +81,6 @@ def _resolve_inbounds(self):
):
continue

if not inbound.get("settings"):
inbound["settings"] = {}
if not inbound["settings"].get("clients"):
inbound["settings"]["clients"] = []
self._addr_clients_by_tag[inbound["tag"]] = inbound["settings"]["clients"]

settings = {
"tag": inbound["tag"],
"protocol": inbound["protocol"],
Expand Down Expand Up @@ -123,23 +115,12 @@ def _resolve_inbounds(self):
if inbound["protocol"] == "vless" and net == "tcp":
settings["flow"] = XRAY_VLESS_REALITY_FLOW

try:
settings["pbk"] = tls_settings["publicKey"]
except KeyError:
pvk = tls_settings.get("privateKey")
if not pvk:
raise ValueError(
f"You need to provide privateKey in realitySettings of {inbound['tag']}"
)
x25519 = get_x25519(XRAY_EXECUTABLE_PATH, pvk)
settings["pbk"] = x25519["public_key"]

try:
settings["sid"] = tls_settings.get("shortIds")[0]
except (IndexError, TypeError):
raise ValueError(
f"You need to define at least one shortID in realitySettings of {inbound['tag']}"
)
pvk = tls_settings.get("privateKey")

x25519 = get_x25519(XRAY_EXECUTABLE_PATH, pvk)
settings["pbk"] = x25519["public_key"]

settings["sid"] = tls_settings.get("shortIds", [""])[0]

if net == "tcp":
header = net_settings.get("header", {})
Expand Down Expand Up @@ -178,28 +159,15 @@ def _resolve_inbounds(self):
self.inbounds.append(settings)
self.inbounds_by_tag[inbound["tag"]] = settings

try:
self.inbounds_by_protocol[inbound["protocol"]].append(settings)
except KeyError:
self.inbounds_by_protocol[inbound["protocol"]] = [settings]

def get_inbound(self, tag) -> dict:
for inbound in self["inbounds"]:
if inbound["tag"] == tag:
return inbound

def get_outbound(self, tag) -> dict:
for outbound in self["outbounds"]:
if outbound["tag"] == tag:
return outbound

def register_inbounds(self, storage: BaseStorage):
inbounds = [
for inbound in self.list_inbounds():
storage.register_inbound(inbound)

def list_inbounds(self) -> list[Inbound]:
return [
Inbound(tag=i["tag"], protocol=i["protocol"], config=i)
for i in self.inbounds_by_tag.values()
]
for inbound in inbounds:
storage.register_inbound(inbound)

def to_json(self, **json_kwargs):
return json.dumps(self, **json_kwargs)
Loading

0 comments on commit e44d93b

Please sign in to comment.