view grok_interview/inference_my_version.py @ 170:7fce234bfdca

Closing old fzf head
author MrJuneJune <me@mrjunejune.com>
date Mon, 19 Jan 2026 17:35:39 -0800
parents 68fa88ac73fe
children
line wrap: on
line source

"""
Inference Questions

Context

You are tasked with building a simplified inference engine component responsible for handling incoming user requests for a large language model (LLM). To optimize throughput and GPU utilization, the engine must batch multiple requests together, run the inference call once per batch, and then deconstruct the results to return token-level output to the individual users.

Objective

Complete the provided Python class, BatchInferenceEngine by implementing the methods necessary to:

Queue incoming user requests.

Process a batch when the queue reaches a defined batch size.
Simulate the token-level output from an LLM and correctly associate each generated token with its original request.
"""

import asyncio
from dataclasses import dataclass, field
from time import time
from typing import Dict, List
import uuid

@dataclass
class UserRequest:
    prompt: str
    id: str = field(default_factory=lambda: str(uuid.uuid4()))
    created_at: float = field(default_factory=time)


@dataclass
class TokenOutput:
    request_id: str
    token: bytes 

class BatchInferenceEngine:

    def __init__(self, batch_sizes: int = 8):
        self.queue = []
        self.request_token_map: Dict[str, str] = {}
        self.batch_sizes = batch_sizes

        self._lock = asyncio.Lock()
        self._batch_event = asyncio.Event()

    async def add_to_queue(self, request: UserRequest):
        async with self._lock:
            self.queue.append(request)

            if len(self.queue) > self.batch_sizes:
                self._batch_event.set()
            
        task = asyncio.create_task(self._batch())
        return task

    async def _batch(self):
        while True:
            try:
                await asyncio.wait_for(self._batch_event.wait(), timeout=0.05)
            except:
                raise Exception("Timed out")

            async with self._lock:
                if not self.queue:
                    self._batch_event.clear()
                    continue 

                batch = self.queue[:self.batch_sizes]
                tokens = await self._handle_prompt_to_token(batch)

            return tokens


    async def _handle_prompt_to_token(self, batch: List[UserRequest]):
        pass