diff love/poppy/apis/router.py @ 38:cf9caa4abc3e

[Love] FE and BE. Can chat and render images. Also created MCP for powerpoint generations.
author MrJuneJune <me@mrjunejune.com>
date Mon, 01 Dec 2025 20:35:56 -0800
parents
children
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/love/poppy/apis/router.py	Mon Dec 01 20:35:56 2025 -0800
@@ -0,0 +1,423 @@
+from typing import Optional
+from uuid import UUID
+from fastapi import (
+    APIRouter,
+    Depends,
+    HTTPException,
+    status,
+    WebSocket,
+    WebSocketDisconnect,
+    Query,
+)
+from sqlalchemy.orm import selectinload
+from sqlmodel import select, desc, exists
+import json
+import asyncio
+import os
+
+from xai_sdk import Client
+from xai_sdk.chat import user, system, assistant
+from xai_sdk.tools import mcp, web_search, code_execution
+
+from .schemas import MessageOut, ChatOut, ChatWithMessages, Chats
+from db.models import Chat, Message, MessageAsset
+from utils.database import get_session
+# from utils.redis import append_message, get_all_messages_from_redis
+from utils.logger import logger
+
+import sys
+
+
+GROK_API_KEY = os.getenv("XAI_API_KEY", "NO-KEY")
+TOOLS = [
+    web_search(),
+    code_execution(),
+    # mcp(
+    #   server_url="https://mcp.babocoder.com/mcp",
+    #   server_label="powerpoint-generator",
+    #   server_description="This will create powerpoitn slides and files for you."
+    # )
+]
+# xAI SDK Client
+client = Client(
+    api_key=GROK_API_KEY,
+    timeout=3600, # Take this out when you are deploying to your local server.
+)
+
+router = APIRouter(tags=["chat"])
+
+# TODO: Make this into something more useful
+SYSTEM_PROMPT = """
+You are a dog lover and everytime someone mentioned about their dog. You should be the most excited person alive.
+"""
+
+TITLE_GENERATION_PROMPT = """Based on the following user message, generate a short, descriptive title (max 6 words) for this chat conversation. 
+Only return the title text, nothing else.
+
+User message: {user_message}"""
+
+
+async def generate_chat_title(user_message: str) -> str:
+    try:
+        title_client = Client(api_key=GROK_API_KEY, timeout=30)
+        title_chat = title_client.chat.create(model="grok-4-fast")
+        title_chat.append(system(TITLE_GENERATION_PROMPT.format(user_message=user_message)))
+        
+        response = title_chat.sample()
+        title = response.content.strip()
+
+        logger.info("Title: \n")
+        logger.info(title)
+        
+        # Ensure title isn't too long
+        if len(title) > 100:
+            title = title[:97] + "..."
+        
+        return title if title else "New Chat"
+    except Exception as e:
+        logger.error(f"Failed to generate title: {e}")
+        return "New Chat"
+
+async def generate_image(user_message: str) -> str:
+    try:
+        image_client = Client(api_key=GROK_API_KEY, timeout=300)
+        response = image_client.image.sample(
+           model="grok-2-image",
+           prompt=user_message,
+           image_format="url"
+       )
+        logger.info("Image URL: \n")
+        logger.info(response.url)
+        
+        return response.url
+    except Exception as e:
+        logger.error(f"Failed to generate title: {e}")
+        return "failed image url"
+
+
+async def should_update_title(chat: Chat, session) -> bool:
+    stmt = select(Message).where(Message.chat_id == chat.id)
+    result = await session.execute(stmt)
+    message_count = len(result.scalars().all())
+    logger.info(f"message_count: {message_count}")
+
+    # Update title if we have exactly 2 messages and title is still default
+    return message_count == 2 and chat.title == "New Chat"
+
+
+# --- REST ---
[email protected]("/chats", response_model=Chats)
+async def get_chats(
+    session=Depends(get_session),
+    limit: int = Query(50, ge=1, le=100),
+    cursor: Optional[str] = Query(None, description="Chat ID to paginate from")
+):
+    has_messages = exists(Message).where(Message.chat_id == Chat.id)
+
+    stmt = select(Chat).where(has_messages).order_by(desc(Chat.created_at))
+    
+    # If cursor provided, filter to chats older than cursor chat
+    if cursor:
+        try:
+            cursor_chat_id = UUID(cursor)
+            cursor_chat = await session.get(Chat, cursor_chat_id)
+            if cursor_chat:
+                stmt = stmt.where(Chat.created_at < cursor_chat.created_at)
+        except (ValueError, TypeError):
+            pass  # Invalid cursor, ignore it
+    
+    stmt = stmt.limit(limit)
+
+    result = await session.execute(stmt)
+    recent_chats = result.scalars().all()
+
+    return Chats(chats=recent_chats)
+
[email protected]("/chats", response_model=ChatOut, status_code=status.HTTP_201_CREATED)
+async def create_chat(session=Depends(get_session)):
+    chat = Chat()
+    session.add(chat)
+    await session.commit()
+    await session.refresh(chat)
+    return chat
+
+
[email protected]("/chats/{chat_id_str}/messages", response_model=ChatWithMessages)
+async def get_chat_messages(chat_id_str: str, session=Depends(get_session)):
+    chat_id = UUID(chat_id_str)
+    chat = await session.get(Chat, chat_id)
+    if not chat:
+        raise HTTPException(404, "Chat not found")
+
+    # try:
+    #     all_msgs = await get_all_messages_from_redis()
+    #     chat_msgs = [m for m in all_msgs if m["chatId"] == chat_id_str]
+    #     if chat_msgs:
+    #         messages = [
+    #             MessageOut(
+    #                 id=UUID(m.get("id", "00000000-0000-0000-0000-000000000000")),  # or store id in Redis!
+    #                 role=m["role"],
+    #                 content=m["content"],
+    #                 created_at=None,
+    #                 image_url=m.get("image_url"),    # ← make sure Redis also stores this
+    #             )
+    #             for m in chat_msgs
+    #         ]
+    #         return ChatWithMessages(
+    #             id=chat.id,
+    #             title=chat.title,
+    #             created_at=chat.created_at,
+    #             messages=messages,
+    #         )
+    # except Exception as e:
+    #     pass
+
+    stmt = (
+        select(Message)
+        .options(selectinload(Message.assets))
+        .where(Message.chat_id == chat_id)
+        .order_by(Message.created_at)
+    )
+    result = await session.execute(stmt)
+    messages = result.scalars().all()
+    return ChatWithMessages(
+        id=chat.id,
+        title=chat.title,
+        created_at=chat.created_at,
+        messages=[MessageOut.from_orm_with_assets(m) for m in messages],
+    )
+
+
[email protected]("/chats/{chat_id_str}/ws")
+async def chat_stream(
+    websocket: WebSocket,
+    chat_id_str: str,
+    session=Depends(get_session),
+    use_grok: bool = Query(True),  # TODO: use url for this, but for now this is fine. 
+):
+    chat_id = UUID(chat_id_str)
+
+    await websocket.accept()
+
+    chat = None
+    history = []
+
+    # Load chat from DB
+    db_chat = await session.get(Chat, chat_id)
+    if not db_chat:
+        await websocket.close(code=1008)
+        return
+
+    # Load Message
+    if use_grok:
+        try:
+            all_msgs = await get_all_messages_from_redis()
+            chat_msgs = [m for m in all_msgs if m["chatId"] == chat_id_str]
+            if chat_msgs:
+                history = [
+                    {"role": m["role"], "content": m["content"]} for m in chat_msgs
+                ]
+            else:
+                stmt = (
+                    select(Message)
+                    .where(Message.chat_id == chat_id)
+                    .order_by(Message.created_at)
+                )
+                result = await session.execute(stmt)
+                db_msgs = result.scalars().all()
+                history = [{"role": m.role, "content": m.content} for m in db_msgs]
+        except Exception as e:
+            logger.error("History load failed:", e)
+            history = []
+
+        chat = client.chat.create(
+            model="grok-4-fast",
+            tools=TOOLS
+        )
+        chat.append(system(SYSTEM_PROMPT))
+
+        # order does not matter so should be fine...
+        for msg in history:
+            if msg["role"] == "user":
+                chat.append(user(msg["content"]))
+            elif msg["role"] == "assistant":
+                chat.append(assistant(msg["content"]))
+
+    logger.info(f"This should only be called once: {history}")
+    # Listen to websocket
+    try:
+        while True:
+            full_response = ""
+            assistant_msg = None
+
+            try:
+                data = await websocket.receive_text()
+                # TODO: Create seralizer for this.
+                user_msg = json.loads(data)
+                if user_msg.get("role") != "user":
+                    await websocket.close(code=1008)
+                    return
+                user_content = user_msg["content"]
+
+                msg = Message(chat_id=chat_id, role="user", content=user_content)
+                session.add(msg)
+                await session.commit()
+                await session.refresh(msg)
+                # await append_message(chat_id_str, "user", user_content)
+
+                # Actual API calls
+                if use_grok:
+
+                    # Can use clarifying agent model for this, but I feel like that would be over kill..
+                    image_triggers = [
+                        "/img ",
+                        "generate image ",
+                        "create image ",
+                        "draw ",
+                        "make an image of ",
+                        "show me a picture of ",
+                        "image: ",
+                        "img: ",
+                        "🖼️",
+                        "picture of ",
+                        "generate a photo of ",
+                    ]
+                    is_image_call = False
+                    image_url = ""
+                    for trigger in image_triggers:
+                        if trigger in user_content:
+                            is_image_call = True
+                            break
+
+                    if is_image_call:
+                        image_url = await generate_image(user_content)
+                        logger.info(f"Image got generated and we are sending assistant prompt?")
+                        chat.append(assistant(f"Your image is here. Please view {image_url}!"))
+                        await websocket.send_json(
+                            {
+                                "chatId": chat_id_str,
+                                "url": f"{image_url}",
+                                "action": "image",
+                            }
+                        )
+                    else:
+                        logger.info(f"User content sent: {user_content}")
+                        chat.append(user(user_content))
+
+                # AI response....
+                # Maybe stream the token into redis?, but probably over kill.
+                is_thinking = True
+                if use_grok:
+                    for response, chunk in chat.stream():
+                        for tool_call in chunk.tool_calls:
+                            logger.info(f"\nCalling tool: {tool_call.function.name} with arguments: {tool_call.function.arguments}")
+                        # if response.usage.reasoning_tokens and is_thinking:
+                            # logger.info(f"\rThinking... ({response.usage.reasoning_tokens} tokens)")
+                        if chunk.content and is_thinking:
+                            logger.info("\n\nAnalysis Results:")
+                            is_thinking = False
+                        token = chunk.content
+                        if not token:
+                            continue
+                        full_response = response.content
+                        await websocket.send_json(
+                            {
+                                "chatId": chat_id_str,
+                                "content": token,
+                                "action": "append",
+                            }
+                        )
+                        await asyncio.sleep(0.015)
+                    logger.info("\n\nCitations:")
+                    logger.info(response.citations)
+                    logger.info("\n\nUsage:")
+                    logger.info(response.usage)
+                    logger.info(response.server_side_tool_usage)
+                    logger.info("\n\nServer Side Tool Calls:")
+                    logger.info(response.tool_calls)
+                else:
+                    mock_reply = f"[Mock] Echo: {user_content}"
+                    for token in mock_reply.split():
+                        token += " "
+                        full_response += token
+                        await websocket.send_json(
+                            {
+                                "chatId": chat_id_str,
+                                "content": token,
+                                "action": "append",
+                            }
+                        )
+                        await asyncio.sleep(0.2)
+
+                if use_grok and full_response:
+                    chat.append(assistant(full_response))
+
+                # Async update redis
+                # task = asyncio.create_task(
+                #     append_message(chat_id_str, "assistant", full_response)
+                # )
+                # asyncio.shield(task)
+
+                assistant_msg = Message(
+                    chat_id=chat_id, role="assistant", content=full_response
+                )
+                session.add(assistant_msg)
+                await session.commit()
+                await session.refresh(assistant_msg)
+
+                # Now create assets linked to the message
+                if is_image_call and image_url:                
+                    asset = MessageAsset(
+                        message_id=assistant_msg.id,
+                        asset_type="image",
+                        url=image_url,
+                    )
+                    session.add(asset)
+                    await session.commit()
+                    await session.refresh(assistant_msg)
+
+                # Getting titles
+                # Generate title after first exchange
+                if use_grok and await should_update_title(db_chat, session):
+                    new_title = await generate_chat_title(user_content)
+                    db_chat.title = new_title
+                    session.add(db_chat)
+                    await session.commit()
+                    logger.info(f"Updated chat title to: {new_title}")
+                    await websocket.send_json(
+                        {
+                            "chatId": chat_id_str,
+                            "title": new_title,
+                            "action": "title_updated"
+                        }
+                    )
+                else:
+                    db_chat.title = "Automatic title"
+                    session.add(db_chat)
+                    await session.commit()
+                    await websocket.send_json(
+                        {
+                            "chatId": chat_id_str,
+                            "title": "Automatic title",
+                            "action": "title_updated"
+                        }
+                    )
+
+                await websocket.send_json(
+                    {"chatId": chat_id_str, "content": "", "action": "done"}
+                )
+            except WebSocketDisconnect:
+                break
+            except json.JSONDecodeError:
+                await websocket.send_json({"error": "Invalid JSON"})
+            except Exception as e:
+                exc_type, exc_value, exc_traceback = sys.exc_info()
+                line_number = exc_traceback.tb_lineno
+                logger.error(
+                    f"An error occurred on line {line_number}: {type(e).__name__} - {e}"
+                )
+                await websocket.send_json({"error": str(e)})
+    except Exception as e:
+        logger.error("Error: ", e)
+    finally:
+        pass