diff grok_interview/inference.py @ 51:68fa88ac73fe

Interview prep for xAI
author June Park <parkjune1995@gmail.com>
date Mon, 15 Dec 2025 19:55:17 -0800
parents
children
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/grok_interview/inference.py	Mon Dec 15 19:55:17 2025 -0800
@@ -0,0 +1,146 @@
+import asyncio
+import time
+from typing import List, Dict, Any, Optional
+import uuid
+from dataclasses import dataclass, field
+
+@dataclass
+class InferenceRequest:
+    prompt: str
+    request_id: str = field(default_factory=lambda: str(uuid.uuid4()))
+    max_tokens: int = 128
+    created_at: float = field(default_factory=time.time)
+
+@dataclass
+class TokenOutput:
+    token: str
+    request_id: str
+    token_index: int  # position in this request's generation
+    is_complete: bool = False
+
+class BatchInferenceEngine:
+
+    def __init__(self, batch_size: int = 8, timeout: float = 0.1):
+        """
+        Args:
+            batch_size: Maximum number of requests to process in one inference call
+            timeout: Max time to wait for batch to fill before processing (in seconds)
+        """
+        self.batch_size = batch_size
+        self.timeout = timeout
+        
+        self.queue: List[InferenceRequest] = []
+        self.pending_requests: Dict[str, List[str]] = {}  # request_id -> list of generated tokens
+        self.completion_futures: Dict[str, asyncio.Future] = {}
+        
+        self._lock = asyncio.Lock()
+        self._batch_ready_event = asyncio.Event()
+        
+    async def submit_request(self, prompt: str, max_tokens: int = 128) -> asyncio.Future[str]:
+        """
+        Submit a new inference request and return a Future that resolves to the generated text.
+        """
+        request = InferenceRequest(prompt=prompt, max_tokens=max_tokens)
+        
+        future = asyncio.get_event_loop().create_future()
+        
+        async with self._lock:
+            self.queue.append(request)
+            self.pending_requests[request.request_id] = []
+            self.completion_futures[request.request_id] = future
+            
+            # Trigger batch processing if we've reached batch size
+            if len(self.queue) >= self.batch_size:
+                self._batch_ready_event.set()
+        
+        # Start background worker if not already running
+        asyncio.create_task(self._worker())
+        
+        return future
+    
+    async def _worker(self):
+        """Background task that processes batches when ready"""
+        while True:
+            # Wait until we have requests or timeout triggers
+            try:
+                await asyncio.wait_for(self._batch_ready_event.wait(), timeout=self.timeout)
+            except asyncio.TimeoutError:
+                pass
+            
+            async with self._lock:
+                if not self.queue:
+                    self._batch_ready_event.clear()
+                    continue
+                
+                # Take up to batch_size requests
+                current_batch = self.queue[:self.batch_size]
+                self.queue = self.queue[self.batch_size:]
+                
+                # Reset event if queue is now empty
+                if not self.queue:
+                    self._batch_ready_event.clear()
+            
+            # Simulate batched LLM inference
+            token_outputs: List[TokenOutput] = self._simulate_batched_inference(current_batch)
+            
+            # Distribute tokens back to individual requests
+            async with self._lock:
+                for token_out in token_outputs:
+                    req_id = token_out.request_id
+                    self.pending_requests[req_id].append(token_out.token)
+                    
+                    # Check if this request is complete
+                    if token_out.is_complete:
+                        full_text = ''.join(self.pending_requests[req_id])
+                        future = self.completion_futures.pop(req_id)
+                        future.set_result(full_text)
+                        del self.pending_requests[req_id]
+    
+    def _simulate_batched_inference(self, batch: List[InferenceRequest]) -> List[TokenOutput]:
+        """
+        Simulates a real LLM forward pass on a batch.
+        In real implementation, this would call model.generate() with padded batch.
+        """
+        outputs: List[TokenOutput] = []
+        
+        for req in batch:
+            prompt_len = len(req.prompt.split())  # rough estimate
+            num_tokens_to_generate = min(req.max_tokens, 50 + hash(req.request_id) % 30)
+            
+            # Simulate token-by-token generation
+            token_str = f"{req.prompt}"
+            is_complete = (i == num_tokens_to_generate - 1)
+            
+            outputs.append(TokenOutput(
+                token=token_str,
+                request_id=req.request_id,
+                token_index=i,
+                is_complete=is_complete
+            ))
+            
+            # In streaming: could yield early here in real streaming setup
+        
+        return outputs
+    
+# Example usage
+async def main():
+    engine = BatchInferenceEngine(batch_size=4, timeout=0.5)
+    
+    # Submit several requests
+    futures = [
+        engine.submit_request("Tell me a joke about cats"),
+        engine.submit_request("Explain quantum computing in simple terms"),
+        engine.submit_request("Write a haiku about AI"),
+        engine.submit_request("What is the meaning of life?"),
+    ]
+    
+    print("Requests submitted, waiting for results...")
+    
+    # Wait for all to complete
+    results = await asyncio.gather(*futures)
+    
+    for i, text in enumerate(results):
+        print(f"Request {i}: {text.result()}\n\n\n")
+
+if __name__ == "__main__":
+    asyncio.run(main())