view love/poppy/apis/router.py @ 149:f41ac17926d2

[Config] Added ctags scripts and actual tags.
author June Park <parkjune1995@gmail.com>
date Sat, 10 Jan 2026 07:07:10 -0800
parents cf9caa4abc3e
children
line wrap: on
line source

from typing import Optional
from uuid import UUID
from fastapi import (
    APIRouter,
    Depends,
    HTTPException,
    status,
    WebSocket,
    WebSocketDisconnect,
    Query,
)
from sqlalchemy.orm import selectinload
from sqlmodel import select, desc, exists
import json
import asyncio
import os

from xai_sdk import Client
from xai_sdk.chat import user, system, assistant
from xai_sdk.tools import mcp, web_search, code_execution

from .schemas import MessageOut, ChatOut, ChatWithMessages, Chats
from db.models import Chat, Message, MessageAsset
from utils.database import get_session
# from utils.redis import append_message, get_all_messages_from_redis
from utils.logger import logger

import sys


GROK_API_KEY = os.getenv("XAI_API_KEY", "NO-KEY")
TOOLS = [
    web_search(),
    code_execution(),
    # mcp(
    #   server_url="https://mcp.babocoder.com/mcp",
    #   server_label="powerpoint-generator",
    #   server_description="This will create powerpoitn slides and files for you."
    # )
]
# xAI SDK Client
client = Client(
    api_key=GROK_API_KEY,
    timeout=3600, # Take this out when you are deploying to your local server.
)

router = APIRouter(tags=["chat"])

# TODO: Make this into something more useful
SYSTEM_PROMPT = """
You are a dog lover and everytime someone mentioned about their dog. You should be the most excited person alive.
"""

TITLE_GENERATION_PROMPT = """Based on the following user message, generate a short, descriptive title (max 6 words) for this chat conversation. 
Only return the title text, nothing else.

User message: {user_message}"""


async def generate_chat_title(user_message: str) -> str:
    try:
        title_client = Client(api_key=GROK_API_KEY, timeout=30)
        title_chat = title_client.chat.create(model="grok-4-fast")
        title_chat.append(system(TITLE_GENERATION_PROMPT.format(user_message=user_message)))
        
        response = title_chat.sample()
        title = response.content.strip()

        logger.info("Title: \n")
        logger.info(title)
        
        # Ensure title isn't too long
        if len(title) > 100:
            title = title[:97] + "..."
        
        return title if title else "New Chat"
    except Exception as e:
        logger.error(f"Failed to generate title: {e}")
        return "New Chat"

async def generate_image(user_message: str) -> str:
    try:
        image_client = Client(api_key=GROK_API_KEY, timeout=300)
        response = image_client.image.sample(
           model="grok-2-image",
           prompt=user_message,
           image_format="url"
       )
        logger.info("Image URL: \n")
        logger.info(response.url)
        
        return response.url
    except Exception as e:
        logger.error(f"Failed to generate title: {e}")
        return "failed image url"


async def should_update_title(chat: Chat, session) -> bool:
    stmt = select(Message).where(Message.chat_id == chat.id)
    result = await session.execute(stmt)
    message_count = len(result.scalars().all())
    logger.info(f"message_count: {message_count}")

    # Update title if we have exactly 2 messages and title is still default
    return message_count == 2 and chat.title == "New Chat"


# --- REST ---
@router.get("/chats", response_model=Chats)
async def get_chats(
    session=Depends(get_session),
    limit: int = Query(50, ge=1, le=100),
    cursor: Optional[str] = Query(None, description="Chat ID to paginate from")
):
    has_messages = exists(Message).where(Message.chat_id == Chat.id)

    stmt = select(Chat).where(has_messages).order_by(desc(Chat.created_at))
    
    # If cursor provided, filter to chats older than cursor chat
    if cursor:
        try:
            cursor_chat_id = UUID(cursor)
            cursor_chat = await session.get(Chat, cursor_chat_id)
            if cursor_chat:
                stmt = stmt.where(Chat.created_at < cursor_chat.created_at)
        except (ValueError, TypeError):
            pass  # Invalid cursor, ignore it
    
    stmt = stmt.limit(limit)

    result = await session.execute(stmt)
    recent_chats = result.scalars().all()

    return Chats(chats=recent_chats)

@router.post("/chats", response_model=ChatOut, status_code=status.HTTP_201_CREATED)
async def create_chat(session=Depends(get_session)):
    chat = Chat()
    session.add(chat)
    await session.commit()
    await session.refresh(chat)
    return chat


@router.get("/chats/{chat_id_str}/messages", response_model=ChatWithMessages)
async def get_chat_messages(chat_id_str: str, session=Depends(get_session)):
    chat_id = UUID(chat_id_str)
    chat = await session.get(Chat, chat_id)
    if not chat:
        raise HTTPException(404, "Chat not found")

    # try:
    #     all_msgs = await get_all_messages_from_redis()
    #     chat_msgs = [m for m in all_msgs if m["chatId"] == chat_id_str]
    #     if chat_msgs:
    #         messages = [
    #             MessageOut(
    #                 id=UUID(m.get("id", "00000000-0000-0000-0000-000000000000")),  # or store id in Redis!
    #                 role=m["role"],
    #                 content=m["content"],
    #                 created_at=None,
    #                 image_url=m.get("image_url"),    # ← make sure Redis also stores this
    #             )
    #             for m in chat_msgs
    #         ]
    #         return ChatWithMessages(
    #             id=chat.id,
    #             title=chat.title,
    #             created_at=chat.created_at,
    #             messages=messages,
    #         )
    # except Exception as e:
    #     pass

    stmt = (
        select(Message)
        .options(selectinload(Message.assets))
        .where(Message.chat_id == chat_id)
        .order_by(Message.created_at)
    )
    result = await session.execute(stmt)
    messages = result.scalars().all()
    return ChatWithMessages(
        id=chat.id,
        title=chat.title,
        created_at=chat.created_at,
        messages=[MessageOut.from_orm_with_assets(m) for m in messages],
    )


@router.websocket("/chats/{chat_id_str}/ws")
async def chat_stream(
    websocket: WebSocket,
    chat_id_str: str,
    session=Depends(get_session),
    use_grok: bool = Query(True),  # TODO: use url for this, but for now this is fine. 
):
    chat_id = UUID(chat_id_str)

    await websocket.accept()

    chat = None
    history = []

    # Load chat from DB
    db_chat = await session.get(Chat, chat_id)
    if not db_chat:
        await websocket.close(code=1008)
        return

    # Load Message
    if use_grok:
        try:
            all_msgs = await get_all_messages_from_redis()
            chat_msgs = [m for m in all_msgs if m["chatId"] == chat_id_str]
            if chat_msgs:
                history = [
                    {"role": m["role"], "content": m["content"]} for m in chat_msgs
                ]
            else:
                stmt = (
                    select(Message)
                    .where(Message.chat_id == chat_id)
                    .order_by(Message.created_at)
                )
                result = await session.execute(stmt)
                db_msgs = result.scalars().all()
                history = [{"role": m.role, "content": m.content} for m in db_msgs]
        except Exception as e:
            logger.error("History load failed:", e)
            history = []

        chat = client.chat.create(
            model="grok-4-fast",
            tools=TOOLS
        )
        chat.append(system(SYSTEM_PROMPT))

        # order does not matter so should be fine...
        for msg in history:
            if msg["role"] == "user":
                chat.append(user(msg["content"]))
            elif msg["role"] == "assistant":
                chat.append(assistant(msg["content"]))

    logger.info(f"This should only be called once: {history}")
    # Listen to websocket
    try:
        while True:
            full_response = ""
            assistant_msg = None

            try:
                data = await websocket.receive_text()
                # TODO: Create seralizer for this.
                user_msg = json.loads(data)
                if user_msg.get("role") != "user":
                    await websocket.close(code=1008)
                    return
                user_content = user_msg["content"]

                msg = Message(chat_id=chat_id, role="user", content=user_content)
                session.add(msg)
                await session.commit()
                await session.refresh(msg)
                # await append_message(chat_id_str, "user", user_content)

                # Actual API calls
                if use_grok:

                    # Can use clarifying agent model for this, but I feel like that would be over kill..
                    image_triggers = [
                        "/img ",
                        "generate image ",
                        "create image ",
                        "draw ",
                        "make an image of ",
                        "show me a picture of ",
                        "image: ",
                        "img: ",
                        "🖼️",
                        "picture of ",
                        "generate a photo of ",
                    ]
                    is_image_call = False
                    image_url = ""
                    for trigger in image_triggers:
                        if trigger in user_content:
                            is_image_call = True
                            break

                    if is_image_call:
                        image_url = await generate_image(user_content)
                        logger.info(f"Image got generated and we are sending assistant prompt?")
                        chat.append(assistant(f"Your image is here. Please view {image_url}!"))
                        await websocket.send_json(
                            {
                                "chatId": chat_id_str,
                                "url": f"{image_url}",
                                "action": "image",
                            }
                        )
                    else:
                        logger.info(f"User content sent: {user_content}")
                        chat.append(user(user_content))

                # AI response....
                # Maybe stream the token into redis?, but probably over kill.
                is_thinking = True
                if use_grok:
                    for response, chunk in chat.stream():
                        for tool_call in chunk.tool_calls:
                            logger.info(f"\nCalling tool: {tool_call.function.name} with arguments: {tool_call.function.arguments}")
                        # if response.usage.reasoning_tokens and is_thinking:
                            # logger.info(f"\rThinking... ({response.usage.reasoning_tokens} tokens)")
                        if chunk.content and is_thinking:
                            logger.info("\n\nAnalysis Results:")
                            is_thinking = False
                        token = chunk.content
                        if not token:
                            continue
                        full_response = response.content
                        await websocket.send_json(
                            {
                                "chatId": chat_id_str,
                                "content": token,
                                "action": "append",
                            }
                        )
                        await asyncio.sleep(0.015)
                    logger.info("\n\nCitations:")
                    logger.info(response.citations)
                    logger.info("\n\nUsage:")
                    logger.info(response.usage)
                    logger.info(response.server_side_tool_usage)
                    logger.info("\n\nServer Side Tool Calls:")
                    logger.info(response.tool_calls)
                else:
                    mock_reply = f"[Mock] Echo: {user_content}"
                    for token in mock_reply.split():
                        token += " "
                        full_response += token
                        await websocket.send_json(
                            {
                                "chatId": chat_id_str,
                                "content": token,
                                "action": "append",
                            }
                        )
                        await asyncio.sleep(0.2)

                if use_grok and full_response:
                    chat.append(assistant(full_response))

                # Async update redis
                # task = asyncio.create_task(
                #     append_message(chat_id_str, "assistant", full_response)
                # )
                # asyncio.shield(task)

                assistant_msg = Message(
                    chat_id=chat_id, role="assistant", content=full_response
                )
                session.add(assistant_msg)
                await session.commit()
                await session.refresh(assistant_msg)

                # Now create assets linked to the message
                if is_image_call and image_url:                
                    asset = MessageAsset(
                        message_id=assistant_msg.id,
                        asset_type="image",
                        url=image_url,
                    )
                    session.add(asset)
                    await session.commit()
                    await session.refresh(assistant_msg)

                # Getting titles
                # Generate title after first exchange
                if use_grok and await should_update_title(db_chat, session):
                    new_title = await generate_chat_title(user_content)
                    db_chat.title = new_title
                    session.add(db_chat)
                    await session.commit()
                    logger.info(f"Updated chat title to: {new_title}")
                    await websocket.send_json(
                        {
                            "chatId": chat_id_str,
                            "title": new_title,
                            "action": "title_updated"
                        }
                    )
                else:
                    db_chat.title = "Automatic title"
                    session.add(db_chat)
                    await session.commit()
                    await websocket.send_json(
                        {
                            "chatId": chat_id_str,
                            "title": "Automatic title",
                            "action": "title_updated"
                        }
                    )

                await websocket.send_json(
                    {"chatId": chat_id_str, "content": "", "action": "done"}
                )
            except WebSocketDisconnect:
                break
            except json.JSONDecodeError:
                await websocket.send_json({"error": "Invalid JSON"})
            except Exception as e:
                exc_type, exc_value, exc_traceback = sys.exc_info()
                line_number = exc_traceback.tb_lineno
                logger.error(
                    f"An error occurred on line {line_number}: {type(e).__name__} - {e}"
                )
                await websocket.send_json({"error": str(e)})
    except Exception as e:
        logger.error("Error: ", e)
    finally:
        pass