Skip to content

Commit

Permalink
Move post registration from the bot side to an agent side
Browse files Browse the repository at this point in the history
  • Loading branch information
borisevich-a-v committed Dec 25, 2024
1 parent 5d950d1 commit 0cfdae3
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 68 deletions.
21 changes: 2 additions & 19 deletions src/aggregator/bot/create_bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,11 @@
from loguru import logger
from telethon import TelegramClient, events
from telethon.sessions import StringSession
from telethon.utils import get_display_name
from typing_extensions import NamedTuple

from aggregator.bot.warden.warden import NotAllowed, Warden
from aggregator.config import ADMIN, AGGREGATOR_CHANNEL, TELEGRAM_API_HASH, TELEGRAM_API_ID, TELEGRAM_BOT_TOKEN, \
BOT_SESSION
from aggregator.config import ADMIN, AGGREGATOR_CHANNEL, BOT_SESSION, TELEGRAM_API_HASH, TELEGRAM_API_ID
from aggregator.posts_storage import NoNewPosts, PostStorage
from aggregator.telegram_slow_client import TelegramSlowClient

ANY_CHANNEL_COMMAND = "next"

Expand Down Expand Up @@ -46,21 +43,7 @@ def get_request_pattern(post_storage: PostStorage) -> re.Pattern:

def create_bot(post_storage: PostStorage, warden: Warden) -> TelegramClient:
logger.info("Creating bot")
bot = TelegramSlowClient(StringSession(BOT_SESSION), TELEGRAM_API_ID, TELEGRAM_API_HASH, min_request_interval=0.005)

# TODO: should we parse album here? nea
@bot.on(events.NewMessage(chats=AGGREGATOR_CHANNEL))
async def aggregator_channel_listener(event) -> None:
if hasattr(event, "message"):
message = event.message
if message.fwd_from and message.fwd_from.from_id:
original_chat = await event.client.get_entity(message.fwd_from.from_id)
channel_name = get_display_name(original_chat)
post_storage.post(event.message, channel_name)
else:
logger.warning("Message is not forwarded, skipping")
else:
logger.critical("No message {}", event)
bot = TelegramClient(StringSession(BOT_SESSION), TELEGRAM_API_ID, TELEGRAM_API_HASH)

@bot.on(events.NewMessage(pattern=get_request_pattern(post_storage), from_users=ADMIN))
async def handle_posts_request_command(event) -> None:
Expand Down
15 changes: 8 additions & 7 deletions src/aggregator/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,18 @@ class MessageModel(Base):
id = Column(Integer, primary_key=True)
message_id = Column(Integer, nullable=False) # message id in the aggr channel
grouped_id = Column(BigInteger, nullable=True) # grouped_id is a Telegram hack to group messages
channel_id = Column(BigInteger, ForeignKey("channel.id"), nullable=False) # source channel id.
sent = Column(DateTime, nullable=True) # have the message been sent to the user. NULL if it haven't been sent
original_message_id = Column(Integer, nullable=False) # Message id in a source channel. for deduplicate reposts

# id of the channel where the message is forwarded from
channel_id = Column(BigInteger, ForeignKey("channel.id"), nullable=False)

# Source channel and message id (in case the message was fwd into listening channel)
original_channel_id = Column(BigInteger, nullable=False)
original_message_id = Column(Integer, nullable=False)

channel = relationship("ChannelModel", back_populates="messages")

__table_args__ = (
# Every message in a telegram channel has unique id, so the pair original_message_id
# and (original) channel_id should be uniq
UniqueConstraint("original_message_id", "channel_id", name="source_message_uniq"),
)
__table_args__ = (UniqueConstraint("original_channel_id", "original_message_id", name="source_message_uniq"),)

def __repr__(self):
return f"MessageModel({self.id, self.message_id, self.grouped_id, self.channel_id, self.sent, self.original_message_id})"
Expand Down
31 changes: 12 additions & 19 deletions src/aggregator/posts_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,30 +24,18 @@ def __init__(self, session_maker) -> None:
logger.info("Post storage is initializing...")
self.session_maker = session_maker

def post(self, message: Message, channel_name: str) -> None:
original_message_id = message.fwd_from.channel_post
original_peer_id = get_peer_id(message.fwd_from.from_id)

logger.info("Adding message {} from the {}...", original_message_id, original_peer_id)
def post(self, message_id: MESSAGE_ID, grouped_id: int , event_peer_id, original_channel_id, original_message_id) -> None:
with self.session_maker() as session:
orm_message = MessageModel(
message_id=message.id,
grouped_id=message.grouped_id,
channel=self.get_or_create_channel(original_peer_id, session, channel_name),
message_id=message_id,
grouped_id=grouped_id,
channel_id=event_peer_id,
original_channel_id=original_channel_id,
original_message_id=original_message_id,
)
session.add(orm_message)
session.commit()

def get_or_create_channel(self, channel_id: int, session: Session, channel_name: str) -> ChannelModel:
channel = session.query(ChannelModel).filter_by(id=channel_id).first()
if not channel:
logger.info(f"Adding the new channel...: {channel_name}")
channel = ChannelModel(id=channel_id, name=channel_name)
session.add(channel)
session.commit()
return channel

def _get_first_unsent_message(self, session: Session, channel_type: Any) -> MessageModel:
# My first code with SQLAlchemy, it's bad, but I'll fix it later
unsent_message_query = session.query(MessageModel).join(ChannelModel).filter(MessageModel.sent.is_(None))
Expand Down Expand Up @@ -86,7 +74,7 @@ def set_sent_multiple(self, message_ids: list[MESSAGE_ID]) -> None:
)
session.commit()

def is_original_msg_duplicate(self, msgs: list[Message]) -> bool:
def is_duplicate(self, msgs: list[Message]) -> bool:
with self.session_maker() as session:
for msg in msgs:
# If message is forwarded, then there is an obvious risk of duplication, if it is an original message
Expand Down Expand Up @@ -114,4 +102,9 @@ def get_all_custom_channel_types(self):
types = (
session.query(ChannelTypeModel.type_).filter(ChannelTypeModel.type_ != NOT_SPECIFIED_CHANNEL_TYPE).all()
)
return [t[0] for t in types]
return [t[0] for t in types]

def get_whitelisted_channel_ids(self) -> list[int]:
with self.session_maker() as session:
channel_ids = session.query(ChannelModel.id).all()
return [t[0] for t in channel_ids]
65 changes: 42 additions & 23 deletions src/aggregator/telegram_agent/create_agent.py
Original file line number Diff line number Diff line change
@@ -1,51 +1,70 @@
import asyncio

from loguru import logger
from telethon import TelegramClient, events
from telethon.sessions import StringSession
from telethon.tl.types import Message, PeerUser
from telethon.utils import get_peer_id

from aggregator.config import AGGREGATOR_CHANNEL, CLIENT_SESSION, TELEGRAM_API_HASH, TELEGRAM_API_ID
from aggregator.posts_storage import PostStorage
from aggregator.telegram_slow_client import TelegramSlowClient


def is_it_user(message: Message) -> bool:
from_id = message.from_id
if isinstance(from_id, (PeerUser,)):
return True
return False


def create_telegram_agent(post_storage: PostStorage) -> TelegramClient:
logger.info("Creating telegram agent")
client = TelegramSlowClient(
StringSession(CLIENT_SESSION), TELEGRAM_API_ID, TELEGRAM_API_HASH, min_request_interval=3
)
client = TelegramClient(StringSession(CLIENT_SESSION), TELEGRAM_API_ID, TELEGRAM_API_HASH)

whitelisted_channels = post_storage.get_whitelisted_channel_ids()
forwarding_message_lock = asyncio.Lock()

@client.on(events.NewMessage(AGGREGATOR_CHANNEL, blacklist_chats=True))
@client.on(events.Album(AGGREGATOR_CHANNEL, blacklist_chats=True))
@client.on(events.NewMessage(whitelisted_channels))
@client.on(events.Album(whitelisted_channels))
async def public_channel_listener(event) -> None:
"""
This handler just forward messages to the aggregation channel.
The bot can't access some posts from other public channel, so we forward posts to the place where bot can
access them.
"""
if hasattr(event, "messages") and event.grouped_id:
messages = event.messages
elif hasattr(event, "message") and not event.grouped_id:
if hasattr(event, "message"):
# probably it's a place for improvement.
# IDK how to forward grouped messages together if only process separated messages, but not album.
logger.debug("Processing a single message event...")
is_single_message = True
messages = [event.message]
await asyncio.sleep(2)

elif hasattr(event, "messages"):
logger.debug("Processing a multi message event...")
is_single_message = False
messages = event.messages

else:
logger.info("Got an update, that is not a message. Skipped.")
logger.debug("Not a message")
return

if is_it_user(messages[0]):
return
async with forwarding_message_lock:
if post_storage.is_duplicate(messages):
logger.warning("The messages have been saved previously: {}", messages)
return
logger.info("New messages {} will be forwarded into the aggregation channel", messages)
fwd_event = await event.forward_to(AGGREGATOR_CHANNEL)

if post_storage.is_original_msg_duplicate(messages):
logger.warning("The messages have been saved previously: {}", messages)
return
first_message = fwd_event if is_single_message else fwd_event[0]
fwd_messages = [fwd_event] if is_single_message else fwd_event

original_channel_id = get_peer_id(first_message.fwd_from.from_id)
forwarded_from_channel_id = get_peer_id(messages[0].peer_id)

await event.forward_to(AGGREGATOR_CHANNEL)
logger.info("Got a new post {} and added it to the aggregation channel", [m.id for m in event.messages])
for fwd_msg, msg in zip(fwd_messages, messages):
post_storage.post(
fwd_msg.id,
msg.grouped_id,
forwarded_from_channel_id,
original_channel_id,
fwd_msg.fwd_from.channel_post
)
logger.debug("Messages was successfully processed")

client.start()
logger.info("Client has been initialized")
Expand Down

0 comments on commit 0cfdae3

Please sign in to comment.