Mercurial
view grok_interview/inference.py @ 55:0dcfbf5ba997
Remvoing unneeded bzl rules.
| author | June Park <parkjune1995@gmail.com> |
|---|---|
| date | Fri, 19 Dec 2025 13:59:11 -0800 |
| parents | 68fa88ac73fe |
| children |
line wrap: on
line source
import asyncio import time from typing import List, Dict, Any, Optional import uuid from dataclasses import dataclass, field @dataclass class InferenceRequest: prompt: str request_id: str = field(default_factory=lambda: str(uuid.uuid4())) max_tokens: int = 128 created_at: float = field(default_factory=time.time) @dataclass class TokenOutput: token: str request_id: str token_index: int # position in this request's generation is_complete: bool = False class BatchInferenceEngine: def __init__(self, batch_size: int = 8, timeout: float = 0.1): """ Args: batch_size: Maximum number of requests to process in one inference call timeout: Max time to wait for batch to fill before processing (in seconds) """ self.batch_size = batch_size self.timeout = timeout self.queue: List[InferenceRequest] = [] self.pending_requests: Dict[str, List[str]] = {} # request_id -> list of generated tokens self.completion_futures: Dict[str, asyncio.Future] = {} self._lock = asyncio.Lock() self._batch_ready_event = asyncio.Event() async def submit_request(self, prompt: str, max_tokens: int = 128) -> asyncio.Future[str]: """ Submit a new inference request and return a Future that resolves to the generated text. """ request = InferenceRequest(prompt=prompt, max_tokens=max_tokens) future = asyncio.get_event_loop().create_future() async with self._lock: self.queue.append(request) self.pending_requests[request.request_id] = [] self.completion_futures[request.request_id] = future # Trigger batch processing if we've reached batch size if len(self.queue) >= self.batch_size: self._batch_ready_event.set() # Start background worker if not already running asyncio.create_task(self._worker()) return future async def _worker(self): """Background task that processes batches when ready""" while True: # Wait until we have requests or timeout triggers try: await asyncio.wait_for(self._batch_ready_event.wait(), timeout=self.timeout) except asyncio.TimeoutError: pass async with self._lock: if not self.queue: self._batch_ready_event.clear() continue # Take up to batch_size requests current_batch = self.queue[:self.batch_size] self.queue = self.queue[self.batch_size:] # Reset event if queue is now empty if not self.queue: self._batch_ready_event.clear() # Simulate batched LLM inference token_outputs: List[TokenOutput] = self._simulate_batched_inference(current_batch) # Distribute tokens back to individual requests async with self._lock: for token_out in token_outputs: req_id = token_out.request_id self.pending_requests[req_id].append(token_out.token) # Check if this request is complete if token_out.is_complete: full_text = ''.join(self.pending_requests[req_id]) future = self.completion_futures.pop(req_id) future.set_result(full_text) del self.pending_requests[req_id] def _simulate_batched_inference(self, batch: List[InferenceRequest]) -> List[TokenOutput]: """ Simulates a real LLM forward pass on a batch. In real implementation, this would call model.generate() with padded batch. """ outputs: List[TokenOutput] = [] for req in batch: prompt_len = len(req.prompt.split()) # rough estimate num_tokens_to_generate = min(req.max_tokens, 50 + hash(req.request_id) % 30) # Simulate token-by-token generation token_str = f"{req.prompt}" is_complete = (i == num_tokens_to_generate - 1) outputs.append(TokenOutput( token=token_str, request_id=req.request_id, token_index=i, is_complete=is_complete )) # In streaming: could yield early here in real streaming setup return outputs # Example usage async def main(): engine = BatchInferenceEngine(batch_size=4, timeout=0.5) # Submit several requests futures = [ engine.submit_request("Tell me a joke about cats"), engine.submit_request("Explain quantum computing in simple terms"), engine.submit_request("Write a haiku about AI"), engine.submit_request("What is the meaning of life?"), ] print("Requests submitted, waiting for results...") # Wait for all to complete results = await asyncio.gather(*futures) for i, text in enumerate(results): print(f"Request {i}: {text.result()}\n\n\n") if __name__ == "__main__": asyncio.run(main())