Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add sqlite support #12

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
337 changes: 255 additions & 82 deletions source/repomd.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import datetime
import bz2
import gzip
import io
import defusedxml.lxml
import pathlib
import urllib.request
import urllib.parse
import sqlite3
import tempfile


_ns = {
Expand All @@ -14,70 +17,194 @@
}


def load(baseurl):
# parse baseurl to allow manipulating the path
base = urllib.parse.urlparse(baseurl)
path = pathlib.PurePosixPath(base.path)
def get_repomd_obj(baseurl):
return RepoMD(baseurl)

# first we must get the repomd.xml file
repomd_path = path / 'repodata' / 'repomd.xml'
repomd_url = base._replace(path=str(repomd_path)).geturl()

# download and parse repomd.xml
with urllib.request.urlopen(repomd_url) as response:
repomd_xml = defusedxml.lxml.fromstring(response.read())
class RepoMD():
def __init__(self, baseurl):
# parse baseurl to allow manipulating the path
self.base = urllib.parse.urlparse(baseurl)
self.path = pathlib.PurePosixPath(self.base.path)

# determine the location of *primary.xml.gz
primary_element = repomd_xml.find('repo:data[@type="primary"]/repo:location', namespaces=_ns)
primary_path = path / primary_element.get('href')
primary_url = base._replace(path=str(primary_path)).geturl()
# first we must get the repomd.xml file
self.repomd_path = self.path / 'repodata' / 'repomd.xml'
self.repomd_url = self.base._replace(path=str(self.repomd_path)).geturl()
with urllib.request.urlopen(self.repomd_url) as response:
self.repomd_xml = defusedxml.lxml.fromstring(response.read())

# download and parse *-primary.xml
with urllib.request.urlopen(primary_url) as response:
with io.BytesIO(response.read()) as compressed:
with gzip.GzipFile(fileobj=compressed) as uncompressed:
metadata = defusedxml.lxml.fromstring(uncompressed.read())
def get_repo_file_url(self, href_name):
find_query = 'repo:data[@type="{}"]/repo:location'.format(href_name)

return Repo(baseurl, metadata)
primary_element = self.repomd_xml.find(find_query, namespaces=_ns)
primary_path = self.path / primary_element.get('href')
primary_url = self.base._replace(path=str(primary_path)).geturl()
return primary_url

def get_repo_file_contents(self, data_type):
"""Get the repometadata for the type in question.

class Repo:
"""A dnf/yum repository."""
Parameters:
data_type - The XML Node to look for, usually 'primary' or 'primary_db'

__slots__ = ['baseurl', '_metadata']
Returns:
(bytes) - An uncompressed bytes object from the URL referenced
in the found xml node.

def __init__(self, baseurl, metadata):
self.baseurl = baseurl
self._metadata = metadata
Common Exceptions:
Will raise AttributeError when the DataType is not found
"""
primary_url = self.get_repo_file_url(data_type)

def __repr__(self):
return f'<{self.__class__.__name__}: "{self.baseurl}">'
with urllib.request.urlopen(primary_url) as response:
with io.BytesIO(response.read()) as compressed:
if primary_url.endswith('.gz'):
with gzip.GzipFile(fileobj=compressed) as uncompressed:
return uncompressed.read()
if primary_url.endswith('.bz2'):
with bz2.BZ2File(compressed) as uncompressed:
return uncompressed.read()

def __str__(self):
return self.baseurl

def __len__(self):
return int(self._metadata.get('packages'))
def load(baseurl):
repomd_obj = get_repomd_obj(baseurl)
repo_obj = None
# download and parse repomd.xml
try:
primary_contents = repomd_obj.get_repo_file_contents('primary_db')
repo_obj = SQLiteRepo(baseurl, primary_contents)
except AttributeError:
# silencing this error so that we can pass to the next exception
pass
if not repo_obj:
primary_contents = repomd_obj.get_repo_file_contents('primary')
repo_obj = XmlRepo(baseurl, primary_contents)
return repo_obj

def __iter__(self):
for element in self._metadata:
yield Package(element)

def find(self, name):
results = self._metadata.findall(f'common:package[common:name="{name}"]', namespaces=_ns)
if results:
return Package(results[-1])
class BasePackage:

@property
def vr(self):
version_info = self._version_info
v = version_info.get('ver')
r = version_info.get('rel')
return f'{v}-{r}'

@property
def nvr(self):
return f'{self.name}-{self.vr}'

@property
def evr(self):
version_info = self._version_info
e = version_info.get('epoch')
v = version_info.get('ver')
r = version_info.get('rel')
if int(e):
return f'{e}:{v}-{r}'
else:
return None
return f'{v}-{r}'

def findall(self, name):
return [
Package(element)
for element in self._metadata.findall(f'common:package[common:name="{name}"]', namespaces=_ns)
]
@property
def nevr(self):
return f'{self.name}-{self.evr}'

@property
def nevra(self):
return f'{self.nevr}.{self.arch}'

@property
def _nevra_tuple(self):
return self.name, self.epoch, self.version, self.release, self.arch

def __eq__(self, other):
return self._nevra_tuple == other._nevra_tuple

class Package:
def __hash__(self):
return hash(self._nevra_tuple)

def __repr__(self):
return f'<{self.__class__.__name__}: "{self.nevra}">'


class SQLitePackage(BasePackage):
"""A Class for inspecting packages in an SQLite-based repo.

Most properties are generated based on primary.packages column headers and dynamically
generated from the pkg_row constructor parameter.

The properties that are auto generated from a CentOS 7-built primary sqlite DB are:

pkgId TEXT
name TEXT
arch TEXT
version TEXT
epoch TEXT
release TEXT
summary TEXT
description TEXT
url TEXT
time_file INTEGER
time_build INTEGER
license TEXT
vendor TEXT
group TEXT
buildhost TEXT
sourcerpm TEXT
header_start INTEGER
header_end INTEGER
packager TEXT
size_package INTEGER
size_installed INTEGER
size_archive INTEGER
location_href TEXT
location_base TEXT
checksum_type TEXT

These could vary between different yum repo versions but should be consistent with most major
red hat-derived distros in 2019.

There are some other properties as well that are not dynamically generated.

Properties:
location (str) - alias to location_href
ver (str) - alias to version
rel (str) - alias to release
shasum (str) - alias to pkgId
build_time (datetime) - datetime object from the time_build column

Parameters:
pkg_row - an sqlite3.Row object representing the sqlite table.

"""
def __init__(self, pkg_row):
self.pkg_row = pkg_row
# copy all of keys from the pkg_row into attributes
# this will result in the common rpm headers having attributes available.
# any key that starts with rpm_ in the sqlite file will have it stripped off
for k in pkg_row.keys():
attr_name = k
if k.startswith('rpm_'):
attr_name = attr_name.replace('rpm_', '')
setattr(self, attr_name, pkg_row[k])

self.location = self.location_href
self.ver = self.version
self.rel = self.release
self.shasum = self.pkgId
self._version_info = {
'epoch': self.epoch,
'ver': self.ver,
'rel': self.rel
}

@property
def build_time(self):
return datetime.datetime.fromtimestamp(int(self.time_build))


class XmlPackage(BasePackage):
"""An RPM package from a repository."""

__slots__ = ['_element']
Expand All @@ -89,6 +216,14 @@ def __init__(self, element):
def name(self):
return self._element.findtext('common:name', namespaces=_ns)

@property
def pkgId(self):
return self._element.findtext('common:checksum', namespaces=_ns)

@property
def checksum_type(self):
return self._element.find('common:checksum', namespaces=_ns).get('type')

@property
def arch(self):
return self._element.findtext('common:arch', namespaces=_ns)
Expand Down Expand Up @@ -134,10 +269,6 @@ def location(self):
def _version_info(self):
return self._element.find('common:version', namespaces=_ns)

@property
def epoch(self):
return self._version_info.get('epoch')

@property
def version(self):
return self._version_info.get('ver')
Expand All @@ -147,44 +278,86 @@ def release(self):
return self._version_info.get('rel')

@property
def vr(self):
version_info = self._version_info
v = version_info.get('ver')
r = version_info.get('rel')
return f'{v}-{r}'
def epoch(self):
return self._version_info.get('epoch')

@property
def nvr(self):
return f'{self.name}-{self.vr}'

@property
def evr(self):
version_info = self._version_info
e = version_info.get('epoch')
v = version_info.get('ver')
r = version_info.get('rel')
if int(e):
return f'{e}:{v}-{r}'
else:
return f'{v}-{r}'
class BaseRepo():
pass

@property
def nevr(self):
return f'{self.name}-{self.evr}'

@property
def nevra(self):
return f'{self.nevr}.{self.arch}'
class XmlRepo(BaseRepo):
"""A dnf/yum repository backed by XML."""

@property
def _nevra_tuple(self):
return self.name, self.epoch, self.version, self.release, self.arch
__slots__ = ['baseurl', '_metadata']

def __eq__(self, other):
return self._nevra_tuple == other._nevra_tuple
def __init__(self, baseurl, metadata):
self.baseurl = baseurl
self._metadata = defusedxml.lxml.fromstring(metadata)

def __hash__(self):
return hash(self._nevra_tuple)
def __repr__(self):
return f'<{self.__class__.__name__}: "{self.baseurl}">'

def __str__(self):
return self.baseurl

def __len__(self):
return int(self._metadata.get('packages'))

def __iter__(self):
for element in self._metadata:
yield XmlPackage(element)

def find(self, name):
results = self._metadata.findall(f'common:package[common:name="{name}"]', namespaces=_ns)
if results:
return XmlPackage(results[-1])
else:
return None

def findall(self, name):
return [
XmlPackage(element)
for element in self._metadata.findall(f'common:package[common:name="{name}"]', namespaces=_ns)
]


class SQLiteRepo(BaseRepo):
"""A yum/dnf repoistory backed by SQLite."""

def __init__(self, baseurl, metadata):
self.baseurl = baseurl
self.db_file = tempfile.NamedTemporaryFile()
self.db_file.write(metadata)
self.conn = sqlite3.connect(self.db_file.name)
self.conn.row_factory = sqlite3.Row

def __repr__(self):
return f'<{self.__class__.__name__}: "{self.nevra}">'
return f'<{self.__class__.__name__}: "{self.baseurl}">'

def __str__(self):
return self.baseurl

def __len__(self):
c = self.conn.cursor()
row = c.execute('SELECT COUNT(ALL) FROM packages')
return row.fetchone()[0]

def __iter__(self):
c = self.conn.cursor()
c.execute('SELECT * FROM packages')
for pkgrow in c:
yield SQLitePackage(pkgrow)

def findall(self, pkgname):
c = self.conn.cursor()
c.execute("SELECT * FROM packages WHERE name = ?", [pkgname])
return [SQLitePackage(p) for p in c.fetchall()]

def find(self, pkgname):
c = self.conn.cursor()
c.execute("SELECT * FROM packages WHERE name = ? ORDER BY time_build LIMIT 1", [pkgname])
res = c.fetchone()
if res:
return SQLitePackage(res)
return None
Loading