diff --git a/core/model.py b/core/model.py index 18efa11..55f2be6 100644 --- a/core/model.py +++ b/core/model.py @@ -32,11 +32,13 @@ def _initialize_nodes(self, size_input, size_output): def _initialize_population(self, pop_size): node_ids = set(self._input_node_ids + self._output_node_ids) bias_node_id = self.get_node(NodeTypes.bias).id - connections = ({"in_node_id": bias_node_id, "out_node_id": out_id, "weight": ((random.random() * 2) - 1)} - for out_id in self._output_node_ids) + connections = ( + {"in_node_id": bias_node_id, "out_node_id": out_id, "weight": ((random.random() * 2) - 1)} + for out_id in self._output_node_ids + ) connections = tuple(connections) individual_dict = {'genotype_kwargs': {"node_ids": node_ids, "connection_dicts": connections}} - individual_dicts = (individual_dict for i in range(pop_size)) + individual_dicts = (individual_dict for _ in range(pop_size)) self._population = self.get_population( generation_id=self._generation.id, individual_dicts=tuple( diff --git a/core/orm/__init__.py b/core/orm/__init__.py index e69de29..9783df3 100644 --- a/core/orm/__init__.py +++ b/core/orm/__init__.py @@ -0,0 +1,53 @@ +from typing import Dict + +from core.orm.database import Database + + +class AbstractModelElement: + _table: str = None + _columns: Dict[str, type] = {'id': int} + + def __init__(self, db): + """ + + :param core.orm.database.Database db: + """ + self._db: Database = db + + def _cast_data_to_types(self, *args): + if not args: + return args + cleaned_args = [] + for key, value in zip(self._columns, args): + cleaned_args.append(self._columns.get(key, str)(value)) + return tuple(cleaned_args) + + def _search_from_id(self, element_id: int, class_table=None): + class_table = class_table or self.__class__ + res = self._db.execute( + f""" + SELECT {', '.join(class_table._columns)} + FROM {class_table._table} + WHERE id = {element_id} + ORDER BY id DESC + LIMIT 1 + """ + ) + if not res: + raise ValueError(f"No {__class__} exists with that id") + return self._cast_data_to_types(*res) + + def _search_from_data(self, class_table=None, **kwargs): + class_table = class_table or self.__class__ + where_clause = ' AND '.join((f"{key} = {value}" for key, value in kwargs.items())) + + res = self._db.execute( + f""" + SELECT {', '.join(class_table._columns)} + FROM {class_table._table} + WHERE {where_clause} + ORDER BY id DESC + LIMIT 1 + """ + ) + return self._cast_data_to_types(*res) diff --git a/core/orm/connections.py b/core/orm/connections.py index ce6d0e0..4a314e8 100644 --- a/core/orm/connections.py +++ b/core/orm/connections.py @@ -1,79 +1,97 @@ -class HistoricalConnection: +from typing import Dict + +from core.orm import AbstractModelElement + + +class HistoricalConnection(AbstractModelElement): + _table: str = 'connection_historical' + _columns: Dict[str, type] = {'id': int, 'in_node_id': int, 'out_node_id': int} + id: int + in_node: int + out_node: int + def __init__(self, db, historical_connection_id=None, in_node_id=None, out_node_id=None): - self._db = db + if not (historical_connection_id or (in_node_id and out_node_id)): + raise ValueError("Must be given either an existing historical_connection_id or in and out_node_id") + super().__init__(db) + res = None, None, None if historical_connection_id: - res = self._db.execute( - f""" - SELECT id, in_node_id, out_node_id - FROM connection_historical - WHERE id = {historical_connection_id} - LIMIT 1 - """) + res = self._search_from_id( + historical_connection_id, + class_table=HistoricalConnection, + ) if not res: raise ValueError('No HistoricalConnection exists with that id') - self.id, self.in_node, self.out_node = res[0] elif in_node_id and out_node_id: if in_node_id == out_node_id: raise ValueError("in_node_id and out_node_id must be different nodes") + res = self._search_from_data( + class_table=HistoricalConnection, + in_node_id=in_node_id, + out_node_id=out_node_id, + ) + if not res: + res = self._create_historical_connection(in_node_id, out_node_id) + if not res: + raise ValueError("Incorrect values for in_node_id and/or out_node_id parameters") + self.id, self.in_node, self.out_node = res - res = self._db.execute( - f""" - SELECT id, in_node_id, out_node_id - FROM connection_historical - WHERE in_node_id = {in_node_id} AND out_node_id = {out_node_id} - LIMIT 1 - """) - if res: - self.id, self.in_node, self.out_node = res[0] - else: - self._db.execute( - f""" + def _create_historical_connection(self, in_node_id, out_node_id): + self._db.execute( + f""" INSERT INTO connection_historical (in_node_id, out_node_id) VALUES ({in_node_id}, {out_node_id}) - """) - res = self._db.execute( - f""" - SELECT id - FROM connection_historical - WHERE in_node_id = {in_node_id} AND out_node_id = {out_node_id} - ORDER BY id DESC - LIMIT 1 - """) - if not res: - raise ValueError("Incorrect values for in_node_id and/or out_node_id parameters") - self.id = res[0][0] - self.in_node = in_node_id - self.out_node = out_node_id - else: - raise ValueError( - "A Connection must be given either an existing historical_connection_id or in and out_node_id") + """ + ) + return self._search_from_data(class_table=HistoricalConnection, in_node_id=in_node_id, out_node_id=out_node_id) class Connection(HistoricalConnection): + _table: str = 'connection' + _columns: Dict[str, type] = { + 'id': int, + 'genotype_id': int, + 'weight': float, + 'is_enabled': bool, + 'historical_id': int, + } + genotype_id: int + is_enabled: bool + weight: float + in_node: int + out_node: int + def __init__( - self, db, historical_connection_id=None, in_node_id=None, out_node_id=None, - genotype_id=None, connection_id=None, weight=None, is_enabled=None): - self._db = db - conn_id = None + self, + db, + historical_connection_id=None, + in_node_id=None, + out_node_id=None, + genotype_id=None, + connection_id=None, + weight=None, + is_enabled=None, + ): if not connection_id and not genotype_id: raise ValueError("Must specify a genotype_id or a connection_id") + if historical_connection_id or (in_node_id and out_node_id): + super().__init__( + db=db, + historical_connection_id=historical_connection_id, + in_node_id=in_node_id, + out_node_id=out_node_id, + ) + self.historical_id = self.id + else: + AbstractModelElement.__init__(self, db=db) if connection_id: - res = self._db.execute( - f""" - SELECT id, genotype_id, weight, is_enabled, historical_id - FROM connection - WHERE id = {connection_id} - ORDER BY id DESC - LIMIT 1 - """) - if not res: - raise ValueError("Specified connection_id doesn't exist") - conn_id, self.genotype_id, self._weight, self._is_enabled, historical_connection_id = res[0] - super().__init__(db, historical_connection_id, in_node_id, out_node_id) - self.historical_id = int(self.id) - self.id = conn_id + res = self._search_from_id(connection_id) + self.id, self.genotype_id, self._weight, self._is_enabled, self.historical_id = res + historical_connection = HistoricalConnection(self._db, historical_connection_id=self.historical_id) + self.in_node = historical_connection.in_node + self.out_node = historical_connection.out_node if not connection_id: res = self._db.execute( @@ -97,7 +115,7 @@ def __init__( """) if res: - self.id, self.genotype_id, self._is_enabled, self._weight = res[0] + self.id, self.genotype_id, self._is_enabled, self._weight = res if is_enabled: self.is_enabled = is_enabled if weight: @@ -123,10 +141,13 @@ def __init__( LIMIT 1 """) - self.id = res[0][0] + self.id = res[0] self._weight = weight self._is_enabled = is_enabled self.genotype_id = genotype_id + historical_connection = HistoricalConnection(self._db, historical_connection_id=self.historical_id) + self.in_node = historical_connection.in_node + self.out_node = historical_connection.out_node @property def is_enabled(self): diff --git a/core/orm/database.py b/core/orm/database.py index 50050e3..04058a4 100644 --- a/core/orm/database.py +++ b/core/orm/database.py @@ -33,170 +33,172 @@ def _create_cursor(self): return self._con.cursor() def _clear(self): - self._cursor.executescript( - """ - PRAGMA writable_schema = 1; - delete from sqlite_master where type in ('table', 'index', 'trigger'); - PRAGMA writable_schema = 0; - VACUUM; - PRAGMA INTEGRITY_CHECK; - PRAGMA foreign_keys = ON; - """) + query = """ + PRAGMA writable_schema = 1; + delete from sqlite_master where type in ('table', 'index', 'trigger'); + PRAGMA writable_schema = 0; + VACUUM; + PRAGMA INTEGRITY_CHECK; + PRAGMA foreign_keys = ON; + """ + + self._cursor.executescript(query) self._con.commit() def init_db(self): self._cursor.executescript( """ - CREATE TABLE node_type ( - id INTEGER PRIMARY KEY, - name VARCHAR(6) NOT NULL UNIQUE - ); - - INSERT INTO node_type (name) - VALUES ('Bias'), - ('Input'), - ('Hidden'), - ('Output'); - - CREATE TABLE mutation_type ( - id INTEGER PRIMARY KEY, - name VARCHAR(8) NOT NULL UNIQUE - ); - - INSERT INTO mutation_type (name) - VALUES ('Weight'), - ('Enabling'), - ('Split'); + CREATE TABLE node_type ( + id INTEGER PRIMARY KEY, + name VARCHAR(6) NOT NULL UNIQUE + ); - CREATE TABLE node ( - id INTEGER PRIMARY KEY, - node_type_id INTEGER, - connection_historical_id INTEGER UNIQUE, - FOREIGN KEY (node_type_id) - REFERENCES node_type (id) - ON DELETE RESTRICT - ON UPDATE RESTRICT - ); - INSERT INTO node (node_type_id) - VALUES (1); - - CREATE TABLE connection_historical ( - id INTEGER PRIMARY KEY, - in_node_id int NOT NULL, - out_node_id int NOT NULL, - FOREIGN KEY (in_node_id) - REFERENCES node (id) - ON DELETE RESTRICT - ON UPDATE RESTRICT, - FOREIGN KEY (out_node_id) - REFERENCES node (id) - ON DELETE RESTRICT - ON UPDATE RESTRICT, - CHECK (in_node_id != out_node_id) - ); - - CREATE TABLE genotype ( - id INTEGER PRIMARY KEY, - parent_1_id INTEGER, - parent_2_id INTEGER, - FOREIGN KEY (parent_1_id) - REFERENCES genotype (id) - ON DELETE RESTRICT - ON UPDATE CASCADE, - FOREIGN KEY (parent_2_id) - REFERENCES genotype (id) - ON DELETE RESTRICT - ON UPDATE CASCADE, - CHECK (id != genotype.parent_1_id), - CHECK (id != genotype.parent_2_id) - ); - - CREATE TABLE connection ( - id INTEGER PRIMARY KEY, - historical_id INTEGER NOT NULL, - genotype_id INTEGER NOT NULL, - is_enabled BOOLEAN DEFAULT TRUE NOT NULL, - weight FLOAT DEFAULT 1.0 NOT NULL, - FOREIGN KEY (historical_id) - REFERENCES connection_historical (id) - ON DELETE CASCADE - ON UPDATE CASCADE , - FOREIGN KEY (genotype_id) - REFERENCES genotype (id) - ON DELETE CASCADE - ON UPDATE CASCADE, - UNIQUE (historical_id, genotype_id) - ); + INSERT INTO node_type (name) + VALUES ('Bias'), + ('Input'), + ('Hidden'), + ('Output'); + + CREATE TABLE mutation_type ( + id INTEGER PRIMARY KEY, + name VARCHAR(8) NOT NULL UNIQUE + ); + + INSERT INTO mutation_type (name) + VALUES ('Weight'), + ('Enabling'), + ('Split'); + + CREATE TABLE node ( + id INTEGER PRIMARY KEY, + node_type_id INTEGER, + connection_historical_id INTEGER UNIQUE, + FOREIGN KEY (node_type_id) + REFERENCES node_type (id) + ON DELETE RESTRICT + ON UPDATE RESTRICT + ); + INSERT INTO node (node_type_id) + VALUES (1); - CREATE TABLE genotype_node_rel ( + CREATE TABLE connection_historical ( id INTEGER PRIMARY KEY, - genotype_id INTEGER NOT NULL, - node_id INTEGER NOT NULL, - UNIQUE (node_id, genotype_id), - FOREIGN KEY (genotype_id) - REFERENCES genotype (id) - ON DELETE CASCADE - ON UPDATE CASCADE, - FOREIGN KEY (node_id) + in_node_id int NOT NULL, + out_node_id int NOT NULL, + FOREIGN KEY (in_node_id) REFERENCES node (id) - ON DELETE CASCADE - ON UPDATE CASCADE - ); - - CREATE TABLE generation ( - id INTEGER PRIMARY KEY - ); - - CREATE TABLE specie ( - id INTEGER PRIMARY KEY - ); + ON DELETE RESTRICT + ON UPDATE RESTRICT, + FOREIGN KEY (out_node_id) + REFERENCES node (id) + ON DELETE RESTRICT + ON UPDATE RESTRICT, + CHECK (in_node_id != out_node_id) + ); - CREATE TABLE population ( - id INTEGER PRIMARY KEY, - generation_id INTEGER NOT NULL, - FOREIGN KEY (generation_id) - REFERENCES generation (id) - ON DELETE CASCADE - ON UPDATE CASCADE - ); + CREATE TABLE genotype ( + id INTEGER PRIMARY KEY, + parent_1_id INTEGER, + parent_2_id INTEGER, + FOREIGN KEY (parent_1_id) + REFERENCES genotype (id) + ON DELETE RESTRICT + ON UPDATE CASCADE, + FOREIGN KEY (parent_2_id) + REFERENCES genotype (id) + ON DELETE RESTRICT + ON UPDATE CASCADE, + CHECK (id != genotype.parent_1_id), + CHECK (id != genotype.parent_2_id) + ); - CREATE TABLE individual ( - id INTEGER PRIMARY KEY, - genotype_id INTEGER NOT NULL, - specie_id INTEGER NOT NULL, - score INTEGER DEFAULT 0 NOT NULL, - population_id INTEGER NOT NULL, - FOREIGN KEY (genotype_id) - REFERENCES genotype(id) - ON DELETE CASCADE - ON UPDATE CASCADE , - FOREIGN KEY (specie_id) - REFERENCES specie (id) + CREATE TABLE connection ( + id INTEGER PRIMARY KEY, + historical_id INTEGER NOT NULL, + genotype_id INTEGER NOT NULL, + is_enabled BOOLEAN DEFAULT TRUE NOT NULL, + weight FLOAT DEFAULT 1.0 NOT NULL, + FOREIGN KEY (historical_id) + REFERENCES connection_historical (id) + ON DELETE CASCADE + ON UPDATE CASCADE , + FOREIGN KEY (genotype_id) + REFERENCES genotype (id) + ON DELETE CASCADE + ON UPDATE CASCADE, + UNIQUE (historical_id, genotype_id) + ); + + CREATE TABLE genotype_node_rel ( + id INTEGER PRIMARY KEY, + genotype_id INTEGER NOT NULL, + node_id INTEGER NOT NULL, + UNIQUE (node_id, genotype_id), + FOREIGN KEY (genotype_id) + REFERENCES genotype (id) ON DELETE CASCADE ON UPDATE CASCADE, - FOREIGN KEY (population_id) - REFERENCES population (id) + FOREIGN KEY (node_id) + REFERENCES node (id) ON DELETE CASCADE ON UPDATE CASCADE - ); - - CREATE TABLE model_metadata ( - id INTEGER PRIMARY KEY, - speciation_tresh FLOAT DEFAULT 0.25 NOT NULL, - specie_cull_rate FLOAT DEFAULT 0.5 NOT NULL, - reproduction_cloning_rate FLOAT DEFAULT 0.25 NOT NULL, - reproduction_interspecie_rate FLOAT DEFAULT 0.001 NOT NULL, - population_size INTEGER DEFAULT 100 NOT NULL, - mutation_split_rate FLOAT DEFAULT 0.01 NOT NULL, - mutation_weight_rate FLOAT DEFAULT 0.05 NOT NULL, - mutation_switch_rate FLOAT DEFAULT 0.02 NOT NULL, - mutation_add_rate FLOAT DEFAULT 0.02 NOT NULL, - mutation_rate FLOAT GENERATED ALWAYS AS ( - mutation_split_rate + mutation_weight_rate + mutation_switch_rate + mutation_add_rate - ) STORED, - mutation_weight_std FLOAT DEFAULT 0.01 NOT NULL ); - """) + + CREATE TABLE generation ( + id INTEGER PRIMARY KEY + ); + + CREATE TABLE specie ( + id INTEGER PRIMARY KEY + ); + + CREATE TABLE population ( + id INTEGER PRIMARY KEY, + generation_id INTEGER NOT NULL, + FOREIGN KEY (generation_id) + REFERENCES generation (id) + ON DELETE CASCADE + ON UPDATE CASCADE + ); + + CREATE TABLE individual ( + id INTEGER PRIMARY KEY, + genotype_id INTEGER NOT NULL, + specie_id INTEGER NOT NULL, + score INTEGER DEFAULT 0 NOT NULL, + population_id INTEGER NOT NULL, + FOREIGN KEY (genotype_id) + REFERENCES genotype(id) + ON DELETE CASCADE + ON UPDATE CASCADE , + FOREIGN KEY (specie_id) + REFERENCES specie (id) + ON DELETE CASCADE + ON UPDATE CASCADE, + FOREIGN KEY (population_id) + REFERENCES population (id) + ON DELETE CASCADE + ON UPDATE CASCADE + ); + + CREATE TABLE model_metadata ( + id INTEGER PRIMARY KEY, + speciation_tresh FLOAT DEFAULT 0.25 NOT NULL, + specie_cull_rate FLOAT DEFAULT 0.5 NOT NULL, + reproduction_cloning_rate FLOAT DEFAULT 0.25 NOT NULL, + reproduction_interspecie_rate FLOAT DEFAULT 0.001 NOT NULL, + population_size INTEGER DEFAULT 100 NOT NULL, + mutation_split_rate FLOAT DEFAULT 0.01 NOT NULL, + mutation_weight_rate FLOAT DEFAULT 0.05 NOT NULL, + mutation_switch_rate FLOAT DEFAULT 0.02 NOT NULL, + mutation_add_rate FLOAT DEFAULT 0.02 NOT NULL, + mutation_rate FLOAT GENERATED ALWAYS AS ( + mutation_split_rate + mutation_weight_rate + mutation_switch_rate + mutation_add_rate + ) STORED, + mutation_weight_std FLOAT DEFAULT 0.01 NOT NULL + ); + """ + ) def execute(self, query): query += '' if query.endswith(';') else ';' @@ -205,10 +207,15 @@ def execute(self, query): try: res = self._cursor.execute(query) except sql.IntegrityError as sql_error: - raise ValueError from sql_error + raise ValueError(query) from sql_error + except sql.OperationalError as sql_error: + raise SyntaxError(query) from sql_error if "INSERT INTO" in query or "UPDATE" in query: self._con.commit() - return res.fetchall() + res = res.fetchall() + if len(res) == 1 and 'LIMIT 1' in query: + return res[0] + return res _db = Database(f'{PATH}/data/template', override=True) diff --git a/core/orm/generation.py b/core/orm/generation.py index b4f4f87..57ad6af 100644 --- a/core/orm/generation.py +++ b/core/orm/generation.py @@ -12,7 +12,7 @@ def __init__(self, db, generation_id=None): """) if not res: raise ValueError("Specified generation_id doesn't exist") - self.id = res[0][0] + self.id = res[0] else: self._db.execute("""INSERT INTO generation DEFAULT VALUES""") res = self._db.execute("""SELECT MAX(id) FROM generation""") diff --git a/core/orm/genotype.py b/core/orm/genotype.py index 43b2ffd..f9309e8 100644 --- a/core/orm/genotype.py +++ b/core/orm/genotype.py @@ -9,7 +9,7 @@ class Genotype: def __init__(self, db, genotype_id=None, node_ids=None, connection_dicts=None, parent_genotype_ids=None): self._db = db - parent_genotype_ids = [] if parent_genotype_ids is None else parent_genotype_ids + parent_genotype_ids = parent_genotype_ids or [] if not (genotype_id or (node_ids and connection_dicts)): raise ValueError("Must specify either an existing genotype_id or both node_ids and connection_dicts") @@ -24,7 +24,7 @@ def __init__(self, db, genotype_id=None, node_ids=None, connection_dicts=None, p """) if not res: raise ValueError("Specified genotype_id doesn't exist") - self.id, *self.parent_ids = res[0] + self.id, *self.parent_ids = res self.parent_ids = set((parent for parent in self.parent_ids if parent)) res = self._db.execute( f""" @@ -45,8 +45,9 @@ def __init__(self, db, genotype_id=None, node_ids=None, connection_dicts=None, p """ SELECT MAX(id) FROM genotype + LIMIT 1 """) - self.id = (res[0][0] or 0) + 1 + self.id = (res[0] or 0) + 1 self.parent_ids = set((parent for parent in parent_genotype_ids if parent)) parent_genotype_ids = list(sorted(self.parent_ids)) + ['NULL', 'NULL'] self._db.execute( @@ -94,10 +95,10 @@ def historical_connection_ids(self): """) return set((row[0] for row in res)) - def __xor__(self, other): + def __and__(self, other): if not isinstance(other, Genotype): raise TypeError( - "Cannot use xor operator between an instance of 'Genotype' and an instance of another class" + "Cannot use and operator between an instance of 'Genotype' and an instance of another class" ) diff_nodes = len(other.node_ids ^ self.node_ids) total_nodes = max(len(other.node_ids), len(self.node_ids)) @@ -113,7 +114,7 @@ def as_dict(self): 'node_ids': self.node_ids, 'connection_dicts': tuple( { - 'historical_connection_id': connection.historical_id, + 'connection_id': connection.historical_id, 'in_node_id': connection.in_node, 'out_node_id': connection.out_node, 'weight': connection.weight, @@ -149,9 +150,9 @@ def get_mutated(self): add_connection_count = 0 in_out_node_mapping = {} for connection in mutant.get('connection_dicts', []): - historical_id = connection['historical_connection_id'] + historical_id = connection['connection_id'] in_out_node_mapping.setdefault(connection['in_node_id'], []).append(connection['out_node_id']) - del connection['historical_connection_id'] + del connection['connection_id'] r = random.random() m_rate = weight_rate if r < m_rate: diff --git a/core/orm/individual.py b/core/orm/individual.py index bfd5ef2..ce971a6 100644 --- a/core/orm/individual.py +++ b/core/orm/individual.py @@ -8,6 +8,7 @@ class Individual: def __init__( self, db, individual_id=None, population_id=None, genotype_id=None, genotype_kwargs=None, specie_id=None, score=None): + genotype_kwargs = genotype_kwargs or {} self._db = db score = score or 0 if not individual_id and (not population_id or not (genotype_id or genotype_kwargs)): @@ -25,7 +26,7 @@ def __init__( """) if not res: raise ValueError("Specified individual_id doesn't exist") - self.id, self.genotype_id, self.specie_id, self._score, self.population_id = res[0] + self.id, self.genotype_id, self.specie_id, self._score, self.population_id = res else: test_exists = {"population": population_id} @@ -47,11 +48,10 @@ def __init__( if not res: raise ValueError(f"Specified {table}_id doesn't exist") - if not genotype_id and genotype_kwargs: - genotype_id = Genotype(self._db, **genotype_kwargs).id + genotype = Genotype(self._db, genotype_id=genotype_id, **genotype_kwargs) + genotype_id = genotype_id or genotype.id if not specie_id: # Find specie - genotype = Genotype(self._db, genotype_id=genotype_id) res = self._db.execute( f""" SELECT gen.id AS genotype_id, @@ -66,7 +66,7 @@ def __init__( best_result = min(1., max(0., 1. - self.speciation_threshold)) for other_id, other_specie_id in res: other_genotype = Genotype(self._db, genotype_id=other_id) - genotypes_similarity = genotype ^ other_genotype + genotypes_similarity = genotype & other_genotype if genotypes_similarity > best_result: best_specie_id = other_specie_id best_result = genotypes_similarity @@ -93,7 +93,7 @@ def __init__( LIMIT 1 """) - self.id = res[0][0] + self.id = res[0] self.population_id = population_id self._score = score self.specie_id = specie_id @@ -116,10 +116,10 @@ def __add__(self, other): other_genotype = Genotype(db, other.genotype_id) inner_hist_conn_ids = self_genotype.historical_connection_ids & other_genotype.historical_connection_ids - outer_hist_conn_ids = self_genotype.historical_connection_ids | other_genotype.historical_connection_ids + outer_hist_conn_ids = self_genotype.historical_connection_ids ^ other_genotype.historical_connection_ids hist_conn_ids = set(inner_hist_conn_ids) - for conn_id in (outer_hist_conn_ids - inner_hist_conn_ids): + for conn_id in outer_hist_conn_ids: if bool(round(random.random())): hist_conn_ids.add(conn_id) @@ -149,7 +149,7 @@ def __add__(self, other): { 'historical_connection_id': hist_conn_id, 'weight': weight, - 'is_enabled': is_enabled, + 'is_enabled': bool(is_enabled), }) return { @@ -177,7 +177,7 @@ def speciation_threshold(self): """) if not res: raise ValueError("There must be at least one row in model_metadata table to fetch data from") - return res[0][0] + return res[0] @property def score_raw(self): @@ -213,7 +213,7 @@ def __init__(self, db, specie_id=None): """) if not res: raise ValueError("Specified specie_id doesn't exist") - self.id = res[0][0] + self.id = res[0] else: self._db.execute("""INSERT INTO specie DEFAULT VALUES""") res = self._db.execute("""SELECT MAX(id) FROM specie""") @@ -256,7 +256,7 @@ def get_culled_individuals(self): FROM model_metadata ORDER BY id LIMIT 1 - """)[0][0] + """)[0] individuals, scores = self.get_sorted_individuals() individuals = individuals[:round(len(individuals) * cull_rate)] return individuals, scores[:len(individuals)] diff --git a/core/orm/node.py b/core/orm/node.py index bc0218f..9b51da9 100644 --- a/core/orm/node.py +++ b/core/orm/node.py @@ -18,7 +18,7 @@ def __init__(self, db, node_type=None, connection_historical_id=None, node_id=No """) if not res: raise ValueError('No node exists with that id') - self.id, self.node_type, self.connection_historical = res[0] + self.id, self.node_type, self.connection_historical = res else: if not connection_historical_id and not node_type: raise ValueError('A node must have a connection_historical_id or a node_type specified') @@ -37,7 +37,8 @@ def __init__(self, db, node_type=None, connection_historical_id=None, node_id=No FROM node LEFT JOIN node_type nt on node.node_type_id = nt.id WHERE nt.name = 'Bias' - """)[0][0] + LIMIT 1 + """)[0] self.node_type = node_type self.connection_historical = None else: @@ -57,7 +58,7 @@ def __init__(self, db, node_type=None, connection_historical_id=None, node_id=No """) if not res: raise ValueError('Incorrect node_type or connection_historical_id given') - self.id = res[0][0] + self.id = res[0] self.node_type = node_type self.connection_historical = connection_historical_id diff --git a/core/orm/population.py b/core/orm/population.py index 231cb0c..c7c9774 100644 --- a/core/orm/population.py +++ b/core/orm/population.py @@ -41,7 +41,7 @@ def __init__(self, db, population_id=None, generation_id=None, individual_dicts= f""" SELECT id from population WHERE generation_id={generation_id} ORDER BY id DESC LIMIT 1 """) - self.id = res[0][0] + self.id = res[0] self.generation_id = generation_id for individual_dict in individual_dicts: individual_dict["population_id"] = self.id @@ -61,7 +61,7 @@ def model_pop_size(self): """) if not res: raise ValueError("There must be at least one row in model_metadata table to fetch data from") - return res[0][0] + return res[0] @property def species(self): diff --git a/test/test_model.py b/test/test_model.py index 0cc717d..3e46fbf 100644 --- a/test/test_model.py +++ b/test/test_model.py @@ -30,10 +30,11 @@ def test_init(self): self.assertSequenceEqual(((1, "Bias"), (2, "Input"), (3, "Hidden"), (4, "Output"),), res) res = self._db.execute("""SELECT id, node_type_id, connection_historical_id FROM node""") - self.assertSequenceEqual([ - (i + 1, NodeTypes.bias if i < 1 else NodeTypes.input if i < 11 else NodeTypes.output, None) - for i in range(21) - ], + self.assertSequenceEqual( + [ + (i + 1, NodeTypes.bias if i < 1 else NodeTypes.input if i < 11 else NodeTypes.output, None) + for i in range(21) + ], res, ) diff --git a/test/test_orm.py b/test/test_orm.py index 19f0f30..44226f5 100644 --- a/test/test_orm.py +++ b/test/test_orm.py @@ -170,7 +170,7 @@ def test_init(self): Genotype(self._db, connection_dicts=connection_dicts) gen = Genotype(self._db, node_ids={1, 2, }, connection_dicts=connection_dicts) - self.assertEqual(1, HistoricalConnection(self._db, 1).id) + self.assertEqual(1, HistoricalConnection(self._db, historical_connection_id=1).id) self.assertEqual(1, Connection(self._db, historical_connection_id=1, genotype_id=gen.id).id) self.assertFalse(Connection(self._db, connection_id=1, genotype_id=gen.id).is_enabled) self.assertEqual(0.5, Connection(self._db, connection_id=1, genotype_id=gen.id).weight) @@ -195,19 +195,19 @@ def test_init(self): ) genode3 = Genotype(self._db, node_ids={1, 2, 3, 4, 5, 6, 7}, connection_dicts=connection_dicts_2) with self.assertRaises(TypeError): - _ = gen ^ {6, } - self.assertEqual(1.0, gen ^ genode2) - self.assertEqual(1 / 3, gen ^ genode3) + _ = gen & {6, } + self.assertEqual(1.0, gen & genode2) + self.assertEqual(1 / 3, gen & genode3) self.assertNotIn(new_node_id, gen.node_ids) def test_draw(self): - node_b = Node(self._db, 'bias') - node_i1 = Node(self._db, 'input') - node_i2 = Node(self._db, 'input') - node_h1 = Node(self._db, 'hidden') - node_h2 = Node(self._db, 'hidden') - node_o1 = Node(self._db, 'output') - node_o2 = Node(self._db, 'output') + node_b = Node(self._db, node_type='bias') + node_i1 = Node(self._db, node_type='input') + node_i2 = Node(self._db, node_type='input') + node_h1 = Node(self._db, node_type='hidden') + node_h2 = Node(self._db, node_type='hidden') + node_o1 = Node(self._db, node_type='output') + node_o2 = Node(self._db, node_type='output') connection_dicts = ( { 'in_node_id': node_i1.id, @@ -262,7 +262,7 @@ def test_draw(self): node_o1.id, node_o2.id, }, - connection_dicts=connection_dicts + connection_dicts=connection_dicts, ) path = os.path.join(os.path.dirname(__file__), 'test_genome.dot') with open(path, 'rt', encoding='utf-8') as test: @@ -391,17 +391,46 @@ def test_add(self): ) } ind2 = Individual(self._db, genotype_kwargs=genotype_kwargs_2, population_id=1) - ind3 = Individual(**(ind1 + ind2)) + + ind3_kwargs = ind1 + ind2 + self.assertDictEqual( + { + 'db': self._db, + 'genotype_id': None, + 'genotype_kwargs': { + 'connection_dicts': ( + { + 'historical_connection_id': 1, + 'is_enabled': False, + 'weight': 0.5, + }, + ), + 'node_ids': {2, 3}, + 'parent_genotype_ids': {1, 2}, + }, + 'individual_id': None, + 'population_id': 1, + 'score': None, + 'specie_id': None, + }, + ind3_kwargs, + ) + + ind3 = Individual(**ind3_kwargs) + self.assertNotEqual(ind1.id, ind3.id) self.assertNotEqual(ind2.id, ind3.id) + self.assertNotEqual(ind1.genotype_id, ind3.genotype_id) self.assertNotEqual(ind2.genotype_id, ind3.genotype_id) + self.assertEqual(ind1.population_id, ind3.population_id) self.assertEqual(ind2.population_id, ind3.population_id) + self.assertNotEqual(ind1.specie_id, ind3.specie_id) self.assertEqual(ind2.specie_id, ind3.specie_id) geno2 = Genotype(self._db, ind2.genotype_id) geno3 = Genotype(self._db, ind3.genotype_id) self.assertSetEqual(geno2.historical_connection_ids, geno3.historical_connection_ids) self.assertSetEqual(set(), geno2.connection_ids & geno3.connection_ids) self.assertSetEqual(geno2.node_ids, geno3.node_ids) - self.assertEqual(1, geno3 ^ geno2) + self.assertEqual(1, geno3 & geno2) class TestPopulation(NEATBaseTestCase):