Skip to content

Commit

Permalink
create sql tables only if sql stores
Browse files Browse the repository at this point in the history
  • Loading branch information
niklastheman committed Jan 8, 2025
1 parent 633d7cf commit 51c463c
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 46 deletions.
72 changes: 42 additions & 30 deletions fedn/network/api/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@

import pymongo
from pymongo.database import Database
from sqlalchemy import MetaData, create_engine
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, sessionmaker
from werkzeug.security import safe_join

from fedn.common.config import get_modelstorage_config, get_network_config, get_statestore_config
Expand All @@ -20,39 +18,53 @@
from fedn.network.storage.statestore.stores.session_store import MongoDBSessionStore, SQLSessionStore
from fedn.network.storage.statestore.stores.shared import EntityNotFound
from fedn.network.storage.statestore.stores.status_store import MongoDBStatusStore, SQLStatusStore, StatusStore
from fedn.network.storage.statestore.stores.store import Base, MyAbstractBase, engine
from fedn.network.storage.statestore.stores.store import MyAbstractBase, engine
from fedn.network.storage.statestore.stores.validation_store import MongoDBValidationStore, SQLValidationStore, ValidationStore
from fedn.utils.checksum import sha

statestore_config = get_statestore_config()
modelstorage_config = get_modelstorage_config()
network_id = get_network_config()

mc = pymongo.MongoClient(**statestore_config["mongo_config"])
mc.server_info()
mdb: Database = mc[network_id]

MyAbstractBase.metadata.create_all(engine, checkfirst=True)


# client_store: ClientStore = MongoDBClientStore(mdb, "network.clients")
client_store: ClientStore = SQLClientStore()
# package_store: PackageStore = MongoDBPackageStore(mdb, "control.package")
package_store: PackageStore = SQLPackageStore()
# session_store = MongoDBSessionStore(mdb, "control.sessions")
session_store = SQLSessionStore()
# model_store = MongoDBModelStore(mdb, "control.model")
model_store = SQLModelStore()
# combiner_store: CombinerStore = MongoDBCombinerStore(mdb, "network.combiners")
combiner_store: CombinerStore = SQLCombinerStore()
# round_store: RoundStore = MongoDBRoundStore(mdb, "control.rounds")
round_store: RoundStore = SQLRoundStore()
# status_store: StatusStore = MongoDBStatusStore(mdb, "control.status")
status_store: StatusStore = SQLStatusStore()
# validation_store: ValidationStore = MongoDBValidationStore(mdb, "control.validations")
validation_store: ValidationStore = SQLValidationStore()
# prediction_store: PredictionStore = MongoDBPredictionStore(mdb, "control.predictions")
prediction_store: PredictionStore = SQLPredictionStore()

client_store: ClientStore = None
validation_store: ValidationStore = None
combiner_store: CombinerStore = None
status_store: StatusStore = None
prediction_store: PredictionStore = None
round_store: RoundStore = None
package_store: PackageStore = None
model_store: SQLModelStore = None
session_store: SQLSessionStore = None

if statestore_config["type"] == "MongoDB":
network_id = get_network_config()

mc = pymongo.MongoClient(**statestore_config["mongo_config"])
mc.server_info()
mdb: Database = mc[network_id]

client_store = MongoDBClientStore(mdb, "network.clients")
validation_store = MongoDBValidationStore(mdb, "control.validations")
combiner_store = MongoDBCombinerStore(mdb, "network.combiners")
status_store = MongoDBStatusStore(mdb, "control.status")
prediction_store = MongoDBPredictionStore(mdb, "control.predictions")
round_store = MongoDBRoundStore(mdb, "control.rounds")
package_store = MongoDBPackageStore(mdb, "control.packages")
model_store = MongoDBModelStore(mdb, "control.models")
session_store = MongoDBSessionStore(mdb, "control.sessions")

else:
MyAbstractBase.metadata.create_all(engine, checkfirst=True)

client_store = SQLClientStore()
validation_store = SQLValidationStore()
combiner_store = SQLCombinerStore()
status_store = SQLStatusStore()
prediction_store = SQLPredictionStore()
round_store = SQLRoundStore()
package_store = SQLPackageStore()
model_store = SQLModelStore()
session_store = SQLSessionStore()


repository = Repository(modelstorage_config["storage_config"])

Expand Down
41 changes: 25 additions & 16 deletions fedn/network/combiner/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,32 +9,41 @@
from fedn.network.storage.statestore.stores.prediction_store import MongoDBPredictionStore, PredictionStore, SQLPredictionStore
from fedn.network.storage.statestore.stores.round_store import MongoDBRoundStore, RoundStore, SQLRoundStore
from fedn.network.storage.statestore.stores.status_store import MongoDBStatusStore, SQLStatusStore, StatusStore
from fedn.network.storage.statestore.stores.store import Base, MyAbstractBase, engine
from fedn.network.storage.statestore.stores.store import MyAbstractBase, engine
from fedn.network.storage.statestore.stores.validation_store import MongoDBValidationStore, SQLValidationStore, ValidationStore

statestore_config = get_statestore_config()
modelstorage_config = get_modelstorage_config()
network_id = get_network_config()

client_store: ClientStore = None
validation_store: ValidationStore = None
combiner_store: CombinerStore = None
status_store: StatusStore = None
prediction_store: PredictionStore = None
round_store: RoundStore = None

if statestore_config["type"] == "MongoDB":
network_id = get_network_config()

mc = pymongo.MongoClient(**statestore_config["mongo_config"])
mc.server_info()
mdb: Database = mc[network_id]

MyAbstractBase.metadata.create_all(engine, checkfirst=True)

# client_store: ClientStore = MongoDBClientStore(mdb, "network.clients")
client_store: ClientStore = SQLClientStore()
# validation_store: ValidationStore = MongoDBValidationStore(mdb, "control.validations")
validation_store: ValidationStore = SQLValidationStore()
# combiner_store: CombinerStore = MongoDBCombinerStore(mdb, "network.combiners")
combiner_store: CombinerStore = SQLCombinerStore()
# status_store: StatusStore = MongoDBStatusStore(mdb, "control.status")
status_store: StatusStore = SQLStatusStore()
# prediction_store: PredictionStore = MongoDBPredictionStore(mdb, "control.predictions")
prediction_store: PredictionStore = SQLPredictionStore()
# round_store: RoundStore = MongoDBRoundStore(mdb, "control.rounds")
round_store: RoundStore = SQLRoundStore()
client_store = MongoDBClientStore(mdb, "network.clients")
validation_store = MongoDBValidationStore(mdb, "control.validations")
combiner_store = MongoDBCombinerStore(mdb, "network.combiners")
status_store = MongoDBStatusStore(mdb, "control.status")
prediction_store = MongoDBPredictionStore(mdb, "control.predictions")
round_store = MongoDBRoundStore(mdb, "control.rounds")
else:
MyAbstractBase.metadata.create_all(engine, checkfirst=True)

client_store = SQLClientStore()
validation_store = SQLValidationStore()
combiner_store = SQLCombinerStore()
status_store = SQLStatusStore()
prediction_store = SQLPredictionStore()
round_store = SQLRoundStore()

repository = Repository(modelstorage_config["storage_config"], init_buckets=False)

Expand Down

0 comments on commit 51c463c

Please sign in to comment.