comparison asyncio_threads/inference/main.py @ 48:46daba6e3cf4

Few python scrtips to show how to use asychio.
author MrJuneJune <me@mrjunejune.com>
date Sat, 13 Dec 2025 14:23:02 -0800
parents
children
comparison
equal deleted inserted replaced
47:829623189a57 48:46daba6e3cf4
1 from typing import List, Dict, Any, Optional
2
3 class UserRequest:
4 """Represents an incoming user request."""
5 def __init__(self, request_id: int, prompt: str):
6 self.request_id = request_id
7 self.prompt = prompt
8 self.output_tokens: List[str] = [] # Stores tokens as they are generated
9
10 def __repr__(self):
11 return f"Request(ID={self.request_id}, Prompt='{self.prompt[:20]}...', Tokens={len(self.output_tokens)})"
12
13 # --- Mock LLM Interface ---
14 def mock_inference_call(prompts: List[str], batch_id: int) -> Dict[str, Any]:
15 """
16 Simulates the call to the underlying LLM/GPU.
17
18 In a real scenario, this returns generated tokens and associated metadata.
19 For this mock, we return a flat list of tokens, one for each request
20 in the batch, and the batch ID for verification.
21
22 The length of the tokens list MUST equal the length of the prompts list.
23 """
24 print(f" [INFERENCE] Running Batch {batch_id} with {len(prompts)} requests...")
25 results = {
26 'batch_id': batch_id,
27 # Simulate generating a single new token for each request in the batch
28 'generated_tokens': [
29 f"token_{i+1}_of_{batch_id}" for i, _ in enumerate(prompts)
30 ]
31 }
32 return results
33 # -------------------------
34
35
36 class BatchInferenceEngine:
37 """
38 A simplified inference engine that handles request batching and
39 token-level result distribution.
40 """
41 def __init__(self, batch_size: int = 4):
42 self.batch_size = batch_size
43 self.request_queue: List[UserRequest] = []
44 self.next_batch_id = 1
45
46 def enqueue_request(self, request: UserRequest) -> None:
47 """
48 Adds a new request to the queue and triggers batch processing if
49 the batch size is reached.
50 """
51 # --- YOUR CODE HERE ---
52 # 1. Add the request to the queue.
53 # 2. Check if the queue size meets or exceeds self.batch_size.
54 # 3. If so, call self._process_batch().
55 # ----------------------
56 self.request_queue.append(request)
57 while len(self.request_queue) > self.batch_size:
58 self._process_batch()
59
60
61 def _process_batch(self) -> None:
62 """
63 Executes the inference call for the current batch and distributes
64 the results back to the individual requests.
65 """
66 batch = self.request_queue[:self.batch_size]
67 prompts = map(lambda x : x.prompt, batch)
68 results = mock_inference_call(list(prompts), self.next_batch_id)
69 for request in self.request_queue:
70 request.output_tokens = results["generated_tokens"]
71 self.request_queue[self.batch_size:]
72 self.next_batch_id += 1
73
74 def get_results(self, request_id: int) -> Optional[List[str]]:
75 """
76 In a real system, this would retrieve results from a separate
77 completed-requests store. For this mock, assume we can only
78 retrieve results for requests that have been fully processed and
79 are no longer in the queue.
80 """
81 # For simplicity, assume all requests that have been processed
82 # by a batch call have completed their generation for this *single step*
83 # of the mock. If you want to make this more realistic, feel free to
84 # expand the UserRequest class to include a 'is_complete' flag.
85
86 # For the provided mock structure, we'll just check the queue:
87
88 for req in self.request_queue:
89 if req.request_id == request_id:
90 # If it's still in the queue, it hasn't been processed yet
91 return None
92
93 # In a complete system, you'd look up the request ID in a completed-requests map.
94 # For this simplified version, let's just return a simulated result for
95 # requests that *would* have been processed:
96
97 # If the request ID is less than the ID of the requests that would be
98 # processed in the *next* batch, we simulate a complete token output.
99 if request_id < self.next_batch_id * self.batch_size:
100 # Simple simulation:
101 return [f"token_X_of_{batch_num}" for batch_num in range(1, self.next_batch_id)]
102
103 return None
104
105
106 def main():
107 pass