Skip to content

Commit

Permalink
Propagate primary key type through openapi
Browse files Browse the repository at this point in the history
  • Loading branch information
nside committed Jul 28, 2023
1 parent 62e9691 commit 816fe5c
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 28 deletions.
10 changes: 5 additions & 5 deletions sqlite2rest/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ def get_primary_key(self, table_name):
columns = self.cursor.fetchall()
for column in columns:
if column[5]: # The 6th item in the tuple is 1 if the column is the primary key, 0 otherwise
return column[1] # The 2nd item in the tuple is the column name
return None
return column[1], column[2] # The 2nd item in the tuple is the column name, the 3rd item is the column type
return None, None

def get_records(self, table_name, page, per_page):
offset = (page - 1) * per_page
Expand All @@ -25,7 +25,7 @@ def get_records(self, table_name, page, per_page):
return records

def get_record(self, table_name, key):
primary_key = self.get_primary_key(table_name)
primary_key, _ = self.get_primary_key(table_name)
self.cursor.execute(f"SELECT * FROM {table_name} WHERE {primary_key} = ?;", (key,))
row = self.cursor.fetchone()
if row is None:
Expand All @@ -41,13 +41,13 @@ def create_record(self, table_name, data):
self.conn.commit()

def update_record(self, table_name, key, data):
primary_key = self.get_primary_key(table_name)
primary_key, _ = self.get_primary_key(table_name)
set_clause = ', '.join(f"{column} = ?" for column in data.keys())
self.cursor.execute(f"UPDATE {table_name} SET {set_clause} WHERE {primary_key} = ?;", tuple(data.values()) + (key,))
self.conn.commit()

def delete_record(self, table_name, key):
primary_key = self.get_primary_key(table_name)
primary_key, _ = self.get_primary_key(table_name)
self.cursor.execute(f"DELETE FROM {table_name} WHERE {primary_key} = ?;", (key,))
self.conn.commit()

Expand Down
109 changes: 87 additions & 22 deletions sqlite2rest/openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,89 @@
from openapi_spec_validator import validate_spec
import yaml

def generate_openapi_spec():
def get_operation_summary(method):
return {
'GET': 'Retrieve all records from',
'POST': 'Create a new record in',
'PUT': 'Update a record in',
'DELETE': 'Delete a record from',
'PATCH': 'Partially update a record in',
'TRACE': 'Trace a request to'
}.get(method, 'Perform operation on')

def add_paging_parameters(operation_obj):
operation_obj["parameters"] = [
{
"name": "page",
"in": "query",
"description": "Page number to retrieve",
"required": False,
"schema": {
"type": "integer",
"default": 1
}
},
{
"name": "per_page",
"in": "query",
"description": "Number of records per page",
"required": False,
"schema": {
"type": "integer",
"default": 10
}
}
]

def add_operation_to_path(path_item, method, rule_str, primary_key_type):
operation = get_operation_summary(method)
table_name = rule_str.split('/')[1]
operation_obj = {
"summary": f"{operation} the {table_name} table",
"responses": {
"200": {
"description": "OK"
}
}
}
if method == 'GET':
if '<id>' in rule_str:
operation_obj["parameters"] = [
{
"name": "id",
"in": "path",
"description": "The ID of the record to retrieve",
"required": True,
"schema": {
"type": primary_key_type,
}
}
]
else:
add_paging_parameters(operation_obj)
path_item[method.lower()] = operation_obj

def sqlite_type_to_openapi_type(sqlite_type):
"""
Convert SQLite data types to OpenAPI data types.
"""
sqlite_type = sqlite_type.upper()
if sqlite_type in ["INT", "INTEGER", "TINYINT", "SMALLINT", "MEDIUMINT", "BIGINT", "UNSIGNED BIG INT", "INT2", "INT8"]:
return "integer"
elif sqlite_type in ["REAL", "DOUBLE", "DOUBLE PRECISION", "FLOAT"]:
return "number"
elif sqlite_type in ["TEXT", "CHARACTER", "VARCHAR", "VARYING CHARACTER", "NCHAR", "NATIVE CHARACTER", "NVARCHAR", "CLOB"]:
return "string"
elif sqlite_type in ["BLOB"]:
return "string", "byte"
elif sqlite_type in ["BOOLEAN"]:
return "boolean"
elif sqlite_type in ["DATE", "DATETIME"]:
return "string", "date-time"
else:
return "string"

def generate_openapi_spec(db):
# Basic OpenAPI spec
spec = {
"openapi": "3.0.0",
Expand All @@ -28,33 +110,16 @@ def generate_openapi_spec():
# Add an operation object for each method
for method in rule.methods:
if method in ['GET', 'POST', 'PUT', 'DELETE']:
operation = {
'GET': 'Retrieve all records from',
'POST': 'Create a new record in',
'PUT': 'Update a record in',
'DELETE': 'Delete a record from',
'PATCH': 'Partially update a record in',
'TRACE': 'Trace a request to'
}.get(method, 'Perform operation on')

table_name = str(rule).split('/')[1]

path_item[method.lower()] = {
"summary": f"{operation} the {table_name} table",
"responses": {
"200": {
"description": "OK"
}
}
}
_, primary_key_type = db.get_primary_key(table_name)
add_operation_to_path(path_item, method, str(rule), sqlite_type_to_openapi_type(primary_key_type))

# Validate the spec
validate_spec(spec)

# Return the spec as a dictionary
return spec

def get_openapi_spec():
spec = generate_openapi_spec()
def get_openapi_spec(db):
spec = generate_openapi_spec(db)
return yaml.dump(spec)

2 changes: 1 addition & 1 deletion sqlite2rest/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,6 @@ def delete_record(id):
@app.route('/openapi.yaml', methods=['GET'])
def openapi():
app.logger.info('Getting OpenAPI specification')
spec = get_openapi_spec()
spec = get_openapi_spec(get_database())
return spec, 200, {'Content-Type': 'text/vnd.yaml'}

0 comments on commit 816fe5c

Please sign in to comment.