Mercurial
diff asyncio_threads/inference/main.py @ 48:46daba6e3cf4
Few python scrtips to show how to use asychio.
| author | MrJuneJune <me@mrjunejune.com> |
|---|---|
| date | Sat, 13 Dec 2025 14:23:02 -0800 |
| parents | |
| children |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/asyncio_threads/inference/main.py Sat Dec 13 14:23:02 2025 -0800 @@ -0,0 +1,107 @@ +from typing import List, Dict, Any, Optional + +class UserRequest: + """Represents an incoming user request.""" + def __init__(self, request_id: int, prompt: str): + self.request_id = request_id + self.prompt = prompt + self.output_tokens: List[str] = [] # Stores tokens as they are generated + + def __repr__(self): + return f"Request(ID={self.request_id}, Prompt='{self.prompt[:20]}...', Tokens={len(self.output_tokens)})" + +# --- Mock LLM Interface --- +def mock_inference_call(prompts: List[str], batch_id: int) -> Dict[str, Any]: + """ + Simulates the call to the underlying LLM/GPU. + + In a real scenario, this returns generated tokens and associated metadata. + For this mock, we return a flat list of tokens, one for each request + in the batch, and the batch ID for verification. + + The length of the tokens list MUST equal the length of the prompts list. + """ + print(f" [INFERENCE] Running Batch {batch_id} with {len(prompts)} requests...") + results = { + 'batch_id': batch_id, + # Simulate generating a single new token for each request in the batch + 'generated_tokens': [ + f"token_{i+1}_of_{batch_id}" for i, _ in enumerate(prompts) + ] + } + return results +# ------------------------- + + +class BatchInferenceEngine: + """ + A simplified inference engine that handles request batching and + token-level result distribution. + """ + def __init__(self, batch_size: int = 4): + self.batch_size = batch_size + self.request_queue: List[UserRequest] = [] + self.next_batch_id = 1 + + def enqueue_request(self, request: UserRequest) -> None: + """ + Adds a new request to the queue and triggers batch processing if + the batch size is reached. + """ + # --- YOUR CODE HERE --- + # 1. Add the request to the queue. + # 2. Check if the queue size meets or exceeds self.batch_size. + # 3. If so, call self._process_batch(). + # ---------------------- + self.request_queue.append(request) + while len(self.request_queue) > self.batch_size: + self._process_batch() + + + def _process_batch(self) -> None: + """ + Executes the inference call for the current batch and distributes + the results back to the individual requests. + """ + batch = self.request_queue[:self.batch_size] + prompts = map(lambda x : x.prompt, batch) + results = mock_inference_call(list(prompts), self.next_batch_id) + for request in self.request_queue: + request.output_tokens = results["generated_tokens"] + self.request_queue[self.batch_size:] + self.next_batch_id += 1 + + def get_results(self, request_id: int) -> Optional[List[str]]: + """ + In a real system, this would retrieve results from a separate + completed-requests store. For this mock, assume we can only + retrieve results for requests that have been fully processed and + are no longer in the queue. + """ + # For simplicity, assume all requests that have been processed + # by a batch call have completed their generation for this *single step* + # of the mock. If you want to make this more realistic, feel free to + # expand the UserRequest class to include a 'is_complete' flag. + + # For the provided mock structure, we'll just check the queue: + + for req in self.request_queue: + if req.request_id == request_id: + # If it's still in the queue, it hasn't been processed yet + return None + + # In a complete system, you'd look up the request ID in a completed-requests map. + # For this simplified version, let's just return a simulated result for + # requests that *would* have been processed: + + # If the request ID is less than the ID of the requests that would be + # processed in the *next* batch, we simulate a complete token output. + if request_id < self.next_batch_id * self.batch_size: + # Simple simulation: + return [f"token_X_of_{batch_num}" for batch_num in range(1, self.next_batch_id)] + + return None + + +def main(): + pass