Mercurial
view love/poppy/apis/router.py @ 71:75de5903355c
Giagantic changes that update Dowa library to be more align with stb style array and hashmap. Updated Seobeo to be caching on server side instead of file level caching. Deleted bunch of things I don't really use.
| author | June Park <parkjune1995@gmail.com> |
|---|---|
| date | Sun, 28 Dec 2025 20:34:22 -0800 |
| parents | cf9caa4abc3e |
| children |
line wrap: on
line source
from typing import Optional from uuid import UUID from fastapi import ( APIRouter, Depends, HTTPException, status, WebSocket, WebSocketDisconnect, Query, ) from sqlalchemy.orm import selectinload from sqlmodel import select, desc, exists import json import asyncio import os from xai_sdk import Client from xai_sdk.chat import user, system, assistant from xai_sdk.tools import mcp, web_search, code_execution from .schemas import MessageOut, ChatOut, ChatWithMessages, Chats from db.models import Chat, Message, MessageAsset from utils.database import get_session # from utils.redis import append_message, get_all_messages_from_redis from utils.logger import logger import sys GROK_API_KEY = os.getenv("XAI_API_KEY", "NO-KEY") TOOLS = [ web_search(), code_execution(), # mcp( # server_url="https://mcp.babocoder.com/mcp", # server_label="powerpoint-generator", # server_description="This will create powerpoitn slides and files for you." # ) ] # xAI SDK Client client = Client( api_key=GROK_API_KEY, timeout=3600, # Take this out when you are deploying to your local server. ) router = APIRouter(tags=["chat"]) # TODO: Make this into something more useful SYSTEM_PROMPT = """ You are a dog lover and everytime someone mentioned about their dog. You should be the most excited person alive. """ TITLE_GENERATION_PROMPT = """Based on the following user message, generate a short, descriptive title (max 6 words) for this chat conversation. Only return the title text, nothing else. User message: {user_message}""" async def generate_chat_title(user_message: str) -> str: try: title_client = Client(api_key=GROK_API_KEY, timeout=30) title_chat = title_client.chat.create(model="grok-4-fast") title_chat.append(system(TITLE_GENERATION_PROMPT.format(user_message=user_message))) response = title_chat.sample() title = response.content.strip() logger.info("Title: \n") logger.info(title) # Ensure title isn't too long if len(title) > 100: title = title[:97] + "..." return title if title else "New Chat" except Exception as e: logger.error(f"Failed to generate title: {e}") return "New Chat" async def generate_image(user_message: str) -> str: try: image_client = Client(api_key=GROK_API_KEY, timeout=300) response = image_client.image.sample( model="grok-2-image", prompt=user_message, image_format="url" ) logger.info("Image URL: \n") logger.info(response.url) return response.url except Exception as e: logger.error(f"Failed to generate title: {e}") return "failed image url" async def should_update_title(chat: Chat, session) -> bool: stmt = select(Message).where(Message.chat_id == chat.id) result = await session.execute(stmt) message_count = len(result.scalars().all()) logger.info(f"message_count: {message_count}") # Update title if we have exactly 2 messages and title is still default return message_count == 2 and chat.title == "New Chat" # --- REST --- @router.get("/chats", response_model=Chats) async def get_chats( session=Depends(get_session), limit: int = Query(50, ge=1, le=100), cursor: Optional[str] = Query(None, description="Chat ID to paginate from") ): has_messages = exists(Message).where(Message.chat_id == Chat.id) stmt = select(Chat).where(has_messages).order_by(desc(Chat.created_at)) # If cursor provided, filter to chats older than cursor chat if cursor: try: cursor_chat_id = UUID(cursor) cursor_chat = await session.get(Chat, cursor_chat_id) if cursor_chat: stmt = stmt.where(Chat.created_at < cursor_chat.created_at) except (ValueError, TypeError): pass # Invalid cursor, ignore it stmt = stmt.limit(limit) result = await session.execute(stmt) recent_chats = result.scalars().all() return Chats(chats=recent_chats) @router.post("/chats", response_model=ChatOut, status_code=status.HTTP_201_CREATED) async def create_chat(session=Depends(get_session)): chat = Chat() session.add(chat) await session.commit() await session.refresh(chat) return chat @router.get("/chats/{chat_id_str}/messages", response_model=ChatWithMessages) async def get_chat_messages(chat_id_str: str, session=Depends(get_session)): chat_id = UUID(chat_id_str) chat = await session.get(Chat, chat_id) if not chat: raise HTTPException(404, "Chat not found") # try: # all_msgs = await get_all_messages_from_redis() # chat_msgs = [m for m in all_msgs if m["chatId"] == chat_id_str] # if chat_msgs: # messages = [ # MessageOut( # id=UUID(m.get("id", "00000000-0000-0000-0000-000000000000")), # or store id in Redis! # role=m["role"], # content=m["content"], # created_at=None, # image_url=m.get("image_url"), # ← make sure Redis also stores this # ) # for m in chat_msgs # ] # return ChatWithMessages( # id=chat.id, # title=chat.title, # created_at=chat.created_at, # messages=messages, # ) # except Exception as e: # pass stmt = ( select(Message) .options(selectinload(Message.assets)) .where(Message.chat_id == chat_id) .order_by(Message.created_at) ) result = await session.execute(stmt) messages = result.scalars().all() return ChatWithMessages( id=chat.id, title=chat.title, created_at=chat.created_at, messages=[MessageOut.from_orm_with_assets(m) for m in messages], ) @router.websocket("/chats/{chat_id_str}/ws") async def chat_stream( websocket: WebSocket, chat_id_str: str, session=Depends(get_session), use_grok: bool = Query(True), # TODO: use url for this, but for now this is fine. ): chat_id = UUID(chat_id_str) await websocket.accept() chat = None history = [] # Load chat from DB db_chat = await session.get(Chat, chat_id) if not db_chat: await websocket.close(code=1008) return # Load Message if use_grok: try: all_msgs = await get_all_messages_from_redis() chat_msgs = [m for m in all_msgs if m["chatId"] == chat_id_str] if chat_msgs: history = [ {"role": m["role"], "content": m["content"]} for m in chat_msgs ] else: stmt = ( select(Message) .where(Message.chat_id == chat_id) .order_by(Message.created_at) ) result = await session.execute(stmt) db_msgs = result.scalars().all() history = [{"role": m.role, "content": m.content} for m in db_msgs] except Exception as e: logger.error("History load failed:", e) history = [] chat = client.chat.create( model="grok-4-fast", tools=TOOLS ) chat.append(system(SYSTEM_PROMPT)) # order does not matter so should be fine... for msg in history: if msg["role"] == "user": chat.append(user(msg["content"])) elif msg["role"] == "assistant": chat.append(assistant(msg["content"])) logger.info(f"This should only be called once: {history}") # Listen to websocket try: while True: full_response = "" assistant_msg = None try: data = await websocket.receive_text() # TODO: Create seralizer for this. user_msg = json.loads(data) if user_msg.get("role") != "user": await websocket.close(code=1008) return user_content = user_msg["content"] msg = Message(chat_id=chat_id, role="user", content=user_content) session.add(msg) await session.commit() await session.refresh(msg) # await append_message(chat_id_str, "user", user_content) # Actual API calls if use_grok: # Can use clarifying agent model for this, but I feel like that would be over kill.. image_triggers = [ "/img ", "generate image ", "create image ", "draw ", "make an image of ", "show me a picture of ", "image: ", "img: ", "🖼️", "picture of ", "generate a photo of ", ] is_image_call = False image_url = "" for trigger in image_triggers: if trigger in user_content: is_image_call = True break if is_image_call: image_url = await generate_image(user_content) logger.info(f"Image got generated and we are sending assistant prompt?") chat.append(assistant(f"Your image is here. Please view {image_url}!")) await websocket.send_json( { "chatId": chat_id_str, "url": f"{image_url}", "action": "image", } ) else: logger.info(f"User content sent: {user_content}") chat.append(user(user_content)) # AI response.... # Maybe stream the token into redis?, but probably over kill. is_thinking = True if use_grok: for response, chunk in chat.stream(): for tool_call in chunk.tool_calls: logger.info(f"\nCalling tool: {tool_call.function.name} with arguments: {tool_call.function.arguments}") # if response.usage.reasoning_tokens and is_thinking: # logger.info(f"\rThinking... ({response.usage.reasoning_tokens} tokens)") if chunk.content and is_thinking: logger.info("\n\nAnalysis Results:") is_thinking = False token = chunk.content if not token: continue full_response = response.content await websocket.send_json( { "chatId": chat_id_str, "content": token, "action": "append", } ) await asyncio.sleep(0.015) logger.info("\n\nCitations:") logger.info(response.citations) logger.info("\n\nUsage:") logger.info(response.usage) logger.info(response.server_side_tool_usage) logger.info("\n\nServer Side Tool Calls:") logger.info(response.tool_calls) else: mock_reply = f"[Mock] Echo: {user_content}" for token in mock_reply.split(): token += " " full_response += token await websocket.send_json( { "chatId": chat_id_str, "content": token, "action": "append", } ) await asyncio.sleep(0.2) if use_grok and full_response: chat.append(assistant(full_response)) # Async update redis # task = asyncio.create_task( # append_message(chat_id_str, "assistant", full_response) # ) # asyncio.shield(task) assistant_msg = Message( chat_id=chat_id, role="assistant", content=full_response ) session.add(assistant_msg) await session.commit() await session.refresh(assistant_msg) # Now create assets linked to the message if is_image_call and image_url: asset = MessageAsset( message_id=assistant_msg.id, asset_type="image", url=image_url, ) session.add(asset) await session.commit() await session.refresh(assistant_msg) # Getting titles # Generate title after first exchange if use_grok and await should_update_title(db_chat, session): new_title = await generate_chat_title(user_content) db_chat.title = new_title session.add(db_chat) await session.commit() logger.info(f"Updated chat title to: {new_title}") await websocket.send_json( { "chatId": chat_id_str, "title": new_title, "action": "title_updated" } ) else: db_chat.title = "Automatic title" session.add(db_chat) await session.commit() await websocket.send_json( { "chatId": chat_id_str, "title": "Automatic title", "action": "title_updated" } ) await websocket.send_json( {"chatId": chat_id_str, "content": "", "action": "done"} ) except WebSocketDisconnect: break except json.JSONDecodeError: await websocket.send_json({"error": "Invalid JSON"}) except Exception as e: exc_type, exc_value, exc_traceback = sys.exc_info() line_number = exc_traceback.tb_lineno logger.error( f"An error occurred on line {line_number}: {type(e).__name__} - {e}" ) await websocket.send_json({"error": str(e)}) except Exception as e: logger.error("Error: ", e) finally: pass