comparison love/poppy/apis/router.py @ 38:cf9caa4abc3e

[Love] FE and BE. Can chat and render images. Also created MCP for powerpoint generations.
author MrJuneJune <me@mrjunejune.com>
date Mon, 01 Dec 2025 20:35:56 -0800
parents
children
comparison
equal deleted inserted replaced
37:fb9bcd3145cb 38:cf9caa4abc3e
1 from typing import Optional
2 from uuid import UUID
3 from fastapi import (
4 APIRouter,
5 Depends,
6 HTTPException,
7 status,
8 WebSocket,
9 WebSocketDisconnect,
10 Query,
11 )
12 from sqlalchemy.orm import selectinload
13 from sqlmodel import select, desc, exists
14 import json
15 import asyncio
16 import os
17
18 from xai_sdk import Client
19 from xai_sdk.chat import user, system, assistant
20 from xai_sdk.tools import mcp, web_search, code_execution
21
22 from .schemas import MessageOut, ChatOut, ChatWithMessages, Chats
23 from db.models import Chat, Message, MessageAsset
24 from utils.database import get_session
25 # from utils.redis import append_message, get_all_messages_from_redis
26 from utils.logger import logger
27
28 import sys
29
30
31 GROK_API_KEY = os.getenv("XAI_API_KEY", "NO-KEY")
32 TOOLS = [
33 web_search(),
34 code_execution(),
35 # mcp(
36 # server_url="https://mcp.babocoder.com/mcp",
37 # server_label="powerpoint-generator",
38 # server_description="This will create powerpoitn slides and files for you."
39 # )
40 ]
41 # xAI SDK Client
42 client = Client(
43 api_key=GROK_API_KEY,
44 timeout=3600, # Take this out when you are deploying to your local server.
45 )
46
47 router = APIRouter(tags=["chat"])
48
49 # TODO: Make this into something more useful
50 SYSTEM_PROMPT = """
51 You are a dog lover and everytime someone mentioned about their dog. You should be the most excited person alive.
52 """
53
54 TITLE_GENERATION_PROMPT = """Based on the following user message, generate a short, descriptive title (max 6 words) for this chat conversation.
55 Only return the title text, nothing else.
56
57 User message: {user_message}"""
58
59
60 async def generate_chat_title(user_message: str) -> str:
61 try:
62 title_client = Client(api_key=GROK_API_KEY, timeout=30)
63 title_chat = title_client.chat.create(model="grok-4-fast")
64 title_chat.append(system(TITLE_GENERATION_PROMPT.format(user_message=user_message)))
65
66 response = title_chat.sample()
67 title = response.content.strip()
68
69 logger.info("Title: \n")
70 logger.info(title)
71
72 # Ensure title isn't too long
73 if len(title) > 100:
74 title = title[:97] + "..."
75
76 return title if title else "New Chat"
77 except Exception as e:
78 logger.error(f"Failed to generate title: {e}")
79 return "New Chat"
80
81 async def generate_image(user_message: str) -> str:
82 try:
83 image_client = Client(api_key=GROK_API_KEY, timeout=300)
84 response = image_client.image.sample(
85 model="grok-2-image",
86 prompt=user_message,
87 image_format="url"
88 )
89 logger.info("Image URL: \n")
90 logger.info(response.url)
91
92 return response.url
93 except Exception as e:
94 logger.error(f"Failed to generate title: {e}")
95 return "failed image url"
96
97
98 async def should_update_title(chat: Chat, session) -> bool:
99 stmt = select(Message).where(Message.chat_id == chat.id)
100 result = await session.execute(stmt)
101 message_count = len(result.scalars().all())
102 logger.info(f"message_count: {message_count}")
103
104 # Update title if we have exactly 2 messages and title is still default
105 return message_count == 2 and chat.title == "New Chat"
106
107
108 # --- REST ---
109 @router.get("/chats", response_model=Chats)
110 async def get_chats(
111 session=Depends(get_session),
112 limit: int = Query(50, ge=1, le=100),
113 cursor: Optional[str] = Query(None, description="Chat ID to paginate from")
114 ):
115 has_messages = exists(Message).where(Message.chat_id == Chat.id)
116
117 stmt = select(Chat).where(has_messages).order_by(desc(Chat.created_at))
118
119 # If cursor provided, filter to chats older than cursor chat
120 if cursor:
121 try:
122 cursor_chat_id = UUID(cursor)
123 cursor_chat = await session.get(Chat, cursor_chat_id)
124 if cursor_chat:
125 stmt = stmt.where(Chat.created_at < cursor_chat.created_at)
126 except (ValueError, TypeError):
127 pass # Invalid cursor, ignore it
128
129 stmt = stmt.limit(limit)
130
131 result = await session.execute(stmt)
132 recent_chats = result.scalars().all()
133
134 return Chats(chats=recent_chats)
135
136 @router.post("/chats", response_model=ChatOut, status_code=status.HTTP_201_CREATED)
137 async def create_chat(session=Depends(get_session)):
138 chat = Chat()
139 session.add(chat)
140 await session.commit()
141 await session.refresh(chat)
142 return chat
143
144
145 @router.get("/chats/{chat_id_str}/messages", response_model=ChatWithMessages)
146 async def get_chat_messages(chat_id_str: str, session=Depends(get_session)):
147 chat_id = UUID(chat_id_str)
148 chat = await session.get(Chat, chat_id)
149 if not chat:
150 raise HTTPException(404, "Chat not found")
151
152 # try:
153 # all_msgs = await get_all_messages_from_redis()
154 # chat_msgs = [m for m in all_msgs if m["chatId"] == chat_id_str]
155 # if chat_msgs:
156 # messages = [
157 # MessageOut(
158 # id=UUID(m.get("id", "00000000-0000-0000-0000-000000000000")), # or store id in Redis!
159 # role=m["role"],
160 # content=m["content"],
161 # created_at=None,
162 # image_url=m.get("image_url"), # ← make sure Redis also stores this
163 # )
164 # for m in chat_msgs
165 # ]
166 # return ChatWithMessages(
167 # id=chat.id,
168 # title=chat.title,
169 # created_at=chat.created_at,
170 # messages=messages,
171 # )
172 # except Exception as e:
173 # pass
174
175 stmt = (
176 select(Message)
177 .options(selectinload(Message.assets))
178 .where(Message.chat_id == chat_id)
179 .order_by(Message.created_at)
180 )
181 result = await session.execute(stmt)
182 messages = result.scalars().all()
183 return ChatWithMessages(
184 id=chat.id,
185 title=chat.title,
186 created_at=chat.created_at,
187 messages=[MessageOut.from_orm_with_assets(m) for m in messages],
188 )
189
190
191 @router.websocket("/chats/{chat_id_str}/ws")
192 async def chat_stream(
193 websocket: WebSocket,
194 chat_id_str: str,
195 session=Depends(get_session),
196 use_grok: bool = Query(True), # TODO: use url for this, but for now this is fine.
197 ):
198 chat_id = UUID(chat_id_str)
199
200 await websocket.accept()
201
202 chat = None
203 history = []
204
205 # Load chat from DB
206 db_chat = await session.get(Chat, chat_id)
207 if not db_chat:
208 await websocket.close(code=1008)
209 return
210
211 # Load Message
212 if use_grok:
213 try:
214 all_msgs = await get_all_messages_from_redis()
215 chat_msgs = [m for m in all_msgs if m["chatId"] == chat_id_str]
216 if chat_msgs:
217 history = [
218 {"role": m["role"], "content": m["content"]} for m in chat_msgs
219 ]
220 else:
221 stmt = (
222 select(Message)
223 .where(Message.chat_id == chat_id)
224 .order_by(Message.created_at)
225 )
226 result = await session.execute(stmt)
227 db_msgs = result.scalars().all()
228 history = [{"role": m.role, "content": m.content} for m in db_msgs]
229 except Exception as e:
230 logger.error("History load failed:", e)
231 history = []
232
233 chat = client.chat.create(
234 model="grok-4-fast",
235 tools=TOOLS
236 )
237 chat.append(system(SYSTEM_PROMPT))
238
239 # order does not matter so should be fine...
240 for msg in history:
241 if msg["role"] == "user":
242 chat.append(user(msg["content"]))
243 elif msg["role"] == "assistant":
244 chat.append(assistant(msg["content"]))
245
246 logger.info(f"This should only be called once: {history}")
247 # Listen to websocket
248 try:
249 while True:
250 full_response = ""
251 assistant_msg = None
252
253 try:
254 data = await websocket.receive_text()
255 # TODO: Create seralizer for this.
256 user_msg = json.loads(data)
257 if user_msg.get("role") != "user":
258 await websocket.close(code=1008)
259 return
260 user_content = user_msg["content"]
261
262 msg = Message(chat_id=chat_id, role="user", content=user_content)
263 session.add(msg)
264 await session.commit()
265 await session.refresh(msg)
266 # await append_message(chat_id_str, "user", user_content)
267
268 # Actual API calls
269 if use_grok:
270
271 # Can use clarifying agent model for this, but I feel like that would be over kill..
272 image_triggers = [
273 "/img ",
274 "generate image ",
275 "create image ",
276 "draw ",
277 "make an image of ",
278 "show me a picture of ",
279 "image: ",
280 "img: ",
281 "🖼️",
282 "picture of ",
283 "generate a photo of ",
284 ]
285 is_image_call = False
286 image_url = ""
287 for trigger in image_triggers:
288 if trigger in user_content:
289 is_image_call = True
290 break
291
292 if is_image_call:
293 image_url = await generate_image(user_content)
294 logger.info(f"Image got generated and we are sending assistant prompt?")
295 chat.append(assistant(f"Your image is here. Please view {image_url}!"))
296 await websocket.send_json(
297 {
298 "chatId": chat_id_str,
299 "url": f"{image_url}",
300 "action": "image",
301 }
302 )
303 else:
304 logger.info(f"User content sent: {user_content}")
305 chat.append(user(user_content))
306
307 # AI response....
308 # Maybe stream the token into redis?, but probably over kill.
309 is_thinking = True
310 if use_grok:
311 for response, chunk in chat.stream():
312 for tool_call in chunk.tool_calls:
313 logger.info(f"\nCalling tool: {tool_call.function.name} with arguments: {tool_call.function.arguments}")
314 # if response.usage.reasoning_tokens and is_thinking:
315 # logger.info(f"\rThinking... ({response.usage.reasoning_tokens} tokens)")
316 if chunk.content and is_thinking:
317 logger.info("\n\nAnalysis Results:")
318 is_thinking = False
319 token = chunk.content
320 if not token:
321 continue
322 full_response = response.content
323 await websocket.send_json(
324 {
325 "chatId": chat_id_str,
326 "content": token,
327 "action": "append",
328 }
329 )
330 await asyncio.sleep(0.015)
331 logger.info("\n\nCitations:")
332 logger.info(response.citations)
333 logger.info("\n\nUsage:")
334 logger.info(response.usage)
335 logger.info(response.server_side_tool_usage)
336 logger.info("\n\nServer Side Tool Calls:")
337 logger.info(response.tool_calls)
338 else:
339 mock_reply = f"[Mock] Echo: {user_content}"
340 for token in mock_reply.split():
341 token += " "
342 full_response += token
343 await websocket.send_json(
344 {
345 "chatId": chat_id_str,
346 "content": token,
347 "action": "append",
348 }
349 )
350 await asyncio.sleep(0.2)
351
352 if use_grok and full_response:
353 chat.append(assistant(full_response))
354
355 # Async update redis
356 # task = asyncio.create_task(
357 # append_message(chat_id_str, "assistant", full_response)
358 # )
359 # asyncio.shield(task)
360
361 assistant_msg = Message(
362 chat_id=chat_id, role="assistant", content=full_response
363 )
364 session.add(assistant_msg)
365 await session.commit()
366 await session.refresh(assistant_msg)
367
368 # Now create assets linked to the message
369 if is_image_call and image_url:
370 asset = MessageAsset(
371 message_id=assistant_msg.id,
372 asset_type="image",
373 url=image_url,
374 )
375 session.add(asset)
376 await session.commit()
377 await session.refresh(assistant_msg)
378
379 # Getting titles
380 # Generate title after first exchange
381 if use_grok and await should_update_title(db_chat, session):
382 new_title = await generate_chat_title(user_content)
383 db_chat.title = new_title
384 session.add(db_chat)
385 await session.commit()
386 logger.info(f"Updated chat title to: {new_title}")
387 await websocket.send_json(
388 {
389 "chatId": chat_id_str,
390 "title": new_title,
391 "action": "title_updated"
392 }
393 )
394 else:
395 db_chat.title = "Automatic title"
396 session.add(db_chat)
397 await session.commit()
398 await websocket.send_json(
399 {
400 "chatId": chat_id_str,
401 "title": "Automatic title",
402 "action": "title_updated"
403 }
404 )
405
406 await websocket.send_json(
407 {"chatId": chat_id_str, "content": "", "action": "done"}
408 )
409 except WebSocketDisconnect:
410 break
411 except json.JSONDecodeError:
412 await websocket.send_json({"error": "Invalid JSON"})
413 except Exception as e:
414 exc_type, exc_value, exc_traceback = sys.exc_info()
415 line_number = exc_traceback.tb_lineno
416 logger.error(
417 f"An error occurred on line {line_number}: {type(e).__name__} - {e}"
418 )
419 await websocket.send_json({"error": str(e)})
420 except Exception as e:
421 logger.error("Error: ", e)
422 finally:
423 pass