Mercurial
comparison grok_interview/inference_my_version.py @ 60:d64a8c189a77
Merged
| author | June Park <me@mrjunejune.com> |
|---|---|
| date | Sat, 20 Dec 2025 13:56:01 -0500 |
| parents | 68fa88ac73fe |
| children |
comparison
equal
deleted
inserted
replaced
| 50:983769fba767 | 60:d64a8c189a77 |
|---|---|
| 1 """ | |
| 2 Inference Questions | |
| 3 | |
| 4 Context | |
| 5 | |
| 6 You are tasked with building a simplified inference engine component responsible for handling incoming user requests for a large language model (LLM). To optimize throughput and GPU utilization, the engine must batch multiple requests together, run the inference call once per batch, and then deconstruct the results to return token-level output to the individual users. | |
| 7 | |
| 8 Objective | |
| 9 | |
| 10 Complete the provided Python class, BatchInferenceEngine by implementing the methods necessary to: | |
| 11 | |
| 12 Queue incoming user requests. | |
| 13 | |
| 14 Process a batch when the queue reaches a defined batch size. | |
| 15 Simulate the token-level output from an LLM and correctly associate each generated token with its original request. | |
| 16 """ | |
| 17 | |
| 18 import asyncio | |
| 19 from dataclasses import dataclass, field | |
| 20 from time import time | |
| 21 from typing import Dict, List | |
| 22 import uuid | |
| 23 | |
| 24 @dataclass | |
| 25 class UserRequest: | |
| 26 prompt: str | |
| 27 id: str = field(default_factory=lambda: str(uuid.uuid4())) | |
| 28 created_at: float = field(default_factory=time) | |
| 29 | |
| 30 | |
| 31 @dataclass | |
| 32 class TokenOutput: | |
| 33 request_id: str | |
| 34 token: bytes | |
| 35 | |
| 36 class BatchInferenceEngine: | |
| 37 | |
| 38 def __init__(self, batch_sizes: int = 8): | |
| 39 self.queue = [] | |
| 40 self.request_token_map: Dict[str, str] = {} | |
| 41 self.batch_sizes = batch_sizes | |
| 42 | |
| 43 self._lock = asyncio.Lock() | |
| 44 self._batch_event = asyncio.Event() | |
| 45 | |
| 46 async def add_to_queue(self, request: UserRequest): | |
| 47 async with self._lock: | |
| 48 self.queue.append(request) | |
| 49 | |
| 50 if len(self.queue) > self.batch_sizes: | |
| 51 self._batch_event.set() | |
| 52 | |
| 53 task = asyncio.create_task(self._batch()) | |
| 54 return task | |
| 55 | |
| 56 async def _batch(self): | |
| 57 while True: | |
| 58 try: | |
| 59 await asyncio.wait_for(self._batch_event.wait(), timeout=0.05) | |
| 60 except: | |
| 61 raise Exception("Timed out") | |
| 62 | |
| 63 async with self._lock: | |
| 64 if not self.queue: | |
| 65 self._batch_event.clear() | |
| 66 continue | |
| 67 | |
| 68 batch = self.queue[:self.batch_sizes] | |
| 69 tokens = await self._handle_prompt_to_token(batch) | |
| 70 | |
| 71 return tokens | |
| 72 | |
| 73 | |
| 74 async def _handle_prompt_to_token(self, batch: List[UserRequest]): | |
| 75 pass |