Skip to content

Commit

Permalink
[IMP] *: Refactoring
Browse files Browse the repository at this point in the history
Improve classes inheriteance & cleanup code

Signed-off-by: Julien Alardot (jual) <[email protected]>
  • Loading branch information
JulienAlardot committed Mar 9, 2024
1 parent 1c77649 commit a1d8dfe
Show file tree
Hide file tree
Showing 11 changed files with 371 additions and 256 deletions.
8 changes: 5 additions & 3 deletions core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,13 @@ def _initialize_nodes(self, size_input, size_output):
def _initialize_population(self, pop_size):
node_ids = set(self._input_node_ids + self._output_node_ids)
bias_node_id = self.get_node(NodeTypes.bias).id
connections = ({"in_node_id": bias_node_id, "out_node_id": out_id, "weight": ((random.random() * 2) - 1)}
for out_id in self._output_node_ids)
connections = (
{"in_node_id": bias_node_id, "out_node_id": out_id, "weight": ((random.random() * 2) - 1)}
for out_id in self._output_node_ids
)
connections = tuple(connections)
individual_dict = {'genotype_kwargs': {"node_ids": node_ids, "connection_dicts": connections}}
individual_dicts = (individual_dict for i in range(pop_size))
individual_dicts = (individual_dict for _ in range(pop_size))

self._population = self.get_population(
generation_id=self._generation.id, individual_dicts=tuple(
Expand Down
53 changes: 53 additions & 0 deletions core/orm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from typing import Dict

from core.orm.database import Database


class AbstractModelElement:
_table: str = None
_columns: Dict[str, type] = {'id': int}

def __init__(self, db):
"""
:param core.orm.database.Database db:
"""
self._db: Database = db

def _cast_data_to_types(self, *args):
if not args:
return args
cleaned_args = []
for key, value in zip(self._columns, args):
cleaned_args.append(self._columns.get(key, str)(value))
return tuple(cleaned_args)

def _search_from_id(self, element_id: int, class_table=None):
class_table = class_table or self.__class__
res = self._db.execute(
f"""
SELECT {', '.join(class_table._columns)}
FROM {class_table._table}
WHERE id = {element_id}
ORDER BY id DESC
LIMIT 1
"""
)
if not res:
raise ValueError(f"No {__class__} exists with that id")
return self._cast_data_to_types(*res)

def _search_from_data(self, class_table=None, **kwargs):
class_table = class_table or self.__class__
where_clause = ' AND '.join((f"{key} = {value}" for key, value in kwargs.items()))

res = self._db.execute(
f"""
SELECT {', '.join(class_table._columns)}
FROM {class_table._table}
WHERE {where_clause}
ORDER BY id DESC
LIMIT 1
"""
)
return self._cast_data_to_types(*res)
139 changes: 80 additions & 59 deletions core/orm/connections.py
Original file line number Diff line number Diff line change
@@ -1,79 +1,97 @@
class HistoricalConnection:
from typing import Dict

from core.orm import AbstractModelElement


class HistoricalConnection(AbstractModelElement):
_table: str = 'connection_historical'
_columns: Dict[str, type] = {'id': int, 'in_node_id': int, 'out_node_id': int}
id: int
in_node: int
out_node: int

def __init__(self, db, historical_connection_id=None, in_node_id=None, out_node_id=None):
self._db = db
if not (historical_connection_id or (in_node_id and out_node_id)):
raise ValueError("Must be given either an existing historical_connection_id or in and out_node_id")

super().__init__(db)
res = None, None, None
if historical_connection_id:
res = self._db.execute(
f"""
SELECT id, in_node_id, out_node_id
FROM connection_historical
WHERE id = {historical_connection_id}
LIMIT 1
""")
res = self._search_from_id(
historical_connection_id,
class_table=HistoricalConnection,
)
if not res:
raise ValueError('No HistoricalConnection exists with that id')
self.id, self.in_node, self.out_node = res[0]
elif in_node_id and out_node_id:
if in_node_id == out_node_id:
raise ValueError("in_node_id and out_node_id must be different nodes")
res = self._search_from_data(
class_table=HistoricalConnection,
in_node_id=in_node_id,
out_node_id=out_node_id,
)
if not res:
res = self._create_historical_connection(in_node_id, out_node_id)
if not res:
raise ValueError("Incorrect values for in_node_id and/or out_node_id parameters")
self.id, self.in_node, self.out_node = res

res = self._db.execute(
f"""
SELECT id, in_node_id, out_node_id
FROM connection_historical
WHERE in_node_id = {in_node_id} AND out_node_id = {out_node_id}
LIMIT 1
""")
if res:
self.id, self.in_node, self.out_node = res[0]
else:
self._db.execute(
f"""
def _create_historical_connection(self, in_node_id, out_node_id):
self._db.execute(
f"""
INSERT INTO connection_historical (in_node_id, out_node_id)
VALUES ({in_node_id}, {out_node_id})
""")
res = self._db.execute(
f"""
SELECT id
FROM connection_historical
WHERE in_node_id = {in_node_id} AND out_node_id = {out_node_id}
ORDER BY id DESC
LIMIT 1
""")
if not res:
raise ValueError("Incorrect values for in_node_id and/or out_node_id parameters")
self.id = res[0][0]
self.in_node = in_node_id
self.out_node = out_node_id
else:
raise ValueError(
"A Connection must be given either an existing historical_connection_id or in and out_node_id")
"""
)
return self._search_from_data(class_table=HistoricalConnection, in_node_id=in_node_id, out_node_id=out_node_id)


class Connection(HistoricalConnection):
_table: str = 'connection'
_columns: Dict[str, type] = {
'id': int,
'genotype_id': int,
'weight': float,
'is_enabled': bool,
'historical_id': int,
}
genotype_id: int
is_enabled: bool
weight: float
in_node: int
out_node: int

def __init__(
self, db, historical_connection_id=None, in_node_id=None, out_node_id=None,
genotype_id=None, connection_id=None, weight=None, is_enabled=None):
self._db = db
conn_id = None
self,
db,
historical_connection_id=None,
in_node_id=None,
out_node_id=None,
genotype_id=None,
connection_id=None,
weight=None,
is_enabled=None,
):
if not connection_id and not genotype_id:
raise ValueError("Must specify a genotype_id or a connection_id")
if historical_connection_id or (in_node_id and out_node_id):
super().__init__(
db=db,
historical_connection_id=historical_connection_id,
in_node_id=in_node_id,
out_node_id=out_node_id,
)
self.historical_id = self.id
else:
AbstractModelElement.__init__(self, db=db)

if connection_id:
res = self._db.execute(
f"""
SELECT id, genotype_id, weight, is_enabled, historical_id
FROM connection
WHERE id = {connection_id}
ORDER BY id DESC
LIMIT 1
""")
if not res:
raise ValueError("Specified connection_id doesn't exist")
conn_id, self.genotype_id, self._weight, self._is_enabled, historical_connection_id = res[0]
super().__init__(db, historical_connection_id, in_node_id, out_node_id)
self.historical_id = int(self.id)
self.id = conn_id
res = self._search_from_id(connection_id)
self.id, self.genotype_id, self._weight, self._is_enabled, self.historical_id = res
historical_connection = HistoricalConnection(self._db, historical_connection_id=self.historical_id)
self.in_node = historical_connection.in_node
self.out_node = historical_connection.out_node

if not connection_id:
res = self._db.execute(
Expand All @@ -97,7 +115,7 @@ def __init__(
""")

if res:
self.id, self.genotype_id, self._is_enabled, self._weight = res[0]
self.id, self.genotype_id, self._is_enabled, self._weight = res
if is_enabled:
self.is_enabled = is_enabled
if weight:
Expand All @@ -123,10 +141,13 @@ def __init__(
LIMIT 1
""")

self.id = res[0][0]
self.id = res[0]
self._weight = weight
self._is_enabled = is_enabled
self.genotype_id = genotype_id
historical_connection = HistoricalConnection(self._db, historical_connection_id=self.historical_id)
self.in_node = historical_connection.in_node
self.out_node = historical_connection.out_node

@property
def is_enabled(self):
Expand Down
Loading

0 comments on commit a1d8dfe

Please sign in to comment.