Mercurial
comparison grok_interview/inference.py @ 60:d64a8c189a77
Merged
| author | June Park <me@mrjunejune.com> |
|---|---|
| date | Sat, 20 Dec 2025 13:56:01 -0500 |
| parents | 68fa88ac73fe |
| children |
comparison
equal
deleted
inserted
replaced
| 50:983769fba767 | 60:d64a8c189a77 |
|---|---|
| 1 import asyncio | |
| 2 import time | |
| 3 from typing import List, Dict, Any, Optional | |
| 4 import uuid | |
| 5 from dataclasses import dataclass, field | |
| 6 | |
| 7 @dataclass | |
| 8 class InferenceRequest: | |
| 9 prompt: str | |
| 10 request_id: str = field(default_factory=lambda: str(uuid.uuid4())) | |
| 11 max_tokens: int = 128 | |
| 12 created_at: float = field(default_factory=time.time) | |
| 13 | |
| 14 @dataclass | |
| 15 class TokenOutput: | |
| 16 token: str | |
| 17 request_id: str | |
| 18 token_index: int # position in this request's generation | |
| 19 is_complete: bool = False | |
| 20 | |
| 21 class BatchInferenceEngine: | |
| 22 | |
| 23 def __init__(self, batch_size: int = 8, timeout: float = 0.1): | |
| 24 """ | |
| 25 Args: | |
| 26 batch_size: Maximum number of requests to process in one inference call | |
| 27 timeout: Max time to wait for batch to fill before processing (in seconds) | |
| 28 """ | |
| 29 self.batch_size = batch_size | |
| 30 self.timeout = timeout | |
| 31 | |
| 32 self.queue: List[InferenceRequest] = [] | |
| 33 self.pending_requests: Dict[str, List[str]] = {} # request_id -> list of generated tokens | |
| 34 self.completion_futures: Dict[str, asyncio.Future] = {} | |
| 35 | |
| 36 self._lock = asyncio.Lock() | |
| 37 self._batch_ready_event = asyncio.Event() | |
| 38 | |
| 39 async def submit_request(self, prompt: str, max_tokens: int = 128) -> asyncio.Future[str]: | |
| 40 """ | |
| 41 Submit a new inference request and return a Future that resolves to the generated text. | |
| 42 """ | |
| 43 request = InferenceRequest(prompt=prompt, max_tokens=max_tokens) | |
| 44 | |
| 45 future = asyncio.get_event_loop().create_future() | |
| 46 | |
| 47 async with self._lock: | |
| 48 self.queue.append(request) | |
| 49 self.pending_requests[request.request_id] = [] | |
| 50 self.completion_futures[request.request_id] = future | |
| 51 | |
| 52 # Trigger batch processing if we've reached batch size | |
| 53 if len(self.queue) >= self.batch_size: | |
| 54 self._batch_ready_event.set() | |
| 55 | |
| 56 # Start background worker if not already running | |
| 57 asyncio.create_task(self._worker()) | |
| 58 | |
| 59 return future | |
| 60 | |
| 61 async def _worker(self): | |
| 62 """Background task that processes batches when ready""" | |
| 63 while True: | |
| 64 # Wait until we have requests or timeout triggers | |
| 65 try: | |
| 66 await asyncio.wait_for(self._batch_ready_event.wait(), timeout=self.timeout) | |
| 67 except asyncio.TimeoutError: | |
| 68 pass | |
| 69 | |
| 70 async with self._lock: | |
| 71 if not self.queue: | |
| 72 self._batch_ready_event.clear() | |
| 73 continue | |
| 74 | |
| 75 # Take up to batch_size requests | |
| 76 current_batch = self.queue[:self.batch_size] | |
| 77 self.queue = self.queue[self.batch_size:] | |
| 78 | |
| 79 # Reset event if queue is now empty | |
| 80 if not self.queue: | |
| 81 self._batch_ready_event.clear() | |
| 82 | |
| 83 # Simulate batched LLM inference | |
| 84 token_outputs: List[TokenOutput] = self._simulate_batched_inference(current_batch) | |
| 85 | |
| 86 # Distribute tokens back to individual requests | |
| 87 async with self._lock: | |
| 88 for token_out in token_outputs: | |
| 89 req_id = token_out.request_id | |
| 90 self.pending_requests[req_id].append(token_out.token) | |
| 91 | |
| 92 # Check if this request is complete | |
| 93 if token_out.is_complete: | |
| 94 full_text = ''.join(self.pending_requests[req_id]) | |
| 95 future = self.completion_futures.pop(req_id) | |
| 96 future.set_result(full_text) | |
| 97 del self.pending_requests[req_id] | |
| 98 | |
| 99 def _simulate_batched_inference(self, batch: List[InferenceRequest]) -> List[TokenOutput]: | |
| 100 """ | |
| 101 Simulates a real LLM forward pass on a batch. | |
| 102 In real implementation, this would call model.generate() with padded batch. | |
| 103 """ | |
| 104 outputs: List[TokenOutput] = [] | |
| 105 | |
| 106 for req in batch: | |
| 107 prompt_len = len(req.prompt.split()) # rough estimate | |
| 108 num_tokens_to_generate = min(req.max_tokens, 50 + hash(req.request_id) % 30) | |
| 109 | |
| 110 # Simulate token-by-token generation | |
| 111 token_str = f"{req.prompt}" | |
| 112 is_complete = (i == num_tokens_to_generate - 1) | |
| 113 | |
| 114 outputs.append(TokenOutput( | |
| 115 token=token_str, | |
| 116 request_id=req.request_id, | |
| 117 token_index=i, | |
| 118 is_complete=is_complete | |
| 119 )) | |
| 120 | |
| 121 # In streaming: could yield early here in real streaming setup | |
| 122 | |
| 123 return outputs | |
| 124 | |
| 125 # Example usage | |
| 126 async def main(): | |
| 127 engine = BatchInferenceEngine(batch_size=4, timeout=0.5) | |
| 128 | |
| 129 # Submit several requests | |
| 130 futures = [ | |
| 131 engine.submit_request("Tell me a joke about cats"), | |
| 132 engine.submit_request("Explain quantum computing in simple terms"), | |
| 133 engine.submit_request("Write a haiku about AI"), | |
| 134 engine.submit_request("What is the meaning of life?"), | |
| 135 ] | |
| 136 | |
| 137 print("Requests submitted, waiting for results...") | |
| 138 | |
| 139 # Wait for all to complete | |
| 140 results = await asyncio.gather(*futures) | |
| 141 | |
| 142 for i, text in enumerate(results): | |
| 143 print(f"Request {i}: {text.result()}\n\n\n") | |
| 144 | |
| 145 if __name__ == "__main__": | |
| 146 asyncio.run(main()) |