view love/poppy/apis/router.py @ 71:75de5903355c

Giagantic changes that update Dowa library to be more align with stb style array and hashmap. Updated Seobeo to be caching on server side instead of file level caching. Deleted bunch of things I don't really use.
author June Park <parkjune1995@gmail.com>
date Sun, 28 Dec 2025 20:34:22 -0800
parents cf9caa4abc3e
children
line wrap: on
line source

from typing import Optional
from uuid import UUID
from fastapi import (
    APIRouter,
    Depends,
    HTTPException,
    status,
    WebSocket,
    WebSocketDisconnect,
    Query,
)
from sqlalchemy.orm import selectinload
from sqlmodel import select, desc, exists
import json
import asyncio
import os

from xai_sdk import Client
from xai_sdk.chat import user, system, assistant
from xai_sdk.tools import mcp, web_search, code_execution

from .schemas import MessageOut, ChatOut, ChatWithMessages, Chats
from db.models import Chat, Message, MessageAsset
from utils.database import get_session
# from utils.redis import append_message, get_all_messages_from_redis
from utils.logger import logger

import sys


GROK_API_KEY = os.getenv("XAI_API_KEY", "NO-KEY")
TOOLS = [
    web_search(),
    code_execution(),
    # mcp(
    #   server_url="https://mcp.babocoder.com/mcp",
    #   server_label="powerpoint-generator",
    #   server_description="This will create powerpoitn slides and files for you."
    # )
]
# xAI SDK Client
client = Client(
    api_key=GROK_API_KEY,
    timeout=3600, # Take this out when you are deploying to your local server.
)

router = APIRouter(tags=["chat"])

# TODO: Make this into something more useful
SYSTEM_PROMPT = """
You are a dog lover and everytime someone mentioned about their dog. You should be the most excited person alive.
"""

TITLE_GENERATION_PROMPT = """Based on the following user message, generate a short, descriptive title (max 6 words) for this chat conversation. 
Only return the title text, nothing else.

User message: {user_message}"""


async def generate_chat_title(user_message: str) -> str:
    try:
        title_client = Client(api_key=GROK_API_KEY, timeout=30)
        title_chat = title_client.chat.create(model="grok-4-fast")
        title_chat.append(system(TITLE_GENERATION_PROMPT.format(user_message=user_message)))
        
        response = title_chat.sample()
        title = response.content.strip()

        logger.info("Title: \n")
        logger.info(title)
        
        # Ensure title isn't too long
        if len(title) > 100:
            title = title[:97] + "..."
        
        return title if title else "New Chat"
    except Exception as e:
        logger.error(f"Failed to generate title: {e}")
        return "New Chat"

async def generate_image(user_message: str) -> str:
    try:
        image_client = Client(api_key=GROK_API_KEY, timeout=300)
        response = image_client.image.sample(
           model="grok-2-image",
           prompt=user_message,
           image_format="url"
       )
        logger.info("Image URL: \n")
        logger.info(response.url)
        
        return response.url
    except Exception as e:
        logger.error(f"Failed to generate title: {e}")
        return "failed image url"


async def should_update_title(chat: Chat, session) -> bool:
    stmt = select(Message).where(Message.chat_id == chat.id)
    result = await session.execute(stmt)
    message_count = len(result.scalars().all())
    logger.info(f"message_count: {message_count}")

    # Update title if we have exactly 2 messages and title is still default
    return message_count == 2 and chat.title == "New Chat"


# --- REST ---
@router.get("/chats", response_model=Chats)
async def get_chats(
    session=Depends(get_session),
    limit: int = Query(50, ge=1, le=100),
    cursor: Optional[str] = Query(None, description="Chat ID to paginate from")
):
    has_messages = exists(Message).where(Message.chat_id == Chat.id)

    stmt = select(Chat).where(has_messages).order_by(desc(Chat.created_at))
    
    # If cursor provided, filter to chats older than cursor chat
    if cursor:
        try:
            cursor_chat_id = UUID(cursor)
            cursor_chat = await session.get(Chat, cursor_chat_id)
            if cursor_chat:
                stmt = stmt.where(Chat.created_at < cursor_chat.created_at)
        except (ValueError, TypeError):
            pass  # Invalid cursor, ignore it
    
    stmt = stmt.limit(limit)

    result = await session.execute(stmt)
    recent_chats = result.scalars().all()

    return Chats(chats=recent_chats)

@router.post("/chats", response_model=ChatOut, status_code=status.HTTP_201_CREATED)
async def create_chat(session=Depends(get_session)):
    chat = Chat()
    session.add(chat)
    await session.commit()
    await session.refresh(chat)
    return chat


@router.get("/chats/{chat_id_str}/messages", response_model=ChatWithMessages)
async def get_chat_messages(chat_id_str: str, session=Depends(get_session)):
    chat_id = UUID(chat_id_str)
    chat = await session.get(Chat, chat_id)
    if not chat:
        raise HTTPException(404, "Chat not found")

    # try:
    #     all_msgs = await get_all_messages_from_redis()
    #     chat_msgs = [m for m in all_msgs if m["chatId"] == chat_id_str]
    #     if chat_msgs:
    #         messages = [
    #             MessageOut(
    #                 id=UUID(m.get("id", "00000000-0000-0000-0000-000000000000")),  # or store id in Redis!
    #                 role=m["role"],
    #                 content=m["content"],
    #                 created_at=None,
    #                 image_url=m.get("image_url"),    # ← make sure Redis also stores this
    #             )
    #             for m in chat_msgs
    #         ]
    #         return ChatWithMessages(
    #             id=chat.id,
    #             title=chat.title,
    #             created_at=chat.created_at,
    #             messages=messages,
    #         )
    # except Exception as e:
    #     pass

    stmt = (
        select(Message)
        .options(selectinload(Message.assets))
        .where(Message.chat_id == chat_id)
        .order_by(Message.created_at)
    )
    result = await session.execute(stmt)
    messages = result.scalars().all()
    return ChatWithMessages(
        id=chat.id,
        title=chat.title,
        created_at=chat.created_at,
        messages=[MessageOut.from_orm_with_assets(m) for m in messages],
    )


@router.websocket("/chats/{chat_id_str}/ws")
async def chat_stream(
    websocket: WebSocket,
    chat_id_str: str,
    session=Depends(get_session),
    use_grok: bool = Query(True),  # TODO: use url for this, but for now this is fine. 
):
    chat_id = UUID(chat_id_str)

    await websocket.accept()

    chat = None
    history = []

    # Load chat from DB
    db_chat = await session.get(Chat, chat_id)
    if not db_chat:
        await websocket.close(code=1008)
        return

    # Load Message
    if use_grok:
        try:
            all_msgs = await get_all_messages_from_redis()
            chat_msgs = [m for m in all_msgs if m["chatId"] == chat_id_str]
            if chat_msgs:
                history = [
                    {"role": m["role"], "content": m["content"]} for m in chat_msgs
                ]
            else:
                stmt = (
                    select(Message)
                    .where(Message.chat_id == chat_id)
                    .order_by(Message.created_at)
                )
                result = await session.execute(stmt)
                db_msgs = result.scalars().all()
                history = [{"role": m.role, "content": m.content} for m in db_msgs]
        except Exception as e:
            logger.error("History load failed:", e)
            history = []

        chat = client.chat.create(
            model="grok-4-fast",
            tools=TOOLS
        )
        chat.append(system(SYSTEM_PROMPT))

        # order does not matter so should be fine...
        for msg in history:
            if msg["role"] == "user":
                chat.append(user(msg["content"]))
            elif msg["role"] == "assistant":
                chat.append(assistant(msg["content"]))

    logger.info(f"This should only be called once: {history}")
    # Listen to websocket
    try:
        while True:
            full_response = ""
            assistant_msg = None

            try:
                data = await websocket.receive_text()
                # TODO: Create seralizer for this.
                user_msg = json.loads(data)
                if user_msg.get("role") != "user":
                    await websocket.close(code=1008)
                    return
                user_content = user_msg["content"]

                msg = Message(chat_id=chat_id, role="user", content=user_content)
                session.add(msg)
                await session.commit()
                await session.refresh(msg)
                # await append_message(chat_id_str, "user", user_content)

                # Actual API calls
                if use_grok:

                    # Can use clarifying agent model for this, but I feel like that would be over kill..
                    image_triggers = [
                        "/img ",
                        "generate image ",
                        "create image ",
                        "draw ",
                        "make an image of ",
                        "show me a picture of ",
                        "image: ",
                        "img: ",
                        "🖼️",
                        "picture of ",
                        "generate a photo of ",
                    ]
                    is_image_call = False
                    image_url = ""
                    for trigger in image_triggers:
                        if trigger in user_content:
                            is_image_call = True
                            break

                    if is_image_call:
                        image_url = await generate_image(user_content)
                        logger.info(f"Image got generated and we are sending assistant prompt?")
                        chat.append(assistant(f"Your image is here. Please view {image_url}!"))
                        await websocket.send_json(
                            {
                                "chatId": chat_id_str,
                                "url": f"{image_url}",
                                "action": "image",
                            }
                        )
                    else:
                        logger.info(f"User content sent: {user_content}")
                        chat.append(user(user_content))

                # AI response....
                # Maybe stream the token into redis?, but probably over kill.
                is_thinking = True
                if use_grok:
                    for response, chunk in chat.stream():
                        for tool_call in chunk.tool_calls:
                            logger.info(f"\nCalling tool: {tool_call.function.name} with arguments: {tool_call.function.arguments}")
                        # if response.usage.reasoning_tokens and is_thinking:
                            # logger.info(f"\rThinking... ({response.usage.reasoning_tokens} tokens)")
                        if chunk.content and is_thinking:
                            logger.info("\n\nAnalysis Results:")
                            is_thinking = False
                        token = chunk.content
                        if not token:
                            continue
                        full_response = response.content
                        await websocket.send_json(
                            {
                                "chatId": chat_id_str,
                                "content": token,
                                "action": "append",
                            }
                        )
                        await asyncio.sleep(0.015)
                    logger.info("\n\nCitations:")
                    logger.info(response.citations)
                    logger.info("\n\nUsage:")
                    logger.info(response.usage)
                    logger.info(response.server_side_tool_usage)
                    logger.info("\n\nServer Side Tool Calls:")
                    logger.info(response.tool_calls)
                else:
                    mock_reply = f"[Mock] Echo: {user_content}"
                    for token in mock_reply.split():
                        token += " "
                        full_response += token
                        await websocket.send_json(
                            {
                                "chatId": chat_id_str,
                                "content": token,
                                "action": "append",
                            }
                        )
                        await asyncio.sleep(0.2)

                if use_grok and full_response:
                    chat.append(assistant(full_response))

                # Async update redis
                # task = asyncio.create_task(
                #     append_message(chat_id_str, "assistant", full_response)
                # )
                # asyncio.shield(task)

                assistant_msg = Message(
                    chat_id=chat_id, role="assistant", content=full_response
                )
                session.add(assistant_msg)
                await session.commit()
                await session.refresh(assistant_msg)

                # Now create assets linked to the message
                if is_image_call and image_url:                
                    asset = MessageAsset(
                        message_id=assistant_msg.id,
                        asset_type="image",
                        url=image_url,
                    )
                    session.add(asset)
                    await session.commit()
                    await session.refresh(assistant_msg)

                # Getting titles
                # Generate title after first exchange
                if use_grok and await should_update_title(db_chat, session):
                    new_title = await generate_chat_title(user_content)
                    db_chat.title = new_title
                    session.add(db_chat)
                    await session.commit()
                    logger.info(f"Updated chat title to: {new_title}")
                    await websocket.send_json(
                        {
                            "chatId": chat_id_str,
                            "title": new_title,
                            "action": "title_updated"
                        }
                    )
                else:
                    db_chat.title = "Automatic title"
                    session.add(db_chat)
                    await session.commit()
                    await websocket.send_json(
                        {
                            "chatId": chat_id_str,
                            "title": "Automatic title",
                            "action": "title_updated"
                        }
                    )

                await websocket.send_json(
                    {"chatId": chat_id_str, "content": "", "action": "done"}
                )
            except WebSocketDisconnect:
                break
            except json.JSONDecodeError:
                await websocket.send_json({"error": "Invalid JSON"})
            except Exception as e:
                exc_type, exc_value, exc_traceback = sys.exc_info()
                line_number = exc_traceback.tb_lineno
                logger.error(
                    f"An error occurred on line {line_number}: {type(e).__name__} - {e}"
                )
                await websocket.send_json({"error": str(e)})
    except Exception as e:
        logger.error("Error: ", e)
    finally:
        pass