from kusibot.database.db import db
from kusibot.database.models import Conversation, Message, Assessment, AssessmentQuestion, User
from sqlalchemy import func
from datetime import datetime, timezone
[docs]
class UserRepository:
"""Manages all data access logic for the User model."""
[docs]
def get_user_by_username(self, username):
"""
Retrieve a user by their username.
Args:
username (str): The username of the user to retrieve.
Returns:
User: The user object if found, otherwise None.
"""
try:
return db.session.query(User).filter_by(username=username).first()
except Exception as e:
print(f"Error retrieving user by username: {e}")
db.session.rollback()
return None
[docs]
def get_user_by_email(self, email):
"""
Retrieve a user by their email.
Args:
email (str): The email of the user to retrieve.
Returns:
User: The user object if found, otherwise None.
"""
try:
return db.session.query(User).filter_by(email=email).first()
except Exception as e:
print(f"Error retrieving user by email: {e}")
db.session.rollback()
return None
[docs]
def add_user(self, username, email, hashed_password, is_professional):
"""
Add a new user to the database.
Args:
username (str): The username of the new user.
email (str): The email of the new user.
hashed_password (str): The hashed password of the new user.
is_professional (bool): Whether the user is a professional user or not.
Returns:
User: The newly created user object if successful, otherwise None.
"""
try:
# Create a new professional user
professional = User(username=username,
email=email,
password=hashed_password,
created_at=datetime.now(timezone.utc),
is_professional=is_professional)
db.session.add(professional)
db.session.commit()
return professional
except Exception as e:
print(f"Error adding user: {e}")
db.session.rollback()
return None
[docs]
def get_non_professional_users(self):
"""
Retrieve all non-professional users from the database.
Returns:
list: A list of non-professional User objects.
"""
try:
return db.session.query(User).filter_by(is_professional=False).all()
except Exception as e:
print(f"Error retrieving non-professional users: {e}")
db.session.rollback()
return []
[docs]
class ConversationRepository:
"""Manages all data access logic for the Conversation model."""
[docs]
def get_current_conversation_by_user_id(self, user_id):
"""
Retrieve the current conversation (not finished) for a given user.
Args:
user_id: The ID of the user whose current conversation is to be retrieved.
Returns:
Conversation: The current conversation object if found, otherwise None.
"""
try:
return db.session.query(Conversation)\
.filter_by(user_id=user_id, finished_at=None)\
.first()
except Exception as e:
print(f"Error retrieving current conversation: {e}")
db.session.rollback()
return None
[docs]
def get_last_conversation_by_user_id(self, user_id):
"""
Retrieve the last conversation (finished or not) for a given user.
Args:
user_id: The ID of the user whose last conversation is to be retrieved.
Returns:
Conversation: The last conversation object if found, otherwise None.
"""
try:
return db.session.query(Conversation)\
.filter_by(user_id=user_id)\
.order_by(Conversation.created_at.desc())\
.first()
except Exception as e:
print(f"Error retrieving last conversation: {e}")
db.session.rollback()
return None
[docs]
def create_conversation(self, user_id):
"""
Create a new conversation for a given user.
Args:
user_id: The ID of the user for whom the conversation is to be created.
Returns:
Conversation: The newly created conversation object if successful, otherwise None.
"""
try:
new_conversation = Conversation(user_id=user_id,
created_at=datetime.now(timezone.utc))
db.session.add(new_conversation)
db.session.commit()
return new_conversation
except Exception as e:
print(f"Error creating conversation: {e}")
db.session.rollback()
return None
[docs]
def get_conversation(self, conv_id):
"""
Retrieve a conversation by its ID.
Args:
conv_id: The ID of the conversation to retrieve.
Returns:
Conversation: The conversation object if found, otherwise None.
"""
try:
return db.session.query(Conversation).filter_by(id=conv_id).first()
except Exception as e:
print(f"Error retrieving conversation: {e}")
db.session.rollback()
return None
[docs]
def end_conversation(self, conv_id):
"""
End the given conversation by setting its finished_at timestamp.
Args:
conv_id: The ID of the conversation to end.
"""
try:
conversation = self.get_conversation(conv_id)
if conversation:
conversation.finished_at = datetime.now(timezone.utc)
db.session.commit()
except Exception as e:
print(f"Error ending conversation: {e}")
db.session.rollback()
[docs]
def get_all_conversations_by_user_id(self, user_id):
"""
Retrieve all conversations for a given user, ordered by creation time in descending order.
Args:
user_id: The ID of the user whose conversations are to be retrieved.
Returns:
list: A list of Conversation objects.
"""
try:
return db.session.query(Conversation)\
.filter_by(user_id=user_id)\
.order_by(Conversation.created_at.desc())\
.all()
except Exception as e:
print(f"Error retrieving conversations: {e}")
db.session.rollback()
return []
[docs]
class MessageRepository:
"""Manages all data access logic for the Message model."""
[docs]
def save_chatbot_message(self, conv_id, msg, intent=None, agent_type="Conversation"):
"""
Save a chatbot message to the conversation stored in database.
Args:
conv_id: The ID of the conversation to which the message belongs.
msg: The text of the chatbot message.
intent: The intent of the message, if applicable (generally, it does not).
agent_type: The type of agent sending the message (default is the ConversationAgent).
"""
try:
message = Message(
conversation_id=conv_id,
text=msg,
timestamp=datetime.now(timezone.utc),
is_user=False,
intent=intent,
agent_type=agent_type
)
db.session.add(message)
db.session.commit()
except Exception as e:
print(f"Error saving chatbot message: {e}")
db.session.rollback()
[docs]
def save_user_message(self, conv_id, msg, intent=None):
"""
Save a user message to the conversation stored in database.
Args:
conv_id: The ID of the conversation to which the message belongs.
msg: The text of the user message.
intent: The intent of the message, if applicable.
"""
try:
message = Message(
conversation_id=conv_id,
text=msg,
timestamp=datetime.now(timezone.utc),
is_user=True,
intent=intent,
)
db.session.add(message)
db.session.commit()
except Exception as e:
print(f"Error saving user message: {e}")
db.session.rollback()
[docs]
def get_limited_messages(self, conv_id, limit):
"""
Retrieve the last <limit> messages from a conversation, ordered by timestamp.
Args:
conv_id: The ID of the conversation from which to retrieve messages.
limit: The maximum number of messages to retrieve.
Returns:
list: A list of Message objects.
"""
try:
return db.session.query(Message)\
.filter_by(conversation_id=conv_id)\
.order_by(Message.timestamp.desc())\
.limit(limit)\
.all()
except Exception as e:
print(f"Error retrieving messages: {e}")
db.session.rollback()
return []
[docs]
def get_messages_by_conversation_id(self, conv_id):
"""
Retrieve all messages from a conversation, ordered by timestamp.
Args:
conv_id: The ID of the conversation from which to retrieve messages.
Returns:
list: A list of Message objects.
"""
try:
return db.session.query(Message)\
.filter_by(conversation_id=conv_id)\
.order_by(Message.timestamp)\
.all()
except Exception as e:
print(f"Error retrieving messages: {e}")
db.session.rollback()
return []
[docs]
class AssessmentRepository:
"""Manages all data access logic for the Assessment model."""
[docs]
def get_current_assessment(self, user_id):
"""
Retrieve the current assessment (not finished) for a given user.
Args:
user_id: The ID of the user whose current assessment is to be retrieved.
Returns:
Assessment: The current assessment object if found, otherwise None.
"""
try:
return db.session.query(Assessment)\
.filter_by(user_id=user_id, end_time=None)\
.first()
except Exception as e:
print(f"Error retrieving current assessment: {e}")
db.session.rollback()
return None
[docs]
def is_assessment_active(self, user_id):
"""
Check if there is an active assessment for the given user.
Args:
user_id: The ID of the user to check.
Returns:
bool: True if an active assessment exists, False otherwise.
"""
return self.get_current_assessment(user_id) is not None
[docs]
def get_assessment(self, assessment_id):
"""
Retrieve an assessment by its ID.
Args:
assessment_id: The ID of the assessment to retrieve.
Returns:
Assessment: The assessment object if found, otherwise None.
"""
try:
return db.session.query(Assessment).filter_by(id=assessment_id).first()
except Exception as e:
print(f"Error retrieving assessment: {e}")
db.session.rollback()
return None
[docs]
def create_assessment(self, message_trigger, user_id, assessment_type, state):
"""
Create a new assessment for a given user.
Args:
user_id: The ID of the user for whom the assessment is to be created.
assessment_type: The type of the assessment (e.g., "PHQ9", "GAD7").
state: The initial state of the assessment (e.g., "AskingQuestion state").
Returns:
Assessment: The newly created assessment object if successful, otherwise None.
"""
try:
new_assessment = Assessment(
user_id=user_id,
assessment_type=assessment_type,
message_trigger=message_trigger,
start_time=datetime.now(timezone.utc),
current_state=state
)
db.session.add(new_assessment)
db.session.commit()
return new_assessment
except Exception as e:
print(f"Error creating assessment: {e}")
db.session.rollback()
return None
[docs]
def update_assessment(self, assessment_id, **kwargs):
"""
Update an existing assessment with new values.
Args:
assessment_id: The ID of the assessment to update.
**kwargs: The fields to update and their new values.
"""
try:
assessment = self.get_assessment(assessment_id)
if assessment:
for key, value in kwargs.items():
setattr(assessment, key, value)
db.session.commit()
except Exception as e:
print(f"Error updating assessment: {e}")
db.session.rollback()
[docs]
def calculate_total_score(self, assessment_id):
"""
Calculate the total score of an assessment by summing the values of its questions.
Args:
assessment_id: The ID of the assessment for which to calculate the total score.
Returns:
int: The total score of the assessment, or 0 if no questions are found or an error occurred.
"""
try:
total_score = db.session.query(func.sum(AssessmentQuestion.categorized_value))\
.filter(AssessmentQuestion.assessment_id == assessment_id)\
.scalar()
return total_score if total_score is not None else 0
except Exception as e:
print(f"Error calculating total score: {e}")
db.session.rollback()
return 0
[docs]
def get_assessments_by_user_id(self, user_id):
"""
Retrieve all assessments for a given user, ordered by start time in descending order.
Args:
user_id: The ID of the user whose assessments are to be retrieved.
Returns:
list: A list of Assessment objects.
"""
try:
return db.session.query(Assessment)\
.filter_by(user_id=user_id)\
.order_by(Assessment.start_time.desc())\
.all()
except Exception as e:
print(f"Error retrieving assessments: {e}")
db.session.rollback()
return []
[docs]
def end_assessment(self, user_id):
"""
End the current assessment for a given user by setting its end_time.
Args:
user_id: The ID of the user whose assessment is to be ended.
"""
try:
assessment = self.get_current_assessment(user_id)
if assessment:
assessment.end_time = datetime.now(timezone.utc)
db.session.commit()
except Exception as e:
print(f"Error ending assessment: {e}")
db.session.rollback()
[docs]
class AssessmentQuestionRepository:
"""Manages all data access logic for the AssessmentQuestion model."""
[docs]
def get_question_by_assessment_id(self, assessment_id):
"""
Retrieve all questions for a given assessment, ordered by question number.
Args:
assessment_id: The ID of the assessment whose questions are to be retrieved.
Returns:
list: A list of AssessmentQuestion objects.
"""
try:
return db.session.query(AssessmentQuestion)\
.filter_by(assessment_id=assessment_id)\
.order_by(AssessmentQuestion.question_number)\
.all()
except Exception as e:
print(f"Error retrieving questions: {e}")
db.session.rollback()
return []
[docs]
def save_assessment_question(self, assessment_id, question_number, question_text, user_response, categorized_value):
"""
Save an assessment question to the database.
Args:
assessment_id: The ID of the assessment to which the question belongs.
question_number: The number of the question in the assessment.
question_text: The text of the question.
user_response: The user's free-text response to the question.
categorized_value: The categorized value of the user's response (depends on questionnaire).
"""
try:
assessment_question = AssessmentQuestion(
assessment_id=assessment_id,
question_number=question_number,
question_text=question_text,
user_response=user_response,
categorized_value=categorized_value,
timestamp=datetime.now(timezone.utc)
)
db.session.add(assessment_question)
db.session.commit()
except Exception as e:
print(f"Error saving assessment question: {e}")
db.session.rollback()