Mercurial
comparison asyncio_threads/inference/main.py @ 48:46daba6e3cf4
Few python scrtips to show how to use asychio.
| author | MrJuneJune <me@mrjunejune.com> |
|---|---|
| date | Sat, 13 Dec 2025 14:23:02 -0800 |
| parents | |
| children |
comparison
equal
deleted
inserted
replaced
| 47:829623189a57 | 48:46daba6e3cf4 |
|---|---|
| 1 from typing import List, Dict, Any, Optional | |
| 2 | |
| 3 class UserRequest: | |
| 4 """Represents an incoming user request.""" | |
| 5 def __init__(self, request_id: int, prompt: str): | |
| 6 self.request_id = request_id | |
| 7 self.prompt = prompt | |
| 8 self.output_tokens: List[str] = [] # Stores tokens as they are generated | |
| 9 | |
| 10 def __repr__(self): | |
| 11 return f"Request(ID={self.request_id}, Prompt='{self.prompt[:20]}...', Tokens={len(self.output_tokens)})" | |
| 12 | |
| 13 # --- Mock LLM Interface --- | |
| 14 def mock_inference_call(prompts: List[str], batch_id: int) -> Dict[str, Any]: | |
| 15 """ | |
| 16 Simulates the call to the underlying LLM/GPU. | |
| 17 | |
| 18 In a real scenario, this returns generated tokens and associated metadata. | |
| 19 For this mock, we return a flat list of tokens, one for each request | |
| 20 in the batch, and the batch ID for verification. | |
| 21 | |
| 22 The length of the tokens list MUST equal the length of the prompts list. | |
| 23 """ | |
| 24 print(f" [INFERENCE] Running Batch {batch_id} with {len(prompts)} requests...") | |
| 25 results = { | |
| 26 'batch_id': batch_id, | |
| 27 # Simulate generating a single new token for each request in the batch | |
| 28 'generated_tokens': [ | |
| 29 f"token_{i+1}_of_{batch_id}" for i, _ in enumerate(prompts) | |
| 30 ] | |
| 31 } | |
| 32 return results | |
| 33 # ------------------------- | |
| 34 | |
| 35 | |
| 36 class BatchInferenceEngine: | |
| 37 """ | |
| 38 A simplified inference engine that handles request batching and | |
| 39 token-level result distribution. | |
| 40 """ | |
| 41 def __init__(self, batch_size: int = 4): | |
| 42 self.batch_size = batch_size | |
| 43 self.request_queue: List[UserRequest] = [] | |
| 44 self.next_batch_id = 1 | |
| 45 | |
| 46 def enqueue_request(self, request: UserRequest) -> None: | |
| 47 """ | |
| 48 Adds a new request to the queue and triggers batch processing if | |
| 49 the batch size is reached. | |
| 50 """ | |
| 51 # --- YOUR CODE HERE --- | |
| 52 # 1. Add the request to the queue. | |
| 53 # 2. Check if the queue size meets or exceeds self.batch_size. | |
| 54 # 3. If so, call self._process_batch(). | |
| 55 # ---------------------- | |
| 56 self.request_queue.append(request) | |
| 57 while len(self.request_queue) > self.batch_size: | |
| 58 self._process_batch() | |
| 59 | |
| 60 | |
| 61 def _process_batch(self) -> None: | |
| 62 """ | |
| 63 Executes the inference call for the current batch and distributes | |
| 64 the results back to the individual requests. | |
| 65 """ | |
| 66 batch = self.request_queue[:self.batch_size] | |
| 67 prompts = map(lambda x : x.prompt, batch) | |
| 68 results = mock_inference_call(list(prompts), self.next_batch_id) | |
| 69 for request in self.request_queue: | |
| 70 request.output_tokens = results["generated_tokens"] | |
| 71 self.request_queue[self.batch_size:] | |
| 72 self.next_batch_id += 1 | |
| 73 | |
| 74 def get_results(self, request_id: int) -> Optional[List[str]]: | |
| 75 """ | |
| 76 In a real system, this would retrieve results from a separate | |
| 77 completed-requests store. For this mock, assume we can only | |
| 78 retrieve results for requests that have been fully processed and | |
| 79 are no longer in the queue. | |
| 80 """ | |
| 81 # For simplicity, assume all requests that have been processed | |
| 82 # by a batch call have completed their generation for this *single step* | |
| 83 # of the mock. If you want to make this more realistic, feel free to | |
| 84 # expand the UserRequest class to include a 'is_complete' flag. | |
| 85 | |
| 86 # For the provided mock structure, we'll just check the queue: | |
| 87 | |
| 88 for req in self.request_queue: | |
| 89 if req.request_id == request_id: | |
| 90 # If it's still in the queue, it hasn't been processed yet | |
| 91 return None | |
| 92 | |
| 93 # In a complete system, you'd look up the request ID in a completed-requests map. | |
| 94 # For this simplified version, let's just return a simulated result for | |
| 95 # requests that *would* have been processed: | |
| 96 | |
| 97 # If the request ID is less than the ID of the requests that would be | |
| 98 # processed in the *next* batch, we simulate a complete token output. | |
| 99 if request_id < self.next_batch_id * self.batch_size: | |
| 100 # Simple simulation: | |
| 101 return [f"token_X_of_{batch_num}" for batch_num in range(1, self.next_batch_id)] | |
| 102 | |
| 103 return None | |
| 104 | |
| 105 | |
| 106 def main(): | |
| 107 pass |