Mercurial
diff grok_interview/inference.py @ 60:d64a8c189a77
Merged
| author | June Park <me@mrjunejune.com> |
|---|---|
| date | Sat, 20 Dec 2025 13:56:01 -0500 |
| parents | 68fa88ac73fe |
| children |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/grok_interview/inference.py Sat Dec 20 13:56:01 2025 -0500 @@ -0,0 +1,146 @@ +import asyncio +import time +from typing import List, Dict, Any, Optional +import uuid +from dataclasses import dataclass, field + +@dataclass +class InferenceRequest: + prompt: str + request_id: str = field(default_factory=lambda: str(uuid.uuid4())) + max_tokens: int = 128 + created_at: float = field(default_factory=time.time) + +@dataclass +class TokenOutput: + token: str + request_id: str + token_index: int # position in this request's generation + is_complete: bool = False + +class BatchInferenceEngine: + + def __init__(self, batch_size: int = 8, timeout: float = 0.1): + """ + Args: + batch_size: Maximum number of requests to process in one inference call + timeout: Max time to wait for batch to fill before processing (in seconds) + """ + self.batch_size = batch_size + self.timeout = timeout + + self.queue: List[InferenceRequest] = [] + self.pending_requests: Dict[str, List[str]] = {} # request_id -> list of generated tokens + self.completion_futures: Dict[str, asyncio.Future] = {} + + self._lock = asyncio.Lock() + self._batch_ready_event = asyncio.Event() + + async def submit_request(self, prompt: str, max_tokens: int = 128) -> asyncio.Future[str]: + """ + Submit a new inference request and return a Future that resolves to the generated text. + """ + request = InferenceRequest(prompt=prompt, max_tokens=max_tokens) + + future = asyncio.get_event_loop().create_future() + + async with self._lock: + self.queue.append(request) + self.pending_requests[request.request_id] = [] + self.completion_futures[request.request_id] = future + + # Trigger batch processing if we've reached batch size + if len(self.queue) >= self.batch_size: + self._batch_ready_event.set() + + # Start background worker if not already running + asyncio.create_task(self._worker()) + + return future + + async def _worker(self): + """Background task that processes batches when ready""" + while True: + # Wait until we have requests or timeout triggers + try: + await asyncio.wait_for(self._batch_ready_event.wait(), timeout=self.timeout) + except asyncio.TimeoutError: + pass + + async with self._lock: + if not self.queue: + self._batch_ready_event.clear() + continue + + # Take up to batch_size requests + current_batch = self.queue[:self.batch_size] + self.queue = self.queue[self.batch_size:] + + # Reset event if queue is now empty + if not self.queue: + self._batch_ready_event.clear() + + # Simulate batched LLM inference + token_outputs: List[TokenOutput] = self._simulate_batched_inference(current_batch) + + # Distribute tokens back to individual requests + async with self._lock: + for token_out in token_outputs: + req_id = token_out.request_id + self.pending_requests[req_id].append(token_out.token) + + # Check if this request is complete + if token_out.is_complete: + full_text = ''.join(self.pending_requests[req_id]) + future = self.completion_futures.pop(req_id) + future.set_result(full_text) + del self.pending_requests[req_id] + + def _simulate_batched_inference(self, batch: List[InferenceRequest]) -> List[TokenOutput]: + """ + Simulates a real LLM forward pass on a batch. + In real implementation, this would call model.generate() with padded batch. + """ + outputs: List[TokenOutput] = [] + + for req in batch: + prompt_len = len(req.prompt.split()) # rough estimate + num_tokens_to_generate = min(req.max_tokens, 50 + hash(req.request_id) % 30) + + # Simulate token-by-token generation + token_str = f"{req.prompt}" + is_complete = (i == num_tokens_to_generate - 1) + + outputs.append(TokenOutput( + token=token_str, + request_id=req.request_id, + token_index=i, + is_complete=is_complete + )) + + # In streaming: could yield early here in real streaming setup + + return outputs + +# Example usage +async def main(): + engine = BatchInferenceEngine(batch_size=4, timeout=0.5) + + # Submit several requests + futures = [ + engine.submit_request("Tell me a joke about cats"), + engine.submit_request("Explain quantum computing in simple terms"), + engine.submit_request("Write a haiku about AI"), + engine.submit_request("What is the meaning of life?"), + ] + + print("Requests submitted, waiting for results...") + + # Wait for all to complete + results = await asyncio.gather(*futures) + + for i, text in enumerate(results): + print(f"Request {i}: {text.result()}\n\n\n") + +if __name__ == "__main__": + asyncio.run(main())