Mercurial
view asyncio_threads/inference/main.py @ 78:e7bf9e002850
amend
| author | June Park <parkjune1995@gmail.com> |
|---|---|
| date | Wed, 31 Dec 2025 15:07:43 -0800 |
| parents | 46daba6e3cf4 |
| children |
line wrap: on
line source
from typing import List, Dict, Any, Optional class UserRequest: """Represents an incoming user request.""" def __init__(self, request_id: int, prompt: str): self.request_id = request_id self.prompt = prompt self.output_tokens: List[str] = [] # Stores tokens as they are generated def __repr__(self): return f"Request(ID={self.request_id}, Prompt='{self.prompt[:20]}...', Tokens={len(self.output_tokens)})" # --- Mock LLM Interface --- def mock_inference_call(prompts: List[str], batch_id: int) -> Dict[str, Any]: """ Simulates the call to the underlying LLM/GPU. In a real scenario, this returns generated tokens and associated metadata. For this mock, we return a flat list of tokens, one for each request in the batch, and the batch ID for verification. The length of the tokens list MUST equal the length of the prompts list. """ print(f" [INFERENCE] Running Batch {batch_id} with {len(prompts)} requests...") results = { 'batch_id': batch_id, # Simulate generating a single new token for each request in the batch 'generated_tokens': [ f"token_{i+1}_of_{batch_id}" for i, _ in enumerate(prompts) ] } return results # ------------------------- class BatchInferenceEngine: """ A simplified inference engine that handles request batching and token-level result distribution. """ def __init__(self, batch_size: int = 4): self.batch_size = batch_size self.request_queue: List[UserRequest] = [] self.next_batch_id = 1 def enqueue_request(self, request: UserRequest) -> None: """ Adds a new request to the queue and triggers batch processing if the batch size is reached. """ # --- YOUR CODE HERE --- # 1. Add the request to the queue. # 2. Check if the queue size meets or exceeds self.batch_size. # 3. If so, call self._process_batch(). # ---------------------- self.request_queue.append(request) while len(self.request_queue) > self.batch_size: self._process_batch() def _process_batch(self) -> None: """ Executes the inference call for the current batch and distributes the results back to the individual requests. """ batch = self.request_queue[:self.batch_size] prompts = map(lambda x : x.prompt, batch) results = mock_inference_call(list(prompts), self.next_batch_id) for request in self.request_queue: request.output_tokens = results["generated_tokens"] self.request_queue[self.batch_size:] self.next_batch_id += 1 def get_results(self, request_id: int) -> Optional[List[str]]: """ In a real system, this would retrieve results from a separate completed-requests store. For this mock, assume we can only retrieve results for requests that have been fully processed and are no longer in the queue. """ # For simplicity, assume all requests that have been processed # by a batch call have completed their generation for this *single step* # of the mock. If you want to make this more realistic, feel free to # expand the UserRequest class to include a 'is_complete' flag. # For the provided mock structure, we'll just check the queue: for req in self.request_queue: if req.request_id == request_id: # If it's still in the queue, it hasn't been processed yet return None # In a complete system, you'd look up the request ID in a completed-requests map. # For this simplified version, let's just return a simulated result for # requests that *would* have been processed: # If the request ID is less than the ID of the requests that would be # processed in the *next* batch, we simulate a complete token output. if request_id < self.next_batch_id * self.batch_size: # Simple simulation: return [f"token_X_of_{batch_num}" for batch_num in range(1, self.next_batch_id)] return None def main(): pass