diff --git a/docker-compose.yml b/docker-compose.yml index f9fa810..1e21649 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -23,14 +23,16 @@ services: - SPARK_SSL_ENABLED=no depends_on: - spark-master -# postgres: -# image: postgres:16 -# environment: -# - POSTGRES_DB=metadata_db -# - POSTGRES_USER=user -# - POSTGRES_PASSWORD=password -# volumes: -# - postgres-data:/var/lib/postgresql/data + postgres: + image: postgres:16 + environment: + - POSTGRES_DB=metadata_db + - POSTGRES_USER=postgres + - POSTGRES_PASSWORD=postgres + ports: + - '5433:5432' + volumes: + - postgres-data:/var/lib/postgresql/data # app: # build: . # environment: @@ -44,5 +46,5 @@ services: # - postgres # command: pe_analyzer 1000 # Process 1000 files (500 clean, 500 malware) -#volumes: -# postgres-data: +volumes: + postgres-data: diff --git a/pe_analyzer/__main__.py b/pe_analyzer/__main__.py index e37e038..4a85af7 100644 --- a/pe_analyzer/__main__.py +++ b/pe_analyzer/__main__.py @@ -6,7 +6,7 @@ from pe_analyzer.metadata_processor import MetadataProcessor from pe_analyzer.settings import settings -logging.basicConfig(level=settings.logging_level, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") +logging.basicConfig(level=settings.logging_level, format=settings.logging_format) logger = logging.getLogger(__name__) hostname = socket.gethostname() @@ -21,7 +21,7 @@ def main(): spark = ( SparkSession.builder.appName("PE File Analyzer") # .master("local[*]") - .master("spark://localhost:7077") + # .master("spark://localhost:7077") .getOrCreate() ) diff --git a/pe_analyzer/db/db.py b/pe_analyzer/db/db.py index 90d4ae1..f2c187d 100644 --- a/pe_analyzer/db/db.py +++ b/pe_analyzer/db/db.py @@ -1,21 +1,43 @@ +import logging +from typing import TypeAlias + from pyspark import Row from sqlalchemy import create_engine, select from sqlalchemy.orm import sessionmaker +from sqlalchemy.dialects.postgresql import insert as pg_insert + from pe_analyzer.db.models import FileMetadata +logger = logging.getLogger(__name__) + +FilePath: TypeAlias = str + class Database: def __init__(self, db_url): self.engine = create_engine(db_url) self.SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=self.engine) - def get_processed_files(self): + def get_not_processed_files(self, file_paths: list[FilePath]) -> list: with self.SessionLocal() as session: - return [row.path for row in session.execute(select(FileMetadata)).scalars()] + # Get all the processed files within the list of file_paths + stmt = select(FileMetadata.path).where(FileMetadata.path.in_(file_paths)) + processed_files = session.scalars(stmt).all() + logger.info(f"Found {len(processed_files)} processed files.") + # Exclude the ones that are already processed + return list(set(file_paths).difference(set(processed_files))) def save_metadata(self, metadata_list: list[Row]): with self.SessionLocal() as session: + metadata_dicts = [] for metadata in metadata_list: - file_metadata = FileMetadata(**metadata.asDict()) - session.merge(file_metadata) + metadata_dict = metadata.asDict() + if metadata_dict["error"]: + # TODO: output errors somewhere else + logger.warning(f"Error processing {metadata_dict['path']}: {metadata_dict['error']}") + del metadata_dict["error"] + metadata_dicts.append(metadata_dict) + + stmt = pg_insert(FileMetadata).values(metadata_dicts) + session.execute(stmt) session.commit() diff --git a/pe_analyzer/db/migrations/versions/a17cab59bdaa_index_on_path.py b/pe_analyzer/db/migrations/versions/a17cab59bdaa_index_on_path.py new file mode 100644 index 0000000..e14c89f --- /dev/null +++ b/pe_analyzer/db/migrations/versions/a17cab59bdaa_index_on_path.py @@ -0,0 +1,31 @@ +"""index on path + +Revision ID: a17cab59bdaa +Revises: 4b50f5396ee1 +Create Date: 2024-07-27 19:52:52.207756 + +""" +from typing import Sequence, Union + +from alembic import op + + +# revision identifiers, used by Alembic. +revision: str = "a17cab59bdaa" +down_revision: Union[str, None] = "4b50f5396ee1" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_constraint("file_metadata_path_key", "file_metadata", type_="unique") + op.create_index(op.f("ix_file_metadata_path"), "file_metadata", ["path"], unique=True) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index(op.f("ix_file_metadata_path"), table_name="file_metadata") + op.create_unique_constraint("file_metadata_path_key", "file_metadata", ["path"]) + # ### end Alembic commands ### diff --git a/pe_analyzer/db/models.py b/pe_analyzer/db/models.py index cddbb1b..b724678 100644 --- a/pe_analyzer/db/models.py +++ b/pe_analyzer/db/models.py @@ -9,10 +9,12 @@ class Base(DeclarativeBase): class FileMetadata(Base): __tablename__ = "file_metadata" + # In case of horizontal scaling, we could use UUIDs instead of integers id: Mapped[int] = mapped_column(primary_key=True) - path: Mapped[str] = mapped_column(String, unique=True, nullable=False) + path: Mapped[str] = mapped_column(String, unique=True, nullable=False, index=True) size: Mapped[int] = mapped_column(Integer, nullable=False) file_type: Mapped[str] = mapped_column(String, nullable=True) + # TODO: use enum for x32, x64 architecture: Mapped[str] = mapped_column(String, nullable=True) num_imports: Mapped[int] = mapped_column(Integer, nullable=True) num_exports: Mapped[int] = mapped_column(Integer, nullable=True) diff --git a/pe_analyzer/metadata_processor.py b/pe_analyzer/metadata_processor.py index 3109c5f..fded8a6 100644 --- a/pe_analyzer/metadata_processor.py +++ b/pe_analyzer/metadata_processor.py @@ -2,6 +2,7 @@ import time import boto3 +import botocore.exceptions import pefile from pyspark.sql import SparkSession from pyspark.sql.functions import udf, struct @@ -15,27 +16,34 @@ class MetadataProcessor: - def __init__(self, spark: SparkSession, n: int, db_url: str): + def __init__( + self, + spark: SparkSession, + n: int, + db_url: str, + batch_size: int = 100, + partition_size: int = 10, + ): self.spark = spark self.n = n self.s3_handler = S3Handler() self.database = Database(db_url=db_url) + self.batch_size = batch_size + self.partition_size = partition_size def process(self): process_start = time.time() logger.info("Processing new files...") - # Problem: in the current implementation we get all the data from the database. - # If the number of processed files is very large, this can be a bottleneck. - processed_files = set(self.database.get_processed_files()) - logger.info(f"Found {len(processed_files)} processed files.") - clean_files = self.s3_handler.list_files("0/", self.n // 2) - malware_files = self.s3_handler.list_files("1/", self.n // 2) - logger.info(f"Found {len(clean_files)} clean files and {len(malware_files)} malware files.") + clean_files = self.s3_handler.list_files("0/") + malware_files = self.s3_handler.list_files("1/") + + clean_files = self.database.get_not_processed_files(clean_files)[: self.n // 2] + malware_files = self.database.get_not_processed_files(malware_files)[: self.n // 2] - all_files = [f for f in clean_files + malware_files if f not in processed_files] + logger.info(f"Found {len(clean_files)} clean files and {len(malware_files)} malware files.") - # TODO: If all_files < N, do we need to get more files and process them? + all_files = clean_files + malware_files if not all_files: logger.info("No new files to process.") @@ -51,20 +59,23 @@ def process(self): StructField("architecture", StringType(), True), StructField("num_imports", IntegerType(), True), StructField("num_exports", IntegerType(), True), + StructField("error", StringType(), True), ] ) def analyze_pe_file(s3_path: str, s3_region: str, s3_bucket: str) -> tuple: file_type = s3_path.split(".")[-1].lower() if "." in s3_path else None - s3 = boto3.client("s3", region_name=s3_region) + s3 = boto3.client("s3", region_name=s3_region, aws_access_key_id="", aws_secret_access_key="") + + s3._request_signer.sign = lambda *args, **kwargs: None try: obj = s3.get_object(Bucket=s3_bucket, Key=s3_path) file_content = obj["Body"].read() file_size = len(file_content) - pe = pefile.PE(data=file_content, fast_load=False) + pe = pefile.PE(data=file_content) arch = "x32" if pe.FILE_HEADER.Machine == pefile.MACHINE_TYPE["IMAGE_FILE_MACHINE_I386"] else "x64" import_count = ( @@ -74,9 +85,12 @@ def analyze_pe_file(s3_path: str, s3_region: str, s3_bucket: str) -> tuple: ) export_count = len(pe.DIRECTORY_ENTRY_EXPORT.symbols) if hasattr(pe, "DIRECTORY_ENTRY_EXPORT") else 0 - return s3_path, file_size, file_type, arch, import_count, export_count - except Exception: - return s3_path, None, file_type, None, None, None + return s3_path, file_size, file_type, arch, import_count, export_count, None + except botocore.exceptions.EndpointConnectionError as err: + # TODO: if the error is due to network issues, we need to retry + return s3_path, None, file_type, None, None, None, str(err) + except Exception as err: + return s3_path, None, file_type, None, None, None, str(err) finally: if "pe" in locals(): pe.close() @@ -90,18 +104,33 @@ def analyze_file_udf(row): [(f, settings.s3_region, settings.s3_bucket) for f in all_files], ["path", "s3_region", "s3_bucket"] ) - logger.info(f"Processing {df.count()} files...") - start = time.time() - df = df.select(analyze_udf(struct("path", "s3_region", "s3_bucket")).alias("metadata")).select("metadata.*") - logger.info(f"Processing time: {time.time() - start:.2f} seconds.") - - start = time.time() - metadata_list = df.collect() - logger.info(f"Collect time: {time.time() - start:.2f} seconds.") - self.database.save_metadata(metadata_list) - - logger.info(f"Processed {df.count()} new files.") + df = df.select(analyze_udf(struct("path", "s3_region", "s3_bucket")).alias("metadata")).select("metadata.*") + def process_partition(iterator): + from sqlalchemy import create_engine, Table, MetaData + from sqlalchemy.dialects.postgresql import insert + + engine = create_engine("postgresql://postgres:postgres@localhost:5432/metadata_db") + file_metadata_table = Table("file_metadata", MetaData(), autoload_with=engine) + + batch = [] + with engine.connect() as connection: + for row in iterator: + row_dict = row.asDict() + if row_dict["error"]: + print(f"Error processing {row_dict['path']}: {row_dict['error']}") + del row_dict["error"] + batch.append(row_dict) + if len(batch) >= 100: + connection.execute(insert(file_metadata_table).values(batch)) + batch = [] + if batch: + connection.execute(insert(file_metadata_table).values(batch)) + connection.commit() + + df.foreachPartition(process_partition) + + logger.info(f"Processed {df.count()} new files in {time.time() - start:.2f} seconds.") logger.info(f"Total processing time: {time.time() - process_start:.2f} seconds.") diff --git a/pe_analyzer/s3_handler.py b/pe_analyzer/s3_handler.py index edc02c4..d8f61ce 100644 --- a/pe_analyzer/s3_handler.py +++ b/pe_analyzer/s3_handler.py @@ -7,9 +7,6 @@ def __init__(self, region_name: str = settings.s3_region, bucket_name: str = set self.s3_client = boto3.client("s3", region_name=region_name) self.bucket_name = bucket_name - def list_files(self, prefix: str, limit: int) -> list[str]: + def list_files(self, prefix: str) -> list[str]: response = self.s3_client.list_objects_v2(Bucket=self.bucket_name, Prefix=prefix) - return [obj["Key"] for obj in response.get("Contents", [])[:limit]] - - def download_file(self, file_path: str, local_path: str): - self.s3_client.download_file(self.bucket_name, file_path, local_path) + return [obj["Key"] for obj in response.get("Contents", [])] diff --git a/pe_analyzer/settings.py b/pe_analyzer/settings.py index cccab67..dea99c9 100644 --- a/pe_analyzer/settings.py +++ b/pe_analyzer/settings.py @@ -20,6 +20,22 @@ class DatabaseSettings(BaseSettings): auto_flush: bool = False expire_on_commit: bool = False + def get_pyspark_driver(self): + return { + "postgresql": "org.postgresql.Driver", + # Add more protocols here + }[self.db_protocol] + + def get_pyspark_properties(self): + return { + "user": self.db_user, + "password": self.db_password, + "driver": self.get_pyspark_driver(), + } + + def get_pyspark_db_url(self): + return f"jdbc:{self.db_protocol}://{self.db_host}:{self.db_port}/{self.db_name}" + class Config: extra = "allow" @@ -30,6 +46,7 @@ class AppSettings(BaseSettings): base_dir: DirectoryPath = os.path.dirname(os.path.abspath(__file__)) logging_level: str = "INFO" + logging_format: str = "%(asctime)s - %(name)s - %(levelname)s - %(message)s" database: DatabaseSettings = DatabaseSettings(_env_file="db.env", _env_file_encoding="utf-8") diff --git a/spark/Dockerfile b/spark/Dockerfile index 14ec816..067283c 100644 --- a/spark/Dockerfile +++ b/spark/Dockerfile @@ -1,5 +1,5 @@ FROM bitnami/spark:3.5.1 USER root -RUN pip install boto3 pefile +RUN pip install boto3 pefile sqlalchemy psycopg2-binary USER 1001