view asyncio_threads/inference/main.py @ 108:f07abbcd2ec5

[HgWeb] Will probably hold off on using it since it is not urgent.
author June Park <parkjune1995@gmail.com>
date Sat, 03 Jan 2026 17:29:12 -0800
parents 46daba6e3cf4
children
line wrap: on
line source

from typing import List, Dict, Any, Optional

class UserRequest:
    """Represents an incoming user request."""
    def __init__(self, request_id: int, prompt: str):
        self.request_id = request_id
        self.prompt = prompt
        self.output_tokens: List[str] = [] # Stores tokens as they are generated

    def __repr__(self):
        return f"Request(ID={self.request_id}, Prompt='{self.prompt[:20]}...', Tokens={len(self.output_tokens)})"

# --- Mock LLM Interface ---
def mock_inference_call(prompts: List[str], batch_id: int) -> Dict[str, Any]:
    """
    Simulates the call to the underlying LLM/GPU.

    In a real scenario, this returns generated tokens and associated metadata.
    For this mock, we return a flat list of tokens, one for each request
    in the batch, and the batch ID for verification.

    The length of the tokens list MUST equal the length of the prompts list.
    """
    print(f"  [INFERENCE] Running Batch {batch_id} with {len(prompts)} requests...")
    results = {
        'batch_id': batch_id,
        # Simulate generating a single new token for each request in the batch
        'generated_tokens': [
            f"token_{i+1}_of_{batch_id}" for i, _ in enumerate(prompts)
        ]
    }
    return results
# -------------------------


class BatchInferenceEngine:
    """
    A simplified inference engine that handles request batching and
    token-level result distribution.
    """
    def __init__(self, batch_size: int = 4):
        self.batch_size = batch_size
        self.request_queue: List[UserRequest] = []
        self.next_batch_id = 1

    def enqueue_request(self, request: UserRequest) -> None:
        """
        Adds a new request to the queue and triggers batch processing if
        the batch size is reached.
        """
        # --- YOUR CODE HERE ---
        # 1. Add the request to the queue.
        # 2. Check if the queue size meets or exceeds self.batch_size.
        # 3. If so, call self._process_batch().
        # ----------------------
        self.request_queue.append(request)
        while len(self.request_queue) > self.batch_size:
            self._process_batch()


    def _process_batch(self) -> None:
        """
        Executes the inference call for the current batch and distributes
        the results back to the individual requests.
        """
        batch = self.request_queue[:self.batch_size]
        prompts = map(lambda x : x.prompt, batch)
        results = mock_inference_call(list(prompts), self.next_batch_id)
        for request in self.request_queue:
            request.output_tokens = results["generated_tokens"]
        self.request_queue[self.batch_size:]
        self.next_batch_id += 1

    def get_results(self, request_id: int) -> Optional[List[str]]:
        """
        In a real system, this would retrieve results from a separate
        completed-requests store. For this mock, assume we can only
        retrieve results for requests that have been fully processed and
        are no longer in the queue.
        """
        # For simplicity, assume all requests that have been processed
        # by a batch call have completed their generation for this *single step*
        # of the mock. If you want to make this more realistic, feel free to
        # expand the UserRequest class to include a 'is_complete' flag.
        
        # For the provided mock structure, we'll just check the queue:
        
        for req in self.request_queue:
            if req.request_id == request_id:
                # If it's still in the queue, it hasn't been processed yet
                return None
        
        # In a complete system, you'd look up the request ID in a completed-requests map.
        # For this simplified version, let's just return a simulated result for 
        # requests that *would* have been processed:
        
        # If the request ID is less than the ID of the requests that would be 
        # processed in the *next* batch, we simulate a complete token output.
        if request_id < self.next_batch_id * self.batch_size:
            # Simple simulation: 
            return [f"token_X_of_{batch_num}" for batch_num in range(1, self.next_batch_id)]
        
        return None


def main():
    pass