annotate async_multi_threads/inference.py @ 70:4bc56e88e1f3

Remove unnecessary files.
author June Park <parkjune1995@gmail.com>
date Thu, 25 Dec 2025 20:10:46 -0800
parents 551d9fc0a2ba
children
Ignore whitespace changes - Everywhere: Within whitespace: At end of lines:
rev   line source
69
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
1 import asyncio
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
2 import time
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
3 from typing import List, Dict, Any, Optional
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
4 import uuid
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
5 from dataclasses import dataclass, field
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
6
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
7 @dataclass
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
8 class InferenceRequest:
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
9 prompt: str
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
10 request_id: str = field(default_factory=lambda: str(uuid.uuid4()))
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
11 max_tokens: int = 128
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
12 created_at: float = field(default_factory=time.time)
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
13
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
14 @dataclass
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
15 class TokenOutput:
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
16 token: str
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
17 request_id: str
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
18 token_index: int # position in this request's generation
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
19 is_complete: bool = False
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
20
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
21 class BatchInferenceEngine:
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
22
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
23 def __init__(self, batch_size: int = 8, timeout: float = 0.1):
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
24 """
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
25 Args:
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
26 batch_size: Maximum number of requests to process in one inference call
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
27 timeout: Max time to wait for batch to fill before processing (in seconds)
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
28 """
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
29 self.batch_size = batch_size
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
30 self.timeout = timeout
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
31
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
32 self.queue: List[InferenceRequest] = []
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
33 self.pending_requests: Dict[str, List[str]] = {} # request_id -> list of generated tokens
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
34 self.completion_futures: Dict[str, asyncio.Future] = {}
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
35
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
36 self._lock = asyncio.Lock()
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
37 self._batch_ready_event = asyncio.Event()
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
38
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
39 async def submit_request(self, prompt: str, max_tokens: int = 128) -> asyncio.Future[str]:
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
40 """
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
41 Submit a new inference request and return a Future that resolves to the generated text.
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
42 """
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
43 request = InferenceRequest(prompt=prompt, max_tokens=max_tokens)
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
44
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
45 future = asyncio.get_event_loop().create_future()
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
46
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
47 async with self._lock:
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
48 self.queue.append(request)
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
49 self.pending_requests[request.request_id] = []
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
50 self.completion_futures[request.request_id] = future
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
51
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
52 # Trigger batch processing if we've reached batch size
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
53 if len(self.queue) >= self.batch_size:
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
54 self._batch_ready_event.set()
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
55
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
56 # Start background worker if not already running
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
57 asyncio.create_task(self._worker())
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
58
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
59 return future
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
60
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
61 async def _worker(self):
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
62 """Background task that processes batches when ready"""
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
63 while True:
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
64 # Wait until we have requests or timeout triggers
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
65 try:
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
66 await asyncio.wait_for(self._batch_ready_event.wait(), timeout=self.timeout)
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
67 except asyncio.TimeoutError:
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
68 pass
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
69
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
70 async with self._lock:
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
71 if not self.queue:
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
72 self._batch_ready_event.clear()
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
73 continue
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
74
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
75 # Take up to batch_size requests
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
76 current_batch = self.queue[:self.batch_size]
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
77 self.queue = self.queue[self.batch_size:]
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
78
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
79 # Reset event if queue is now empty
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
80 if not self.queue:
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
81 self._batch_ready_event.clear()
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
82
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
83 # Simulate batched LLM inference
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
84 token_outputs: List[TokenOutput] = self._simulate_batched_inference(current_batch)
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
85
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
86 # Distribute tokens back to individual requests
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
87 async with self._lock:
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
88 for token_out in token_outputs:
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
89 req_id = token_out.request_id
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
90 self.pending_requests[req_id].append(token_out.token)
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
91
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
92 # Check if this request is complete
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
93 if token_out.is_complete:
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
94 full_text = ''.join(self.pending_requests[req_id])
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
95 future = self.completion_futures.pop(req_id)
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
96 future.set_result(full_text)
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
97 del self.pending_requests[req_id]
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
98
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
99 def _simulate_batched_inference(self, batch: List[InferenceRequest]) -> List[TokenOutput]:
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
100 """
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
101 Simulates a real LLM forward pass on a batch.
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
102 In real implementation, this would call model.generate() with padded batch.
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
103 """
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
104 outputs: List[TokenOutput] = []
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
105
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
106 for req in batch:
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
107 prompt_len = len(req.prompt.split()) # rough estimate
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
108 num_tokens_to_generate = min(req.max_tokens, 50 + hash(req.request_id) % 30)
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
109
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
110 # Simulate token-by-token generation
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
111 token_str = f"{req.prompt}"
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
112 is_complete = (i == num_tokens_to_generate - 1)
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
113
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
114 outputs.append(TokenOutput(
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
115 token=token_str,
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
116 request_id=req.request_id,
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
117 token_index=i,
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
118 is_complete=is_complete
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
119 ))
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
120
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
121 # In streaming: could yield early here in real streaming setup
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
122
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
123 return outputs
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
124
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
125 # Example usage
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
126 async def main():
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
127 engine = BatchInferenceEngine(batch_size=4, timeout=0.5)
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
128
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
129 # Submit several requests
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
130 futures = [
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
131 engine.submit_request("Tell me a joke about cats"),
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
132 engine.submit_request("Explain quantum computing in simple terms"),
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
133 engine.submit_request("Write a haiku about AI"),
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
134 engine.submit_request("What is the meaning of life?"),
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
135 ]
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
136
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
137 print("Requests submitted, waiting for results...")
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
138
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
139 # Wait for all to complete
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
140 results = await asyncio.gather(*futures)
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
141
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
142 for i, text in enumerate(results):
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
143 print(f"Request {i}: {text.result()}\n\n\n")
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
144
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
145 if __name__ == "__main__":
551d9fc0a2ba Updated wrong names.
June Park <parkjune1995@gmail.com>
parents:
diff changeset
146 asyncio.run(main())