Skip to content

Commit

Permalink
fix by inserting using on_conflict_do_nothing
Browse files Browse the repository at this point in the history
  • Loading branch information
peterdudfield committed Oct 1, 2024
1 parent 99b57a5 commit 47d752d
Showing 1 changed file with 20 additions and 19 deletions.
39 changes: 20 additions & 19 deletions pvsite_datamodel/read/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from datetime import datetime
from typing import List, Optional

from sqlalchemy.dialects import postgresql
from sqlalchemy.orm import Session, contains_eager

from pvsite_datamodel.sqlmodels import APIRequestSQL, SiteGroupSQL, UserSQL
Expand All @@ -27,21 +28,16 @@ def get_user_by_email(session: Session, email: str, make_new_user_if_none: bool
logger.info(f"User with email {email} not found, so making one")

# checking for site_group
site_group = (
session.query(SiteGroupSQL)
.filter(SiteGroupSQL.site_group_name == f"site_group_for_{email}")
.first()
)
# making a new site group if one doesn't exist
if site_group is None:
site_group = SiteGroupSQL(site_group_name=f"site_group_for_{email}")
session.add(site_group)
session.commit()
site_group_name = f"site_group_for_{email}"
site_group = get_site_group_by_name(session=session, site_group_name=site_group_name)

# make a new user
user = UserSQL(email=email, site_group_uuid=site_group.site_group_uuid)
session.add(user)
session.commit()
stmt = postgresql.insert(UserSQL.__table__)
stmt = stmt.on_conflict_do_nothing()
session.execute(stmt, [{"site_group_uuid": site_group.site_group_uuid, "email": email}])

# get a new user
user = session.query(UserSQL).filter(UserSQL.email == email).first()

return user

Expand All @@ -62,7 +58,7 @@ def get_all_users(session: Session) -> List[UserSQL]:
return users


def get_site_group_by_name(session: Session, site_group_name: str):
def get_site_group_by_name(session: Session, site_group_name: str, create_if_none: bool = True):
"""
Get site group by name. If site group does not exist, make one.
Expand All @@ -75,13 +71,18 @@ def get_site_group_by_name(session: Session, site_group_name: str):
session.query(SiteGroupSQL).filter(SiteGroupSQL.site_group_name == site_group_name).first()
)

if site_group is None:
if (site_group is None) and (create_if_none is True):
logger.info(f"Site group with name {site_group_name} not found, so making one")

# make a new site group
site_group = SiteGroupSQL(site_group_name=site_group_name)
session.add(site_group)
session.commit()
stmt = postgresql.insert(SiteGroupSQL.__table__)
stmt = stmt.on_conflict_do_nothing()
session.execute(stmt, [{"site_group_name": site_group_name}])

site_group = (
session.query(SiteGroupSQL)
.filter(SiteGroupSQL.site_group_name == site_group_name)
.first()
)

return site_group

Expand Down

0 comments on commit 47d752d

Please sign in to comment.