comparison grok_interview/inference_my_version.py @ 51:68fa88ac73fe

Interview prep for xAI
author June Park <parkjune1995@gmail.com>
date Mon, 15 Dec 2025 19:55:17 -0800
parents
children
comparison
equal deleted inserted replaced
46:b9a40c633c93 51:68fa88ac73fe
1 """
2 Inference Questions
3
4 Context
5
6 You are tasked with building a simplified inference engine component responsible for handling incoming user requests for a large language model (LLM). To optimize throughput and GPU utilization, the engine must batch multiple requests together, run the inference call once per batch, and then deconstruct the results to return token-level output to the individual users.
7
8 Objective
9
10 Complete the provided Python class, BatchInferenceEngine by implementing the methods necessary to:
11
12 Queue incoming user requests.
13
14 Process a batch when the queue reaches a defined batch size.
15 Simulate the token-level output from an LLM and correctly associate each generated token with its original request.
16 """
17
18 import asyncio
19 from dataclasses import dataclass, field
20 from time import time
21 from typing import Dict, List
22 import uuid
23
24 @dataclass
25 class UserRequest:
26 prompt: str
27 id: str = field(default_factory=lambda: str(uuid.uuid4()))
28 created_at: float = field(default_factory=time)
29
30
31 @dataclass
32 class TokenOutput:
33 request_id: str
34 token: bytes
35
36 class BatchInferenceEngine:
37
38 def __init__(self, batch_sizes: int = 8):
39 self.queue = []
40 self.request_token_map: Dict[str, str] = {}
41 self.batch_sizes = batch_sizes
42
43 self._lock = asyncio.Lock()
44 self._batch_event = asyncio.Event()
45
46 async def add_to_queue(self, request: UserRequest):
47 async with self._lock:
48 self.queue.append(request)
49
50 if len(self.queue) > self.batch_sizes:
51 self._batch_event.set()
52
53 task = asyncio.create_task(self._batch())
54 return task
55
56 async def _batch(self):
57 while True:
58 try:
59 await asyncio.wait_for(self._batch_event.wait(), timeout=0.05)
60 except:
61 raise Exception("Timed out")
62
63 async with self._lock:
64 if not self.queue:
65 self._batch_event.clear()
66 continue
67
68 batch = self.queue[:self.batch_sizes]
69 tokens = await self._handle_prompt_to_token(batch)
70
71 return tokens
72
73
74 async def _handle_prompt_to_token(self, batch: List[UserRequest]):
75 pass