Skip to content

Commit

Permalink
Merge pull request #22 from yale-swe/slug-title
Browse files Browse the repository at this point in the history
add user profile
  • Loading branch information
plin349 authored Apr 15, 2024
2 parents d23c769 + ee81438 commit 14539d1
Show file tree
Hide file tree
Showing 3 changed files with 136 additions and 9 deletions.
Binary file removed backend/__pycache__/lib.cpython-312.pyc
Binary file not shown.
138 changes: 133 additions & 5 deletions backend/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,9 @@ def init_database(app):
if "MONGO_URI" in app.config:
client = MongoClient(app.config["MONGO_URI"])
db = client["course_db"]
app.config["collection"] = db["parsed_courses"]
app.config["courses"] = db["parsed_courses"]
app.config["profiles"] = db["user_profile"]

# else, set to None or Mock in case of testing

def create_app(test_config=None):
Expand Down Expand Up @@ -109,7 +111,133 @@ def validate_cas_ticket():
else:
print("response status is not 200")
return jsonify({"isAuthenticated": False}), 401

@app.route("/api/save_chat", methods=["POST"])
def save_chat():
try:
chat_data = request.get_json()
collection = app.config["courses"]
result = collection.insert_one(chat_data)
return jsonify({"status": "success", "id": str(result.inserted_id)}), 201
except Exception as e:
return jsonify({"status": "error", "message": str(e)}), 500

@app.route("/api/reload_chat", methods=["POST"])
def reload_chat():
try:
data = request.get_json()
chat_id = data.get("chat_id")
collection = app.config["courses"]
chat = collection.find_one({"_id": str(chat_id)})
if chat:
return jsonify({"status": "success", "chat": chat}), 200
else:
return jsonify({"status": "not found"}), 404
except Exception as e:
return jsonify({"status": "error", "message": str(e)}), 500

@app.route("/api/verify_course_code", methods=["POST"])
def verify_course_code():
try:
data = request.get_json()

uid = data.get("uid")
if not uid:
return jsonify({"error": "No uid provided"}), 400

course_code = data["search"]
course_collection = app.config["courses"]
user_collection = app.config["profiles"]
course = course_collection.find_one({"course_code": course_code})
if course:
# insert into database
result = user_collection.update_one({"uid": uid}, {"$addToSet": {"courses": course_code}})
return jsonify({"status": "success", "course": course['course_code']}), 200
else:
return jsonify({"status": "invalid course code"}), 404
except Exception as e:
return jsonify({"status": "error", "message": str(e)}), 500

@app.route("/api/delete_course_code", methods=["POST"])
def delete_course_code():
try:
data = request.get_json()

uid = data.get("uid")
if not uid:
return jsonify({"error": "No uid provided"}), 400

course_code = data["search"]
course_collection = app.config["courses"]
user_collection = app.config["profiles"]
course = course_collection.find_one({"course_code": course_code})
if course:
# insert into database
result = user_collection.update_one({"uid": uid}, {"$pull": {"courses": course_code}})
return jsonify({"status": "success", "course": course['course_code']}), 200
else:
return jsonify({"status": "invalid course code"}), 404
except Exception as e:
return jsonify({"status": "error", "message": str(e)}), 500

# must only call after checking if the user exists
@app.route("/api/user/create", methods=["POST"])
def create_user_profile():
data = request.get_json()
uid = data.get("uid", None)
if not uid:
return jsonify({"error": "No uid provided"})

collection = app.config["profiles"]

user_profile = {
"uid": uid,
"chat_history": [],
"courses": [],
"name": "",
}

result = collection.insert_one(user_profile)
user_profile['_id'] = str(result.inserted_id)

return jsonify(user_profile)

@app.route("/api/user/profile", methods=["POST"])
def get_user_profile():
data = request.get_json()
uid = data.get("uid", None)
if not uid:
return jsonify({"error": "No uid provided"})

collection = app.config["profiles"]
user_profile = collection.find_one({"uid": uid})
if not user_profile:
return jsonify({"error": "No user profile found"}), 404

# Convert ObjectId to string
user_profile['_id'] = str(user_profile['_id'])

return jsonify(user_profile)

def update_user_profile(uid, update):
collection = app.config["profiles"]
user_profile = collection.find_one({"uid": uid})
collection.update_one({"uid": uid}, {"$set": update})
user_profile = collection.find_one({"uid": uid})
user_profile['_id'] = str(user_profile['_id'])
return jsonify(user_profile)

@app.route("/api/user/update/name", methods=["POST"])
def update_user_name():
data = request.get_json()
uid = data.get("uid", None)
name = data.get("name", None)
if not uid:
return jsonify({"error": "No uid provided"})
if not name:
return jsonify({"error": "No name provided"})
return update_user_profile(uid, {"name": data.get("name")})

@app.route("/api/slug", methods=["POST"])
def get_chat_history_slug():
data = request.get_json()
Expand All @@ -136,7 +264,7 @@ def chat():
user_messages = data.get("message", None)

filter_season_codes = data.get("season_codes", None) # assume it is an array of season code
filter_subject = data.get("subject", None)
filter_subjects = data.get("subject", None)
filter_areas = data.get("areas", None)

if not user_messages:
Expand Down Expand Up @@ -224,7 +352,7 @@ def chat():

query_vector = create_embedding(response)

collection = app.config["collection"]
collection = app.config["courses"]

aggregate_pipeline = {
"$vectorSearch": {
Expand All @@ -243,10 +371,10 @@ def chat():
}
}

if filter_subject:
if filter_subjects:
aggregate_pipeline["$vectorSearch"]["filter"] = {
"subject": {
"$eq": filter_subject
"$in": filter_subjects
}
}

Expand Down
7 changes: 3 additions & 4 deletions backend/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,7 @@ def test_load_config_with_test_config(app):


def test_init_database_with_config(app):
assert "collection" in app.config

assert "courses" in app.config

@pytest.fixture
def client():
Expand All @@ -126,7 +125,7 @@ def client():
app = create_app(
{
"TESTING": True,
"collection": mock_collection,
"courses": mock_collection,
"MONGO_URL": "TEST_URL",
"COURSE_QUERY_LIMIT": 5,
"SAFETY_CHECK_ENABLED": True,
Expand Down Expand Up @@ -159,7 +158,7 @@ def client_all_disabled():
app = create_app(
{
"TESTING": True,
"collection": mock_collection,
"courses": mock_collection,
}
)
with app.test_client() as client:
Expand Down

0 comments on commit 14539d1

Please sign in to comment.