Mercurial
diff love/poppy/apis/router.py @ 38:cf9caa4abc3e
[Love] FE and BE. Can chat and render images. Also created MCP for powerpoint generations.
| author | MrJuneJune <me@mrjunejune.com> |
|---|---|
| date | Mon, 01 Dec 2025 20:35:56 -0800 |
| parents | |
| children |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/love/poppy/apis/router.py Mon Dec 01 20:35:56 2025 -0800 @@ -0,0 +1,423 @@ +from typing import Optional +from uuid import UUID +from fastapi import ( + APIRouter, + Depends, + HTTPException, + status, + WebSocket, + WebSocketDisconnect, + Query, +) +from sqlalchemy.orm import selectinload +from sqlmodel import select, desc, exists +import json +import asyncio +import os + +from xai_sdk import Client +from xai_sdk.chat import user, system, assistant +from xai_sdk.tools import mcp, web_search, code_execution + +from .schemas import MessageOut, ChatOut, ChatWithMessages, Chats +from db.models import Chat, Message, MessageAsset +from utils.database import get_session +# from utils.redis import append_message, get_all_messages_from_redis +from utils.logger import logger + +import sys + + +GROK_API_KEY = os.getenv("XAI_API_KEY", "NO-KEY") +TOOLS = [ + web_search(), + code_execution(), + # mcp( + # server_url="https://mcp.babocoder.com/mcp", + # server_label="powerpoint-generator", + # server_description="This will create powerpoitn slides and files for you." + # ) +] +# xAI SDK Client +client = Client( + api_key=GROK_API_KEY, + timeout=3600, # Take this out when you are deploying to your local server. +) + +router = APIRouter(tags=["chat"]) + +# TODO: Make this into something more useful +SYSTEM_PROMPT = """ +You are a dog lover and everytime someone mentioned about their dog. You should be the most excited person alive. +""" + +TITLE_GENERATION_PROMPT = """Based on the following user message, generate a short, descriptive title (max 6 words) for this chat conversation. +Only return the title text, nothing else. + +User message: {user_message}""" + + +async def generate_chat_title(user_message: str) -> str: + try: + title_client = Client(api_key=GROK_API_KEY, timeout=30) + title_chat = title_client.chat.create(model="grok-4-fast") + title_chat.append(system(TITLE_GENERATION_PROMPT.format(user_message=user_message))) + + response = title_chat.sample() + title = response.content.strip() + + logger.info("Title: \n") + logger.info(title) + + # Ensure title isn't too long + if len(title) > 100: + title = title[:97] + "..." + + return title if title else "New Chat" + except Exception as e: + logger.error(f"Failed to generate title: {e}") + return "New Chat" + +async def generate_image(user_message: str) -> str: + try: + image_client = Client(api_key=GROK_API_KEY, timeout=300) + response = image_client.image.sample( + model="grok-2-image", + prompt=user_message, + image_format="url" + ) + logger.info("Image URL: \n") + logger.info(response.url) + + return response.url + except Exception as e: + logger.error(f"Failed to generate title: {e}") + return "failed image url" + + +async def should_update_title(chat: Chat, session) -> bool: + stmt = select(Message).where(Message.chat_id == chat.id) + result = await session.execute(stmt) + message_count = len(result.scalars().all()) + logger.info(f"message_count: {message_count}") + + # Update title if we have exactly 2 messages and title is still default + return message_count == 2 and chat.title == "New Chat" + + +# --- REST --- [email protected]("/chats", response_model=Chats) +async def get_chats( + session=Depends(get_session), + limit: int = Query(50, ge=1, le=100), + cursor: Optional[str] = Query(None, description="Chat ID to paginate from") +): + has_messages = exists(Message).where(Message.chat_id == Chat.id) + + stmt = select(Chat).where(has_messages).order_by(desc(Chat.created_at)) + + # If cursor provided, filter to chats older than cursor chat + if cursor: + try: + cursor_chat_id = UUID(cursor) + cursor_chat = await session.get(Chat, cursor_chat_id) + if cursor_chat: + stmt = stmt.where(Chat.created_at < cursor_chat.created_at) + except (ValueError, TypeError): + pass # Invalid cursor, ignore it + + stmt = stmt.limit(limit) + + result = await session.execute(stmt) + recent_chats = result.scalars().all() + + return Chats(chats=recent_chats) + [email protected]("/chats", response_model=ChatOut, status_code=status.HTTP_201_CREATED) +async def create_chat(session=Depends(get_session)): + chat = Chat() + session.add(chat) + await session.commit() + await session.refresh(chat) + return chat + + [email protected]("/chats/{chat_id_str}/messages", response_model=ChatWithMessages) +async def get_chat_messages(chat_id_str: str, session=Depends(get_session)): + chat_id = UUID(chat_id_str) + chat = await session.get(Chat, chat_id) + if not chat: + raise HTTPException(404, "Chat not found") + + # try: + # all_msgs = await get_all_messages_from_redis() + # chat_msgs = [m for m in all_msgs if m["chatId"] == chat_id_str] + # if chat_msgs: + # messages = [ + # MessageOut( + # id=UUID(m.get("id", "00000000-0000-0000-0000-000000000000")), # or store id in Redis! + # role=m["role"], + # content=m["content"], + # created_at=None, + # image_url=m.get("image_url"), # ← make sure Redis also stores this + # ) + # for m in chat_msgs + # ] + # return ChatWithMessages( + # id=chat.id, + # title=chat.title, + # created_at=chat.created_at, + # messages=messages, + # ) + # except Exception as e: + # pass + + stmt = ( + select(Message) + .options(selectinload(Message.assets)) + .where(Message.chat_id == chat_id) + .order_by(Message.created_at) + ) + result = await session.execute(stmt) + messages = result.scalars().all() + return ChatWithMessages( + id=chat.id, + title=chat.title, + created_at=chat.created_at, + messages=[MessageOut.from_orm_with_assets(m) for m in messages], + ) + + [email protected]("/chats/{chat_id_str}/ws") +async def chat_stream( + websocket: WebSocket, + chat_id_str: str, + session=Depends(get_session), + use_grok: bool = Query(True), # TODO: use url for this, but for now this is fine. +): + chat_id = UUID(chat_id_str) + + await websocket.accept() + + chat = None + history = [] + + # Load chat from DB + db_chat = await session.get(Chat, chat_id) + if not db_chat: + await websocket.close(code=1008) + return + + # Load Message + if use_grok: + try: + all_msgs = await get_all_messages_from_redis() + chat_msgs = [m for m in all_msgs if m["chatId"] == chat_id_str] + if chat_msgs: + history = [ + {"role": m["role"], "content": m["content"]} for m in chat_msgs + ] + else: + stmt = ( + select(Message) + .where(Message.chat_id == chat_id) + .order_by(Message.created_at) + ) + result = await session.execute(stmt) + db_msgs = result.scalars().all() + history = [{"role": m.role, "content": m.content} for m in db_msgs] + except Exception as e: + logger.error("History load failed:", e) + history = [] + + chat = client.chat.create( + model="grok-4-fast", + tools=TOOLS + ) + chat.append(system(SYSTEM_PROMPT)) + + # order does not matter so should be fine... + for msg in history: + if msg["role"] == "user": + chat.append(user(msg["content"])) + elif msg["role"] == "assistant": + chat.append(assistant(msg["content"])) + + logger.info(f"This should only be called once: {history}") + # Listen to websocket + try: + while True: + full_response = "" + assistant_msg = None + + try: + data = await websocket.receive_text() + # TODO: Create seralizer for this. + user_msg = json.loads(data) + if user_msg.get("role") != "user": + await websocket.close(code=1008) + return + user_content = user_msg["content"] + + msg = Message(chat_id=chat_id, role="user", content=user_content) + session.add(msg) + await session.commit() + await session.refresh(msg) + # await append_message(chat_id_str, "user", user_content) + + # Actual API calls + if use_grok: + + # Can use clarifying agent model for this, but I feel like that would be over kill.. + image_triggers = [ + "/img ", + "generate image ", + "create image ", + "draw ", + "make an image of ", + "show me a picture of ", + "image: ", + "img: ", + "🖼️", + "picture of ", + "generate a photo of ", + ] + is_image_call = False + image_url = "" + for trigger in image_triggers: + if trigger in user_content: + is_image_call = True + break + + if is_image_call: + image_url = await generate_image(user_content) + logger.info(f"Image got generated and we are sending assistant prompt?") + chat.append(assistant(f"Your image is here. Please view {image_url}!")) + await websocket.send_json( + { + "chatId": chat_id_str, + "url": f"{image_url}", + "action": "image", + } + ) + else: + logger.info(f"User content sent: {user_content}") + chat.append(user(user_content)) + + # AI response.... + # Maybe stream the token into redis?, but probably over kill. + is_thinking = True + if use_grok: + for response, chunk in chat.stream(): + for tool_call in chunk.tool_calls: + logger.info(f"\nCalling tool: {tool_call.function.name} with arguments: {tool_call.function.arguments}") + # if response.usage.reasoning_tokens and is_thinking: + # logger.info(f"\rThinking... ({response.usage.reasoning_tokens} tokens)") + if chunk.content and is_thinking: + logger.info("\n\nAnalysis Results:") + is_thinking = False + token = chunk.content + if not token: + continue + full_response = response.content + await websocket.send_json( + { + "chatId": chat_id_str, + "content": token, + "action": "append", + } + ) + await asyncio.sleep(0.015) + logger.info("\n\nCitations:") + logger.info(response.citations) + logger.info("\n\nUsage:") + logger.info(response.usage) + logger.info(response.server_side_tool_usage) + logger.info("\n\nServer Side Tool Calls:") + logger.info(response.tool_calls) + else: + mock_reply = f"[Mock] Echo: {user_content}" + for token in mock_reply.split(): + token += " " + full_response += token + await websocket.send_json( + { + "chatId": chat_id_str, + "content": token, + "action": "append", + } + ) + await asyncio.sleep(0.2) + + if use_grok and full_response: + chat.append(assistant(full_response)) + + # Async update redis + # task = asyncio.create_task( + # append_message(chat_id_str, "assistant", full_response) + # ) + # asyncio.shield(task) + + assistant_msg = Message( + chat_id=chat_id, role="assistant", content=full_response + ) + session.add(assistant_msg) + await session.commit() + await session.refresh(assistant_msg) + + # Now create assets linked to the message + if is_image_call and image_url: + asset = MessageAsset( + message_id=assistant_msg.id, + asset_type="image", + url=image_url, + ) + session.add(asset) + await session.commit() + await session.refresh(assistant_msg) + + # Getting titles + # Generate title after first exchange + if use_grok and await should_update_title(db_chat, session): + new_title = await generate_chat_title(user_content) + db_chat.title = new_title + session.add(db_chat) + await session.commit() + logger.info(f"Updated chat title to: {new_title}") + await websocket.send_json( + { + "chatId": chat_id_str, + "title": new_title, + "action": "title_updated" + } + ) + else: + db_chat.title = "Automatic title" + session.add(db_chat) + await session.commit() + await websocket.send_json( + { + "chatId": chat_id_str, + "title": "Automatic title", + "action": "title_updated" + } + ) + + await websocket.send_json( + {"chatId": chat_id_str, "content": "", "action": "done"} + ) + except WebSocketDisconnect: + break + except json.JSONDecodeError: + await websocket.send_json({"error": "Invalid JSON"}) + except Exception as e: + exc_type, exc_value, exc_traceback = sys.exc_info() + line_number = exc_traceback.tb_lineno + logger.error( + f"An error occurred on line {line_number}: {type(e).__name__} - {e}" + ) + await websocket.send_json({"error": str(e)}) + except Exception as e: + logger.error("Error: ", e) + finally: + pass