Mercurial
view love/poppy/apis/router.py @ 149:f41ac17926d2
[Config] Added ctags scripts and actual tags.
| author | June Park <parkjune1995@gmail.com> |
|---|---|
| date | Sat, 10 Jan 2026 07:07:10 -0800 |
| parents | cf9caa4abc3e |
| children |
line wrap: on
line source
from typing import Optional from uuid import UUID from fastapi import ( APIRouter, Depends, HTTPException, status, WebSocket, WebSocketDisconnect, Query, ) from sqlalchemy.orm import selectinload from sqlmodel import select, desc, exists import json import asyncio import os from xai_sdk import Client from xai_sdk.chat import user, system, assistant from xai_sdk.tools import mcp, web_search, code_execution from .schemas import MessageOut, ChatOut, ChatWithMessages, Chats from db.models import Chat, Message, MessageAsset from utils.database import get_session # from utils.redis import append_message, get_all_messages_from_redis from utils.logger import logger import sys GROK_API_KEY = os.getenv("XAI_API_KEY", "NO-KEY") TOOLS = [ web_search(), code_execution(), # mcp( # server_url="https://mcp.babocoder.com/mcp", # server_label="powerpoint-generator", # server_description="This will create powerpoitn slides and files for you." # ) ] # xAI SDK Client client = Client( api_key=GROK_API_KEY, timeout=3600, # Take this out when you are deploying to your local server. ) router = APIRouter(tags=["chat"]) # TODO: Make this into something more useful SYSTEM_PROMPT = """ You are a dog lover and everytime someone mentioned about their dog. You should be the most excited person alive. """ TITLE_GENERATION_PROMPT = """Based on the following user message, generate a short, descriptive title (max 6 words) for this chat conversation. Only return the title text, nothing else. User message: {user_message}""" async def generate_chat_title(user_message: str) -> str: try: title_client = Client(api_key=GROK_API_KEY, timeout=30) title_chat = title_client.chat.create(model="grok-4-fast") title_chat.append(system(TITLE_GENERATION_PROMPT.format(user_message=user_message))) response = title_chat.sample() title = response.content.strip() logger.info("Title: \n") logger.info(title) # Ensure title isn't too long if len(title) > 100: title = title[:97] + "..." return title if title else "New Chat" except Exception as e: logger.error(f"Failed to generate title: {e}") return "New Chat" async def generate_image(user_message: str) -> str: try: image_client = Client(api_key=GROK_API_KEY, timeout=300) response = image_client.image.sample( model="grok-2-image", prompt=user_message, image_format="url" ) logger.info("Image URL: \n") logger.info(response.url) return response.url except Exception as e: logger.error(f"Failed to generate title: {e}") return "failed image url" async def should_update_title(chat: Chat, session) -> bool: stmt = select(Message).where(Message.chat_id == chat.id) result = await session.execute(stmt) message_count = len(result.scalars().all()) logger.info(f"message_count: {message_count}") # Update title if we have exactly 2 messages and title is still default return message_count == 2 and chat.title == "New Chat" # --- REST --- @router.get("/chats", response_model=Chats) async def get_chats( session=Depends(get_session), limit: int = Query(50, ge=1, le=100), cursor: Optional[str] = Query(None, description="Chat ID to paginate from") ): has_messages = exists(Message).where(Message.chat_id == Chat.id) stmt = select(Chat).where(has_messages).order_by(desc(Chat.created_at)) # If cursor provided, filter to chats older than cursor chat if cursor: try: cursor_chat_id = UUID(cursor) cursor_chat = await session.get(Chat, cursor_chat_id) if cursor_chat: stmt = stmt.where(Chat.created_at < cursor_chat.created_at) except (ValueError, TypeError): pass # Invalid cursor, ignore it stmt = stmt.limit(limit) result = await session.execute(stmt) recent_chats = result.scalars().all() return Chats(chats=recent_chats) @router.post("/chats", response_model=ChatOut, status_code=status.HTTP_201_CREATED) async def create_chat(session=Depends(get_session)): chat = Chat() session.add(chat) await session.commit() await session.refresh(chat) return chat @router.get("/chats/{chat_id_str}/messages", response_model=ChatWithMessages) async def get_chat_messages(chat_id_str: str, session=Depends(get_session)): chat_id = UUID(chat_id_str) chat = await session.get(Chat, chat_id) if not chat: raise HTTPException(404, "Chat not found") # try: # all_msgs = await get_all_messages_from_redis() # chat_msgs = [m for m in all_msgs if m["chatId"] == chat_id_str] # if chat_msgs: # messages = [ # MessageOut( # id=UUID(m.get("id", "00000000-0000-0000-0000-000000000000")), # or store id in Redis! # role=m["role"], # content=m["content"], # created_at=None, # image_url=m.get("image_url"), # ← make sure Redis also stores this # ) # for m in chat_msgs # ] # return ChatWithMessages( # id=chat.id, # title=chat.title, # created_at=chat.created_at, # messages=messages, # ) # except Exception as e: # pass stmt = ( select(Message) .options(selectinload(Message.assets)) .where(Message.chat_id == chat_id) .order_by(Message.created_at) ) result = await session.execute(stmt) messages = result.scalars().all() return ChatWithMessages( id=chat.id, title=chat.title, created_at=chat.created_at, messages=[MessageOut.from_orm_with_assets(m) for m in messages], ) @router.websocket("/chats/{chat_id_str}/ws") async def chat_stream( websocket: WebSocket, chat_id_str: str, session=Depends(get_session), use_grok: bool = Query(True), # TODO: use url for this, but for now this is fine. ): chat_id = UUID(chat_id_str) await websocket.accept() chat = None history = [] # Load chat from DB db_chat = await session.get(Chat, chat_id) if not db_chat: await websocket.close(code=1008) return # Load Message if use_grok: try: all_msgs = await get_all_messages_from_redis() chat_msgs = [m for m in all_msgs if m["chatId"] == chat_id_str] if chat_msgs: history = [ {"role": m["role"], "content": m["content"]} for m in chat_msgs ] else: stmt = ( select(Message) .where(Message.chat_id == chat_id) .order_by(Message.created_at) ) result = await session.execute(stmt) db_msgs = result.scalars().all() history = [{"role": m.role, "content": m.content} for m in db_msgs] except Exception as e: logger.error("History load failed:", e) history = [] chat = client.chat.create( model="grok-4-fast", tools=TOOLS ) chat.append(system(SYSTEM_PROMPT)) # order does not matter so should be fine... for msg in history: if msg["role"] == "user": chat.append(user(msg["content"])) elif msg["role"] == "assistant": chat.append(assistant(msg["content"])) logger.info(f"This should only be called once: {history}") # Listen to websocket try: while True: full_response = "" assistant_msg = None try: data = await websocket.receive_text() # TODO: Create seralizer for this. user_msg = json.loads(data) if user_msg.get("role") != "user": await websocket.close(code=1008) return user_content = user_msg["content"] msg = Message(chat_id=chat_id, role="user", content=user_content) session.add(msg) await session.commit() await session.refresh(msg) # await append_message(chat_id_str, "user", user_content) # Actual API calls if use_grok: # Can use clarifying agent model for this, but I feel like that would be over kill.. image_triggers = [ "/img ", "generate image ", "create image ", "draw ", "make an image of ", "show me a picture of ", "image: ", "img: ", "🖼️", "picture of ", "generate a photo of ", ] is_image_call = False image_url = "" for trigger in image_triggers: if trigger in user_content: is_image_call = True break if is_image_call: image_url = await generate_image(user_content) logger.info(f"Image got generated and we are sending assistant prompt?") chat.append(assistant(f"Your image is here. Please view {image_url}!")) await websocket.send_json( { "chatId": chat_id_str, "url": f"{image_url}", "action": "image", } ) else: logger.info(f"User content sent: {user_content}") chat.append(user(user_content)) # AI response.... # Maybe stream the token into redis?, but probably over kill. is_thinking = True if use_grok: for response, chunk in chat.stream(): for tool_call in chunk.tool_calls: logger.info(f"\nCalling tool: {tool_call.function.name} with arguments: {tool_call.function.arguments}") # if response.usage.reasoning_tokens and is_thinking: # logger.info(f"\rThinking... ({response.usage.reasoning_tokens} tokens)") if chunk.content and is_thinking: logger.info("\n\nAnalysis Results:") is_thinking = False token = chunk.content if not token: continue full_response = response.content await websocket.send_json( { "chatId": chat_id_str, "content": token, "action": "append", } ) await asyncio.sleep(0.015) logger.info("\n\nCitations:") logger.info(response.citations) logger.info("\n\nUsage:") logger.info(response.usage) logger.info(response.server_side_tool_usage) logger.info("\n\nServer Side Tool Calls:") logger.info(response.tool_calls) else: mock_reply = f"[Mock] Echo: {user_content}" for token in mock_reply.split(): token += " " full_response += token await websocket.send_json( { "chatId": chat_id_str, "content": token, "action": "append", } ) await asyncio.sleep(0.2) if use_grok and full_response: chat.append(assistant(full_response)) # Async update redis # task = asyncio.create_task( # append_message(chat_id_str, "assistant", full_response) # ) # asyncio.shield(task) assistant_msg = Message( chat_id=chat_id, role="assistant", content=full_response ) session.add(assistant_msg) await session.commit() await session.refresh(assistant_msg) # Now create assets linked to the message if is_image_call and image_url: asset = MessageAsset( message_id=assistant_msg.id, asset_type="image", url=image_url, ) session.add(asset) await session.commit() await session.refresh(assistant_msg) # Getting titles # Generate title after first exchange if use_grok and await should_update_title(db_chat, session): new_title = await generate_chat_title(user_content) db_chat.title = new_title session.add(db_chat) await session.commit() logger.info(f"Updated chat title to: {new_title}") await websocket.send_json( { "chatId": chat_id_str, "title": new_title, "action": "title_updated" } ) else: db_chat.title = "Automatic title" session.add(db_chat) await session.commit() await websocket.send_json( { "chatId": chat_id_str, "title": "Automatic title", "action": "title_updated" } ) await websocket.send_json( {"chatId": chat_id_str, "content": "", "action": "done"} ) except WebSocketDisconnect: break except json.JSONDecodeError: await websocket.send_json({"error": "Invalid JSON"}) except Exception as e: exc_type, exc_value, exc_traceback = sys.exc_info() line_number = exc_traceback.tb_lineno logger.error( f"An error occurred on line {line_number}: {type(e).__name__} - {e}" ) await websocket.send_json({"error": str(e)}) except Exception as e: logger.error("Error: ", e) finally: pass