-
Notifications
You must be signed in to change notification settings - Fork 3
/
main.py
150 lines (129 loc) · 5.71 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
#!/usr/bin/env python3
import argparse
import asyncio
import logging
import os
import signal
import sys
import time
import traceback
from pathlib import Path
from aiohttp import ClientConnectionError, ServerDisconnectedError
from bison.errors import SchemeValidationError
from nio import InviteMemberEvent, JoinResponse, MegolmEvent, RoomMessageText, UnknownEvent, RoomMessageImage
from matrix_gpt import MatrixClientHelper
from matrix_gpt.callbacks import MatrixBotCallbacks
from matrix_gpt.config import global_config
SCRIPT_DIR = os.path.abspath(os.path.dirname(__file__))
logging.basicConfig()
logger = logging.getLogger('MatrixGPT')
async def main(args):
args.config = Path(args.config)
if not args.config.exists():
logger.critical('Config file does not exist:', args.config)
sys.exit(1)
global_config.load(args.config)
try:
global_config.validate()
except SchemeValidationError as e:
logger.critical(f'Config validation error: {e}')
sys.exit(1)
if global_config['logging']['log_level'] == 'info':
log_level = logging.INFO
elif global_config['logging']['log_level'] == 'debug':
log_level = logging.DEBUG
elif global_config['logging']['log_level'] == 'warning':
log_level = logging.WARNING
elif global_config['logging']['log_level'] == 'critical':
log_level = logging.CRITICAL
else:
log_level = logging.INFO
logger.setLevel(log_level)
l = logger.getEffectiveLevel()
if l == 10:
logger.debug('Log level is DEBUG')
elif l == 20:
logger.info('Log level is INFO')
elif l == 30:
logger.warning('Log level is WARNING')
elif l == 40:
logger.error('Log level is ERROR')
elif l == 50:
logger.critical('Log level is CRITICAL')
else:
logger.info(f'Log level is {l}')
del l
logger.debug(f'Command Prefixes: {[k for k, v in global_config.command_prefixes.items()]}')
logger.info(f"OpenAI API key: {'yes' if global_config['openai'].get('api_key') else 'no'}")
logger.info(f"Anthropic API key: {'yes' if global_config['anthropic'].get('api_key') else 'no'}")
logger.info(f"Copilot API key: {'yes' if global_config['copilot'].get('api_key') else 'no'}")
client_helper = MatrixClientHelper(
user_id=global_config['auth']['username'],
passwd=global_config['auth']['password'],
homeserver=global_config['auth']['homeserver'],
store_path=global_config['store_path'],
device_id=global_config['auth']['device_id']
)
client = client_helper.client
if global_config['openai'].get('api_base'):
logger.info(f'Set OpenAI API base URL to: {global_config["openai"].get("api_base")}')
# Set up event callbacks
callbacks = MatrixBotCallbacks(client=client_helper)
client.add_event_callback(callbacks.handle_message, (RoomMessageText, RoomMessageImage))
client.add_event_callback(callbacks.handle_invite, InviteMemberEvent)
client.add_event_callback(callbacks.decryption_failure, MegolmEvent)
client.add_event_callback(callbacks.unknown, UnknownEvent)
# Keep trying to reconnect on failure (with some time in-between)
while True:
try:
logger.info('Logging in...')
while True:
login_success, login_response = await client_helper.login()
if not login_success:
if 'M_LIMIT_EXCEEDED' in str(login_response):
try:
wait = int((int(str(login_response).split(' ')[-1][:-2]) / 1000) / 2) # only wait half the ratelimited time
logger.error(f'Ratelimited, sleeping {wait}s...')
time.sleep(wait)
except:
logger.error(f'Could not parse M_LIMIT_EXCEEDED: {login_response}')
else:
logger.error(f'Failed to login, retrying: {login_response}')
time.sleep(5)
else:
break
# Login succeeded!
logger.info(f'Logged in as {client.user_id}')
if global_config.get('autojoin_rooms'):
for room in global_config.get('autojoin_rooms'):
r = await client.join(room)
if not isinstance(r, JoinResponse):
logger.critical(f'Failed to join room {room}: {vars(r)}')
time.sleep(1.5)
logger.info('Performing initial sync...')
last_sync = (await client_helper.sync()).next_batch
client_helper.run_sync_in_bg() # start a background thread to record our sync tokens
logger.info('Bot is active')
await client.sync_forever(timeout=10000, full_state=True, since=last_sync)
except (ClientConnectionError, ServerDisconnectedError):
logger.warning("Unable to connect to homeserver, retrying in 15s...")
time.sleep(15)
except KeyboardInterrupt:
await client.close()
os.kill(os.getpid(), signal.SIGTERM)
except Exception:
logger.critical(traceback.format_exc())
logger.critical('Sleeping 5s...')
time.sleep(5)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='MatrixGPT Bot')
parser.add_argument('--config', default=Path(SCRIPT_DIR, 'config.yaml'), help='Path to config.yaml if it is not located next to this executable.')
args = parser.parse_args()
while True:
try:
asyncio.run(main(args))
except KeyboardInterrupt:
os.kill(os.getpid(), signal.SIGTERM)
except Exception:
logger.critical(traceback.format_exc())
time.sleep(5)