Skip to content

Commit

Permalink
feat: batch insert to db using partitions, omitting processed files
Browse files Browse the repository at this point in the history
  • Loading branch information
abinba committed Jul 27, 2024
1 parent 677a948 commit c50ee5d
Show file tree
Hide file tree
Showing 9 changed files with 149 additions and 49 deletions.
22 changes: 12 additions & 10 deletions docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,16 @@ services:
- SPARK_SSL_ENABLED=no
depends_on:
- spark-master
# postgres:
# image: postgres:16
# environment:
# - POSTGRES_DB=metadata_db
# - POSTGRES_USER=user
# - POSTGRES_PASSWORD=password
# volumes:
# - postgres-data:/var/lib/postgresql/data
postgres:
image: postgres:16
environment:
- POSTGRES_DB=metadata_db
- POSTGRES_USER=postgres
- POSTGRES_PASSWORD=postgres
ports:
- '5433:5432'
volumes:
- postgres-data:/var/lib/postgresql/data
# app:
# build: .
# environment:
Expand All @@ -44,5 +46,5 @@ services:
# - postgres
# command: pe_analyzer 1000 # Process 1000 files (500 clean, 500 malware)

#volumes:
# postgres-data:
volumes:
postgres-data:
4 changes: 2 additions & 2 deletions pe_analyzer/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from pe_analyzer.metadata_processor import MetadataProcessor
from pe_analyzer.settings import settings

logging.basicConfig(level=settings.logging_level, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
logging.basicConfig(level=settings.logging_level, format=settings.logging_format)
logger = logging.getLogger(__name__)

hostname = socket.gethostname()
Expand All @@ -21,7 +21,7 @@ def main():
spark = (
SparkSession.builder.appName("PE File Analyzer")
# .master("local[*]")
.master("spark://localhost:7077")
# .master("spark://localhost:7077")
.getOrCreate()
)

Expand Down
30 changes: 26 additions & 4 deletions pe_analyzer/db/db.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,43 @@
import logging
from typing import TypeAlias

from pyspark import Row
from sqlalchemy import create_engine, select
from sqlalchemy.orm import sessionmaker
from sqlalchemy.dialects.postgresql import insert as pg_insert

from pe_analyzer.db.models import FileMetadata

logger = logging.getLogger(__name__)

FilePath: TypeAlias = str


class Database:
def __init__(self, db_url):
self.engine = create_engine(db_url)
self.SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=self.engine)

def get_processed_files(self):
def get_not_processed_files(self, file_paths: list[FilePath]) -> list:
with self.SessionLocal() as session:
return [row.path for row in session.execute(select(FileMetadata)).scalars()]
# Get all the processed files within the list of file_paths
stmt = select(FileMetadata.path).where(FileMetadata.path.in_(file_paths))
processed_files = session.scalars(stmt).all()
logger.info(f"Found {len(processed_files)} processed files.")
# Exclude the ones that are already processed
return list(set(file_paths).difference(set(processed_files)))

def save_metadata(self, metadata_list: list[Row]):
with self.SessionLocal() as session:
metadata_dicts = []
for metadata in metadata_list:
file_metadata = FileMetadata(**metadata.asDict())
session.merge(file_metadata)
metadata_dict = metadata.asDict()
if metadata_dict["error"]:
# TODO: output errors somewhere else
logger.warning(f"Error processing {metadata_dict['path']}: {metadata_dict['error']}")
del metadata_dict["error"]
metadata_dicts.append(metadata_dict)

stmt = pg_insert(FileMetadata).values(metadata_dicts)
session.execute(stmt)
session.commit()
31 changes: 31 additions & 0 deletions pe_analyzer/db/migrations/versions/a17cab59bdaa_index_on_path.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
"""index on path
Revision ID: a17cab59bdaa
Revises: 4b50f5396ee1
Create Date: 2024-07-27 19:52:52.207756
"""
from typing import Sequence, Union

from alembic import op


# revision identifiers, used by Alembic.
revision: str = "a17cab59bdaa"
down_revision: Union[str, None] = "4b50f5396ee1"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_constraint("file_metadata_path_key", "file_metadata", type_="unique")
op.create_index(op.f("ix_file_metadata_path"), "file_metadata", ["path"], unique=True)
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index(op.f("ix_file_metadata_path"), table_name="file_metadata")
op.create_unique_constraint("file_metadata_path_key", "file_metadata", ["path"])
# ### end Alembic commands ###
4 changes: 3 additions & 1 deletion pe_analyzer/db/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@ class Base(DeclarativeBase):
class FileMetadata(Base):
__tablename__ = "file_metadata"

# In case of horizontal scaling, we could use UUIDs instead of integers
id: Mapped[int] = mapped_column(primary_key=True)
path: Mapped[str] = mapped_column(String, unique=True, nullable=False)
path: Mapped[str] = mapped_column(String, unique=True, nullable=False, index=True)
size: Mapped[int] = mapped_column(Integer, nullable=False)
file_type: Mapped[str] = mapped_column(String, nullable=True)
# TODO: use enum for x32, x64
architecture: Mapped[str] = mapped_column(String, nullable=True)
num_imports: Mapped[int] = mapped_column(Integer, nullable=True)
num_exports: Mapped[int] = mapped_column(Integer, nullable=True)
Expand Down
81 changes: 55 additions & 26 deletions pe_analyzer/metadata_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import time

import boto3
import botocore.exceptions
import pefile
from pyspark.sql import SparkSession
from pyspark.sql.functions import udf, struct
Expand All @@ -15,27 +16,34 @@


class MetadataProcessor:
def __init__(self, spark: SparkSession, n: int, db_url: str):
def __init__(
self,
spark: SparkSession,
n: int,
db_url: str,
batch_size: int = 100,
partition_size: int = 10,
):
self.spark = spark
self.n = n
self.s3_handler = S3Handler()
self.database = Database(db_url=db_url)
self.batch_size = batch_size
self.partition_size = partition_size

def process(self):
process_start = time.time()
logger.info("Processing new files...")
# Problem: in the current implementation we get all the data from the database.
# If the number of processed files is very large, this can be a bottleneck.
processed_files = set(self.database.get_processed_files())
logger.info(f"Found {len(processed_files)} processed files.")

clean_files = self.s3_handler.list_files("0/", self.n // 2)
malware_files = self.s3_handler.list_files("1/", self.n // 2)
logger.info(f"Found {len(clean_files)} clean files and {len(malware_files)} malware files.")
clean_files = self.s3_handler.list_files("0/")
malware_files = self.s3_handler.list_files("1/")

clean_files = self.database.get_not_processed_files(clean_files)[: self.n // 2]
malware_files = self.database.get_not_processed_files(malware_files)[: self.n // 2]

all_files = [f for f in clean_files + malware_files if f not in processed_files]
logger.info(f"Found {len(clean_files)} clean files and {len(malware_files)} malware files.")

# TODO: If all_files < N, do we need to get more files and process them?
all_files = clean_files + malware_files

if not all_files:
logger.info("No new files to process.")
Expand All @@ -51,20 +59,23 @@ def process(self):
StructField("architecture", StringType(), True),
StructField("num_imports", IntegerType(), True),
StructField("num_exports", IntegerType(), True),
StructField("error", StringType(), True),
]
)

def analyze_pe_file(s3_path: str, s3_region: str, s3_bucket: str) -> tuple:
file_type = s3_path.split(".")[-1].lower() if "." in s3_path else None

s3 = boto3.client("s3", region_name=s3_region)
s3 = boto3.client("s3", region_name=s3_region, aws_access_key_id="", aws_secret_access_key="")

s3._request_signer.sign = lambda *args, **kwargs: None

try:
obj = s3.get_object(Bucket=s3_bucket, Key=s3_path)
file_content = obj["Body"].read()
file_size = len(file_content)

pe = pefile.PE(data=file_content, fast_load=False)
pe = pefile.PE(data=file_content)
arch = "x32" if pe.FILE_HEADER.Machine == pefile.MACHINE_TYPE["IMAGE_FILE_MACHINE_I386"] else "x64"

import_count = (
Expand All @@ -74,9 +85,12 @@ def analyze_pe_file(s3_path: str, s3_region: str, s3_bucket: str) -> tuple:
)
export_count = len(pe.DIRECTORY_ENTRY_EXPORT.symbols) if hasattr(pe, "DIRECTORY_ENTRY_EXPORT") else 0

return s3_path, file_size, file_type, arch, import_count, export_count
except Exception:
return s3_path, None, file_type, None, None, None
return s3_path, file_size, file_type, arch, import_count, export_count, None
except botocore.exceptions.EndpointConnectionError as err:
# TODO: if the error is due to network issues, we need to retry
return s3_path, None, file_type, None, None, None, str(err)
except Exception as err:
return s3_path, None, file_type, None, None, None, str(err)
finally:
if "pe" in locals():
pe.close()
Expand All @@ -90,18 +104,33 @@ def analyze_file_udf(row):
[(f, settings.s3_region, settings.s3_bucket) for f in all_files], ["path", "s3_region", "s3_bucket"]
)

logger.info(f"Processing {df.count()} files...")

start = time.time()
df = df.select(analyze_udf(struct("path", "s3_region", "s3_bucket")).alias("metadata")).select("metadata.*")
logger.info(f"Processing time: {time.time() - start:.2f} seconds.")

start = time.time()
metadata_list = df.collect()
logger.info(f"Collect time: {time.time() - start:.2f} seconds.")

self.database.save_metadata(metadata_list)

logger.info(f"Processed {df.count()} new files.")
df = df.select(analyze_udf(struct("path", "s3_region", "s3_bucket")).alias("metadata")).select("metadata.*")

def process_partition(iterator):
from sqlalchemy import create_engine, Table, MetaData
from sqlalchemy.dialects.postgresql import insert

engine = create_engine("postgresql://postgres:postgres@localhost:5432/metadata_db")
file_metadata_table = Table("file_metadata", MetaData(), autoload_with=engine)

batch = []
with engine.connect() as connection:
for row in iterator:
row_dict = row.asDict()
if row_dict["error"]:
print(f"Error processing {row_dict['path']}: {row_dict['error']}")
del row_dict["error"]
batch.append(row_dict)
if len(batch) >= 100:
connection.execute(insert(file_metadata_table).values(batch))
batch = []
if batch:
connection.execute(insert(file_metadata_table).values(batch))
connection.commit()

df.foreachPartition(process_partition)

logger.info(f"Processed {df.count()} new files in {time.time() - start:.2f} seconds.")
logger.info(f"Total processing time: {time.time() - process_start:.2f} seconds.")
7 changes: 2 additions & 5 deletions pe_analyzer/s3_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,6 @@ def __init__(self, region_name: str = settings.s3_region, bucket_name: str = set
self.s3_client = boto3.client("s3", region_name=region_name)
self.bucket_name = bucket_name

def list_files(self, prefix: str, limit: int) -> list[str]:
def list_files(self, prefix: str) -> list[str]:
response = self.s3_client.list_objects_v2(Bucket=self.bucket_name, Prefix=prefix)
return [obj["Key"] for obj in response.get("Contents", [])[:limit]]

def download_file(self, file_path: str, local_path: str):
self.s3_client.download_file(self.bucket_name, file_path, local_path)
return [obj["Key"] for obj in response.get("Contents", [])]
17 changes: 17 additions & 0 deletions pe_analyzer/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,22 @@ class DatabaseSettings(BaseSettings):
auto_flush: bool = False
expire_on_commit: bool = False

def get_pyspark_driver(self):
return {
"postgresql": "org.postgresql.Driver",
# Add more protocols here
}[self.db_protocol]

def get_pyspark_properties(self):
return {
"user": self.db_user,
"password": self.db_password,
"driver": self.get_pyspark_driver(),
}

def get_pyspark_db_url(self):
return f"jdbc:{self.db_protocol}://{self.db_host}:{self.db_port}/{self.db_name}"

class Config:
extra = "allow"

Expand All @@ -30,6 +46,7 @@ class AppSettings(BaseSettings):
base_dir: DirectoryPath = os.path.dirname(os.path.abspath(__file__))

logging_level: str = "INFO"
logging_format: str = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"

database: DatabaseSettings = DatabaseSettings(_env_file="db.env", _env_file_encoding="utf-8")

Expand Down
2 changes: 1 addition & 1 deletion spark/Dockerfile
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
FROM bitnami/spark:3.5.1

USER root
RUN pip install boto3 pefile
RUN pip install boto3 pefile sqlalchemy psycopg2-binary
USER 1001

0 comments on commit c50ee5d

Please sign in to comment.