Mercurial
comparison love/poppy/apis/router.py @ 38:cf9caa4abc3e
[Love] FE and BE. Can chat and render images. Also created MCP for powerpoint generations.
| author | MrJuneJune <me@mrjunejune.com> |
|---|---|
| date | Mon, 01 Dec 2025 20:35:56 -0800 |
| parents | |
| children |
comparison
equal
deleted
inserted
replaced
| 37:fb9bcd3145cb | 38:cf9caa4abc3e |
|---|---|
| 1 from typing import Optional | |
| 2 from uuid import UUID | |
| 3 from fastapi import ( | |
| 4 APIRouter, | |
| 5 Depends, | |
| 6 HTTPException, | |
| 7 status, | |
| 8 WebSocket, | |
| 9 WebSocketDisconnect, | |
| 10 Query, | |
| 11 ) | |
| 12 from sqlalchemy.orm import selectinload | |
| 13 from sqlmodel import select, desc, exists | |
| 14 import json | |
| 15 import asyncio | |
| 16 import os | |
| 17 | |
| 18 from xai_sdk import Client | |
| 19 from xai_sdk.chat import user, system, assistant | |
| 20 from xai_sdk.tools import mcp, web_search, code_execution | |
| 21 | |
| 22 from .schemas import MessageOut, ChatOut, ChatWithMessages, Chats | |
| 23 from db.models import Chat, Message, MessageAsset | |
| 24 from utils.database import get_session | |
| 25 # from utils.redis import append_message, get_all_messages_from_redis | |
| 26 from utils.logger import logger | |
| 27 | |
| 28 import sys | |
| 29 | |
| 30 | |
| 31 GROK_API_KEY = os.getenv("XAI_API_KEY", "NO-KEY") | |
| 32 TOOLS = [ | |
| 33 web_search(), | |
| 34 code_execution(), | |
| 35 # mcp( | |
| 36 # server_url="https://mcp.babocoder.com/mcp", | |
| 37 # server_label="powerpoint-generator", | |
| 38 # server_description="This will create powerpoitn slides and files for you." | |
| 39 # ) | |
| 40 ] | |
| 41 # xAI SDK Client | |
| 42 client = Client( | |
| 43 api_key=GROK_API_KEY, | |
| 44 timeout=3600, # Take this out when you are deploying to your local server. | |
| 45 ) | |
| 46 | |
| 47 router = APIRouter(tags=["chat"]) | |
| 48 | |
| 49 # TODO: Make this into something more useful | |
| 50 SYSTEM_PROMPT = """ | |
| 51 You are a dog lover and everytime someone mentioned about their dog. You should be the most excited person alive. | |
| 52 """ | |
| 53 | |
| 54 TITLE_GENERATION_PROMPT = """Based on the following user message, generate a short, descriptive title (max 6 words) for this chat conversation. | |
| 55 Only return the title text, nothing else. | |
| 56 | |
| 57 User message: {user_message}""" | |
| 58 | |
| 59 | |
| 60 async def generate_chat_title(user_message: str) -> str: | |
| 61 try: | |
| 62 title_client = Client(api_key=GROK_API_KEY, timeout=30) | |
| 63 title_chat = title_client.chat.create(model="grok-4-fast") | |
| 64 title_chat.append(system(TITLE_GENERATION_PROMPT.format(user_message=user_message))) | |
| 65 | |
| 66 response = title_chat.sample() | |
| 67 title = response.content.strip() | |
| 68 | |
| 69 logger.info("Title: \n") | |
| 70 logger.info(title) | |
| 71 | |
| 72 # Ensure title isn't too long | |
| 73 if len(title) > 100: | |
| 74 title = title[:97] + "..." | |
| 75 | |
| 76 return title if title else "New Chat" | |
| 77 except Exception as e: | |
| 78 logger.error(f"Failed to generate title: {e}") | |
| 79 return "New Chat" | |
| 80 | |
| 81 async def generate_image(user_message: str) -> str: | |
| 82 try: | |
| 83 image_client = Client(api_key=GROK_API_KEY, timeout=300) | |
| 84 response = image_client.image.sample( | |
| 85 model="grok-2-image", | |
| 86 prompt=user_message, | |
| 87 image_format="url" | |
| 88 ) | |
| 89 logger.info("Image URL: \n") | |
| 90 logger.info(response.url) | |
| 91 | |
| 92 return response.url | |
| 93 except Exception as e: | |
| 94 logger.error(f"Failed to generate title: {e}") | |
| 95 return "failed image url" | |
| 96 | |
| 97 | |
| 98 async def should_update_title(chat: Chat, session) -> bool: | |
| 99 stmt = select(Message).where(Message.chat_id == chat.id) | |
| 100 result = await session.execute(stmt) | |
| 101 message_count = len(result.scalars().all()) | |
| 102 logger.info(f"message_count: {message_count}") | |
| 103 | |
| 104 # Update title if we have exactly 2 messages and title is still default | |
| 105 return message_count == 2 and chat.title == "New Chat" | |
| 106 | |
| 107 | |
| 108 # --- REST --- | |
| 109 @router.get("/chats", response_model=Chats) | |
| 110 async def get_chats( | |
| 111 session=Depends(get_session), | |
| 112 limit: int = Query(50, ge=1, le=100), | |
| 113 cursor: Optional[str] = Query(None, description="Chat ID to paginate from") | |
| 114 ): | |
| 115 has_messages = exists(Message).where(Message.chat_id == Chat.id) | |
| 116 | |
| 117 stmt = select(Chat).where(has_messages).order_by(desc(Chat.created_at)) | |
| 118 | |
| 119 # If cursor provided, filter to chats older than cursor chat | |
| 120 if cursor: | |
| 121 try: | |
| 122 cursor_chat_id = UUID(cursor) | |
| 123 cursor_chat = await session.get(Chat, cursor_chat_id) | |
| 124 if cursor_chat: | |
| 125 stmt = stmt.where(Chat.created_at < cursor_chat.created_at) | |
| 126 except (ValueError, TypeError): | |
| 127 pass # Invalid cursor, ignore it | |
| 128 | |
| 129 stmt = stmt.limit(limit) | |
| 130 | |
| 131 result = await session.execute(stmt) | |
| 132 recent_chats = result.scalars().all() | |
| 133 | |
| 134 return Chats(chats=recent_chats) | |
| 135 | |
| 136 @router.post("/chats", response_model=ChatOut, status_code=status.HTTP_201_CREATED) | |
| 137 async def create_chat(session=Depends(get_session)): | |
| 138 chat = Chat() | |
| 139 session.add(chat) | |
| 140 await session.commit() | |
| 141 await session.refresh(chat) | |
| 142 return chat | |
| 143 | |
| 144 | |
| 145 @router.get("/chats/{chat_id_str}/messages", response_model=ChatWithMessages) | |
| 146 async def get_chat_messages(chat_id_str: str, session=Depends(get_session)): | |
| 147 chat_id = UUID(chat_id_str) | |
| 148 chat = await session.get(Chat, chat_id) | |
| 149 if not chat: | |
| 150 raise HTTPException(404, "Chat not found") | |
| 151 | |
| 152 # try: | |
| 153 # all_msgs = await get_all_messages_from_redis() | |
| 154 # chat_msgs = [m for m in all_msgs if m["chatId"] == chat_id_str] | |
| 155 # if chat_msgs: | |
| 156 # messages = [ | |
| 157 # MessageOut( | |
| 158 # id=UUID(m.get("id", "00000000-0000-0000-0000-000000000000")), # or store id in Redis! | |
| 159 # role=m["role"], | |
| 160 # content=m["content"], | |
| 161 # created_at=None, | |
| 162 # image_url=m.get("image_url"), # ← make sure Redis also stores this | |
| 163 # ) | |
| 164 # for m in chat_msgs | |
| 165 # ] | |
| 166 # return ChatWithMessages( | |
| 167 # id=chat.id, | |
| 168 # title=chat.title, | |
| 169 # created_at=chat.created_at, | |
| 170 # messages=messages, | |
| 171 # ) | |
| 172 # except Exception as e: | |
| 173 # pass | |
| 174 | |
| 175 stmt = ( | |
| 176 select(Message) | |
| 177 .options(selectinload(Message.assets)) | |
| 178 .where(Message.chat_id == chat_id) | |
| 179 .order_by(Message.created_at) | |
| 180 ) | |
| 181 result = await session.execute(stmt) | |
| 182 messages = result.scalars().all() | |
| 183 return ChatWithMessages( | |
| 184 id=chat.id, | |
| 185 title=chat.title, | |
| 186 created_at=chat.created_at, | |
| 187 messages=[MessageOut.from_orm_with_assets(m) for m in messages], | |
| 188 ) | |
| 189 | |
| 190 | |
| 191 @router.websocket("/chats/{chat_id_str}/ws") | |
| 192 async def chat_stream( | |
| 193 websocket: WebSocket, | |
| 194 chat_id_str: str, | |
| 195 session=Depends(get_session), | |
| 196 use_grok: bool = Query(True), # TODO: use url for this, but for now this is fine. | |
| 197 ): | |
| 198 chat_id = UUID(chat_id_str) | |
| 199 | |
| 200 await websocket.accept() | |
| 201 | |
| 202 chat = None | |
| 203 history = [] | |
| 204 | |
| 205 # Load chat from DB | |
| 206 db_chat = await session.get(Chat, chat_id) | |
| 207 if not db_chat: | |
| 208 await websocket.close(code=1008) | |
| 209 return | |
| 210 | |
| 211 # Load Message | |
| 212 if use_grok: | |
| 213 try: | |
| 214 all_msgs = await get_all_messages_from_redis() | |
| 215 chat_msgs = [m for m in all_msgs if m["chatId"] == chat_id_str] | |
| 216 if chat_msgs: | |
| 217 history = [ | |
| 218 {"role": m["role"], "content": m["content"]} for m in chat_msgs | |
| 219 ] | |
| 220 else: | |
| 221 stmt = ( | |
| 222 select(Message) | |
| 223 .where(Message.chat_id == chat_id) | |
| 224 .order_by(Message.created_at) | |
| 225 ) | |
| 226 result = await session.execute(stmt) | |
| 227 db_msgs = result.scalars().all() | |
| 228 history = [{"role": m.role, "content": m.content} for m in db_msgs] | |
| 229 except Exception as e: | |
| 230 logger.error("History load failed:", e) | |
| 231 history = [] | |
| 232 | |
| 233 chat = client.chat.create( | |
| 234 model="grok-4-fast", | |
| 235 tools=TOOLS | |
| 236 ) | |
| 237 chat.append(system(SYSTEM_PROMPT)) | |
| 238 | |
| 239 # order does not matter so should be fine... | |
| 240 for msg in history: | |
| 241 if msg["role"] == "user": | |
| 242 chat.append(user(msg["content"])) | |
| 243 elif msg["role"] == "assistant": | |
| 244 chat.append(assistant(msg["content"])) | |
| 245 | |
| 246 logger.info(f"This should only be called once: {history}") | |
| 247 # Listen to websocket | |
| 248 try: | |
| 249 while True: | |
| 250 full_response = "" | |
| 251 assistant_msg = None | |
| 252 | |
| 253 try: | |
| 254 data = await websocket.receive_text() | |
| 255 # TODO: Create seralizer for this. | |
| 256 user_msg = json.loads(data) | |
| 257 if user_msg.get("role") != "user": | |
| 258 await websocket.close(code=1008) | |
| 259 return | |
| 260 user_content = user_msg["content"] | |
| 261 | |
| 262 msg = Message(chat_id=chat_id, role="user", content=user_content) | |
| 263 session.add(msg) | |
| 264 await session.commit() | |
| 265 await session.refresh(msg) | |
| 266 # await append_message(chat_id_str, "user", user_content) | |
| 267 | |
| 268 # Actual API calls | |
| 269 if use_grok: | |
| 270 | |
| 271 # Can use clarifying agent model for this, but I feel like that would be over kill.. | |
| 272 image_triggers = [ | |
| 273 "/img ", | |
| 274 "generate image ", | |
| 275 "create image ", | |
| 276 "draw ", | |
| 277 "make an image of ", | |
| 278 "show me a picture of ", | |
| 279 "image: ", | |
| 280 "img: ", | |
| 281 "🖼️", | |
| 282 "picture of ", | |
| 283 "generate a photo of ", | |
| 284 ] | |
| 285 is_image_call = False | |
| 286 image_url = "" | |
| 287 for trigger in image_triggers: | |
| 288 if trigger in user_content: | |
| 289 is_image_call = True | |
| 290 break | |
| 291 | |
| 292 if is_image_call: | |
| 293 image_url = await generate_image(user_content) | |
| 294 logger.info(f"Image got generated and we are sending assistant prompt?") | |
| 295 chat.append(assistant(f"Your image is here. Please view {image_url}!")) | |
| 296 await websocket.send_json( | |
| 297 { | |
| 298 "chatId": chat_id_str, | |
| 299 "url": f"{image_url}", | |
| 300 "action": "image", | |
| 301 } | |
| 302 ) | |
| 303 else: | |
| 304 logger.info(f"User content sent: {user_content}") | |
| 305 chat.append(user(user_content)) | |
| 306 | |
| 307 # AI response.... | |
| 308 # Maybe stream the token into redis?, but probably over kill. | |
| 309 is_thinking = True | |
| 310 if use_grok: | |
| 311 for response, chunk in chat.stream(): | |
| 312 for tool_call in chunk.tool_calls: | |
| 313 logger.info(f"\nCalling tool: {tool_call.function.name} with arguments: {tool_call.function.arguments}") | |
| 314 # if response.usage.reasoning_tokens and is_thinking: | |
| 315 # logger.info(f"\rThinking... ({response.usage.reasoning_tokens} tokens)") | |
| 316 if chunk.content and is_thinking: | |
| 317 logger.info("\n\nAnalysis Results:") | |
| 318 is_thinking = False | |
| 319 token = chunk.content | |
| 320 if not token: | |
| 321 continue | |
| 322 full_response = response.content | |
| 323 await websocket.send_json( | |
| 324 { | |
| 325 "chatId": chat_id_str, | |
| 326 "content": token, | |
| 327 "action": "append", | |
| 328 } | |
| 329 ) | |
| 330 await asyncio.sleep(0.015) | |
| 331 logger.info("\n\nCitations:") | |
| 332 logger.info(response.citations) | |
| 333 logger.info("\n\nUsage:") | |
| 334 logger.info(response.usage) | |
| 335 logger.info(response.server_side_tool_usage) | |
| 336 logger.info("\n\nServer Side Tool Calls:") | |
| 337 logger.info(response.tool_calls) | |
| 338 else: | |
| 339 mock_reply = f"[Mock] Echo: {user_content}" | |
| 340 for token in mock_reply.split(): | |
| 341 token += " " | |
| 342 full_response += token | |
| 343 await websocket.send_json( | |
| 344 { | |
| 345 "chatId": chat_id_str, | |
| 346 "content": token, | |
| 347 "action": "append", | |
| 348 } | |
| 349 ) | |
| 350 await asyncio.sleep(0.2) | |
| 351 | |
| 352 if use_grok and full_response: | |
| 353 chat.append(assistant(full_response)) | |
| 354 | |
| 355 # Async update redis | |
| 356 # task = asyncio.create_task( | |
| 357 # append_message(chat_id_str, "assistant", full_response) | |
| 358 # ) | |
| 359 # asyncio.shield(task) | |
| 360 | |
| 361 assistant_msg = Message( | |
| 362 chat_id=chat_id, role="assistant", content=full_response | |
| 363 ) | |
| 364 session.add(assistant_msg) | |
| 365 await session.commit() | |
| 366 await session.refresh(assistant_msg) | |
| 367 | |
| 368 # Now create assets linked to the message | |
| 369 if is_image_call and image_url: | |
| 370 asset = MessageAsset( | |
| 371 message_id=assistant_msg.id, | |
| 372 asset_type="image", | |
| 373 url=image_url, | |
| 374 ) | |
| 375 session.add(asset) | |
| 376 await session.commit() | |
| 377 await session.refresh(assistant_msg) | |
| 378 | |
| 379 # Getting titles | |
| 380 # Generate title after first exchange | |
| 381 if use_grok and await should_update_title(db_chat, session): | |
| 382 new_title = await generate_chat_title(user_content) | |
| 383 db_chat.title = new_title | |
| 384 session.add(db_chat) | |
| 385 await session.commit() | |
| 386 logger.info(f"Updated chat title to: {new_title}") | |
| 387 await websocket.send_json( | |
| 388 { | |
| 389 "chatId": chat_id_str, | |
| 390 "title": new_title, | |
| 391 "action": "title_updated" | |
| 392 } | |
| 393 ) | |
| 394 else: | |
| 395 db_chat.title = "Automatic title" | |
| 396 session.add(db_chat) | |
| 397 await session.commit() | |
| 398 await websocket.send_json( | |
| 399 { | |
| 400 "chatId": chat_id_str, | |
| 401 "title": "Automatic title", | |
| 402 "action": "title_updated" | |
| 403 } | |
| 404 ) | |
| 405 | |
| 406 await websocket.send_json( | |
| 407 {"chatId": chat_id_str, "content": "", "action": "done"} | |
| 408 ) | |
| 409 except WebSocketDisconnect: | |
| 410 break | |
| 411 except json.JSONDecodeError: | |
| 412 await websocket.send_json({"error": "Invalid JSON"}) | |
| 413 except Exception as e: | |
| 414 exc_type, exc_value, exc_traceback = sys.exc_info() | |
| 415 line_number = exc_traceback.tb_lineno | |
| 416 logger.error( | |
| 417 f"An error occurred on line {line_number}: {type(e).__name__} - {e}" | |
| 418 ) | |
| 419 await websocket.send_json({"error": str(e)}) | |
| 420 except Exception as e: | |
| 421 logger.error("Error: ", e) | |
| 422 finally: | |
| 423 pass |