|
51
|
1 import asyncio
|
|
|
2 import time
|
|
|
3 from typing import List, Dict, Any, Optional
|
|
|
4 import uuid
|
|
|
5 from dataclasses import dataclass, field
|
|
|
6
|
|
|
7 @dataclass
|
|
|
8 class InferenceRequest:
|
|
|
9 prompt: str
|
|
|
10 request_id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
|
|
11 max_tokens: int = 128
|
|
|
12 created_at: float = field(default_factory=time.time)
|
|
|
13
|
|
|
14 @dataclass
|
|
|
15 class TokenOutput:
|
|
|
16 token: str
|
|
|
17 request_id: str
|
|
|
18 token_index: int # position in this request's generation
|
|
|
19 is_complete: bool = False
|
|
|
20
|
|
|
21 class BatchInferenceEngine:
|
|
|
22
|
|
|
23 def __init__(self, batch_size: int = 8, timeout: float = 0.1):
|
|
|
24 """
|
|
|
25 Args:
|
|
|
26 batch_size: Maximum number of requests to process in one inference call
|
|
|
27 timeout: Max time to wait for batch to fill before processing (in seconds)
|
|
|
28 """
|
|
|
29 self.batch_size = batch_size
|
|
|
30 self.timeout = timeout
|
|
|
31
|
|
|
32 self.queue: List[InferenceRequest] = []
|
|
|
33 self.pending_requests: Dict[str, List[str]] = {} # request_id -> list of generated tokens
|
|
|
34 self.completion_futures: Dict[str, asyncio.Future] = {}
|
|
|
35
|
|
|
36 self._lock = asyncio.Lock()
|
|
|
37 self._batch_ready_event = asyncio.Event()
|
|
|
38
|
|
|
39 async def submit_request(self, prompt: str, max_tokens: int = 128) -> asyncio.Future[str]:
|
|
|
40 """
|
|
|
41 Submit a new inference request and return a Future that resolves to the generated text.
|
|
|
42 """
|
|
|
43 request = InferenceRequest(prompt=prompt, max_tokens=max_tokens)
|
|
|
44
|
|
|
45 future = asyncio.get_event_loop().create_future()
|
|
|
46
|
|
|
47 async with self._lock:
|
|
|
48 self.queue.append(request)
|
|
|
49 self.pending_requests[request.request_id] = []
|
|
|
50 self.completion_futures[request.request_id] = future
|
|
|
51
|
|
|
52 # Trigger batch processing if we've reached batch size
|
|
|
53 if len(self.queue) >= self.batch_size:
|
|
|
54 self._batch_ready_event.set()
|
|
|
55
|
|
|
56 # Start background worker if not already running
|
|
|
57 asyncio.create_task(self._worker())
|
|
|
58
|
|
|
59 return future
|
|
|
60
|
|
|
61 async def _worker(self):
|
|
|
62 """Background task that processes batches when ready"""
|
|
|
63 while True:
|
|
|
64 # Wait until we have requests or timeout triggers
|
|
|
65 try:
|
|
|
66 await asyncio.wait_for(self._batch_ready_event.wait(), timeout=self.timeout)
|
|
|
67 except asyncio.TimeoutError:
|
|
|
68 pass
|
|
|
69
|
|
|
70 async with self._lock:
|
|
|
71 if not self.queue:
|
|
|
72 self._batch_ready_event.clear()
|
|
|
73 continue
|
|
|
74
|
|
|
75 # Take up to batch_size requests
|
|
|
76 current_batch = self.queue[:self.batch_size]
|
|
|
77 self.queue = self.queue[self.batch_size:]
|
|
|
78
|
|
|
79 # Reset event if queue is now empty
|
|
|
80 if not self.queue:
|
|
|
81 self._batch_ready_event.clear()
|
|
|
82
|
|
|
83 # Simulate batched LLM inference
|
|
|
84 token_outputs: List[TokenOutput] = self._simulate_batched_inference(current_batch)
|
|
|
85
|
|
|
86 # Distribute tokens back to individual requests
|
|
|
87 async with self._lock:
|
|
|
88 for token_out in token_outputs:
|
|
|
89 req_id = token_out.request_id
|
|
|
90 self.pending_requests[req_id].append(token_out.token)
|
|
|
91
|
|
|
92 # Check if this request is complete
|
|
|
93 if token_out.is_complete:
|
|
|
94 full_text = ''.join(self.pending_requests[req_id])
|
|
|
95 future = self.completion_futures.pop(req_id)
|
|
|
96 future.set_result(full_text)
|
|
|
97 del self.pending_requests[req_id]
|
|
|
98
|
|
|
99 def _simulate_batched_inference(self, batch: List[InferenceRequest]) -> List[TokenOutput]:
|
|
|
100 """
|
|
|
101 Simulates a real LLM forward pass on a batch.
|
|
|
102 In real implementation, this would call model.generate() with padded batch.
|
|
|
103 """
|
|
|
104 outputs: List[TokenOutput] = []
|
|
|
105
|
|
|
106 for req in batch:
|
|
|
107 prompt_len = len(req.prompt.split()) # rough estimate
|
|
|
108 num_tokens_to_generate = min(req.max_tokens, 50 + hash(req.request_id) % 30)
|
|
|
109
|
|
|
110 # Simulate token-by-token generation
|
|
|
111 token_str = f"{req.prompt}"
|
|
|
112 is_complete = (i == num_tokens_to_generate - 1)
|
|
|
113
|
|
|
114 outputs.append(TokenOutput(
|
|
|
115 token=token_str,
|
|
|
116 request_id=req.request_id,
|
|
|
117 token_index=i,
|
|
|
118 is_complete=is_complete
|
|
|
119 ))
|
|
|
120
|
|
|
121 # In streaming: could yield early here in real streaming setup
|
|
|
122
|
|
|
123 return outputs
|
|
|
124
|
|
|
125 # Example usage
|
|
|
126 async def main():
|
|
|
127 engine = BatchInferenceEngine(batch_size=4, timeout=0.5)
|
|
|
128
|
|
|
129 # Submit several requests
|
|
|
130 futures = [
|
|
|
131 engine.submit_request("Tell me a joke about cats"),
|
|
|
132 engine.submit_request("Explain quantum computing in simple terms"),
|
|
|
133 engine.submit_request("Write a haiku about AI"),
|
|
|
134 engine.submit_request("What is the meaning of life?"),
|
|
|
135 ]
|
|
|
136
|
|
|
137 print("Requests submitted, waiting for results...")
|
|
|
138
|
|
|
139 # Wait for all to complete
|
|
|
140 results = await asyncio.gather(*futures)
|
|
|
141
|
|
|
142 for i, text in enumerate(results):
|
|
|
143 print(f"Request {i}: {text.result()}\n\n\n")
|
|
|
144
|
|
|
145 if __name__ == "__main__":
|
|
|
146 asyncio.run(main())
|