Skip to content

Multi-Model Routing

Route queries to the optimal LLM based on complexity, cost, and performance requirements.

Overview

Multi-model routing intelligently distributes queries across multiple LLMs to optimize for:

  • Cost - Route simple queries to cheaper models
  • Performance - Route complex queries to capable models
  • Latency - Route to fastest model when speed matters
  • Specialization - Route to domain-specific models

Quick Start

from stratarouter import Router, Route
from stratarouter.multi_model import ModelRouter, ModelConfig

# Define models with capabilities
models = [
    ModelConfig(
        id="gpt-4",
        cost_per_1k_tokens=0.03,
        avg_latency_ms=2000,
        max_complexity=10,
        capabilities=["reasoning", "code", "math"]
    ),
    ModelConfig(
        id="gpt-3.5-turbo",
        cost_per_1k_tokens=0.002,
        avg_latency_ms=800,
        max_complexity=7,
        capabilities=["general", "qa"]
    ),
    ModelConfig(
        id="llama-2-70b",
        cost_per_1k_tokens=0.001,
        avg_latency_ms=500,
        max_complexity=6,
        capabilities=["general", "privacy"]
    )
]

# Create model router
model_router = ModelRouter(models)

# Route based on complexity
query = "Explain quantum entanglement"
complexity = model_router.estimate_complexity(query)
selected_model = model_router.select_model(complexity=complexity)

print(f"Selected: {selected_model.id} (complexity: {complexity})")

Complexity Estimation

Automatic Complexity Scoring

class ComplexityEstimator:
    """Estimate query complexity."""

    def __init__(self):
        self.patterns = {
            "simple": [
                r"what is",
                r"define",
                r"who is",
                r"when did"
            ],
            "medium": [
                r"how to",
                r"explain",
                r"compare",
                r"analyze"
            ],
            "complex": [
                r"derive",
                r"prove",
                r"design",
                r"implement"
            ]
        }

    def estimate(self, query: str) -> float:
        """Return complexity score 0-10."""
        # Check patterns
        for level, patterns in self.patterns.items():
            for pattern in patterns:
                if re.search(pattern, query.lower()):
                    if level == "simple":
                        return 3.0
                    elif level == "medium":
                        return 6.0
                    else:
                        return 9.0

        # Fallback: use query length and features
        length_score = min(len(query) / 100, 5.0)

        features = {
            "code": 2.0 if "```" in query else 0,
            "math": 2.0 if any(c in query for c in "∫∑∏") else 0,
            "multipart": 1.0 if len(query.split("?")) > 2 else 0
        }

        return min(length_score + sum(features.values()), 10.0)

estimator = ComplexityEstimator()
print(estimator.estimate("What is Python?"))  # 3.0
print(estimator.estimate("Explain how neural networks learn"))  # 6.0
print(estimator.estimate("Derive the backpropagation algorithm"))  # 9.0

ML-Based Complexity

from transformers import AutoTokenizer, AutoModel
import torch

class MLComplexityEstimator:
    """ML-based complexity estimation."""

    def __init__(self):
        self.model = AutoModel.from_pretrained("complexity-classifier")
        self.tokenizer = AutoTokenizer.from_pretrained("complexity-classifier")

    def estimate(self, query: str) -> float:
        """Estimate using trained model."""
        inputs = self.tokenizer(query, return_tensors="pt")

        with torch.no_grad():
            outputs = self.model(**inputs)
            complexity = outputs.logits.softmax(dim=-1)

        # Return weighted score
        return float((complexity * torch.tensor([1, 5, 10])).sum())

Model Selection Strategies

1. Cost-Optimized

class CostOptimizedRouter:
    """Route to minimize cost while meeting quality requirements."""

    def select_model(self, complexity: float, min_quality: float = 0.8):
        """Select cheapest model that meets requirements."""
        suitable_models = [
            m for m in self.models
            if m.max_complexity >= complexity
        ]

        # Sort by cost
        suitable_models.sort(key=lambda m: m.cost_per_1k_tokens)

        return suitable_models[0] if suitable_models else self.models[-1]

2. Latency-Optimized

class LatencyOptimizedRouter:
    """Route to minimize latency."""

    def select_model(self, complexity: float):
        """Select fastest model that can handle complexity."""
        suitable_models = [
            m for m in self.models
            if m.max_complexity >= complexity
        ]

        # Sort by latency
        suitable_models.sort(key=lambda m: m.avg_latency_ms)

        return suitable_models[0]

3. Quality-Optimized

class QualityOptimizedRouter:
    """Route to maximize quality."""

    def select_model(self, complexity: float, required_capabilities: list = None):
        """Select highest quality model."""
        if required_capabilities:
            suitable_models = [
                m for m in self.models
                if all(cap in m.capabilities for cap in required_capabilities)
            ]
        else:
            suitable_models = self.models

        # Sort by max complexity (proxy for capability)
        suitable_models.sort(key=lambda m: m.max_complexity, reverse=True)

        return suitable_models[0]

4. Balanced

class BalancedRouter:
    """Balance cost, latency, and quality."""

    def select_model(self, complexity: float, weights: dict = None):
        """Select model optimizing multiple objectives."""
        weights = weights or {
            "cost": 0.4,
            "latency": 0.3,
            "quality": 0.3
        }

        scores = []
        for model in self.models:
            # Normalize metrics (0-1)
            cost_score = 1 - (model.cost_per_1k_tokens / max_cost)
            latency_score = 1 - (model.avg_latency_ms / max_latency)
            quality_score = model.max_complexity / 10.0

            # Weighted sum
            total_score = (
                weights["cost"] * cost_score +
                weights["latency"] * latency_score +
                weights["quality"] * quality_score
            )

            scores.append((model, total_score))

        # Return best
        scores.sort(key=lambda x: x[1], reverse=True)
        return scores[0][0]

Advanced Routing Patterns

Cascade Routing

Try cheap model first, escalate if quality insufficient:

class CascadeRouter:
    """Cascade through models until quality threshold met."""

    async def route_cascade(self, query: str):
        """Try models in order of increasing capability."""
        models_ordered = sorted(self.models, key=lambda m: m.max_complexity)

        for model in models_ordered:
            # Try model
            response = await self.execute(model, query)

            # Check quality
            quality = self.assess_quality(response)

            if quality >= self.quality_threshold:
                return {
                    "model": model.id,
                    "response": response,
                    "quality": quality,
                    "cost": model.cost_per_1k_tokens
                }

        # Fallback to most capable
        return await self.execute(models_ordered[-1], query)

Ensemble Routing

Query multiple models and aggregate responses:

class EnsembleRouter:
    """Query multiple models and aggregate."""

    async def route_ensemble(self, query: str, n_models: int = 3):
        """Get responses from multiple models."""
        # Select diverse models
        models = self.select_diverse_models(n_models)

        # Query in parallel
        responses = await asyncio.gather(*[
            self.execute(model, query)
            for model in models
        ])

        # Aggregate (majority vote, averaging, etc.)
        final_response = self.aggregate_responses(responses)

        return {
            "response": final_response,
            "models_used": [m.id for m in models],
            "confidence": self.compute_confidence(responses)
        }

Specialized Routing

Route to domain-specific models:

class SpecializedRouter:
    """Route to domain-specific models."""

    def __init__(self):
        self.domain_models = {
            "code": ["codex", "code-llama"],
            "math": ["gpt-4", "claude-2"],
            "medical": ["med-palm", "gpt-4-medical"],
            "legal": ["legal-bert", "gpt-4-legal"]
        }

    def select_model(self, query: str, domain: str = None):
        """Select model based on domain."""
        if not domain:
            domain = self.detect_domain(query)

        # Get domain-specific models
        candidates = self.domain_models.get(domain, self.models)

        # Select best from candidates
        return self.select_best(candidates, query)

Cost Tracking

Track Usage

class CostTracker:
    """Track costs across models."""

    def __init__(self):
        self.usage = {}  # model_id -> tokens used
        self.costs = {}  # model_id -> total cost

    def record_usage(self, model_id: str, tokens: int, cost: float):
        """Record model usage."""
        self.usage[model_id] = self.usage.get(model_id, 0) + tokens
        self.costs[model_id] = self.costs.get(model_id, 0) + cost

    def get_report(self) -> dict:
        """Get cost report."""
        return {
            "total_cost": sum(self.costs.values()),
            "by_model": {
                model: {
                    "tokens": self.usage[model],
                    "cost": self.costs[model],
                    "cost_per_request": self.costs[model] / requests
                }
                for model in self.costs
            }
        }

# Usage
tracker = CostTracker()

result = await model_router.execute(model, query)
tracker.record_usage(
    model_id=model.id,
    tokens=result.tokens_used,
    cost=result.cost
)

print(tracker.get_report())

Cost Optimization

def optimize_model_selection(historical_data):
    """Optimize model selection based on historical performance."""
    # Analyze query patterns
    query_types = cluster_queries(historical_data)

    # For each type, find optimal model
    optimal_models = {}
    for query_type, queries in query_types.items():
        # Calculate cost-quality tradeoff
        model_performance = {}
        for model in models:
            avg_cost = np.mean([q.cost for q in queries if q.model == model.id])
            avg_quality = np.mean([q.quality for q in queries if q.model == model.id])

            model_performance[model.id] = {
                "cost": avg_cost,
                "quality": avg_quality,
                "score": avg_quality / avg_cost  # quality per dollar
            }

        # Select best
        best_model = max(model_performance.items(), key=lambda x: x[1]["score"])
        optimal_models[query_type] = best_model[0]

    return optimal_models

Integration with Runtime

from stratarouter_runtime import CoreRuntimeBridge

class MultiModelBridge(CoreRuntimeBridge):
    """Runtime bridge with multi-model routing."""

    def __init__(self, config):
        super().__init__(config)
        self.model_router = ModelRouter(config.models)
        self.cost_tracker = CostTracker()

    async def execute(self, decision, context):
        """Execute with multi-model routing."""
        # Estimate complexity
        complexity = self.model_router.estimate_complexity(context["query"])

        # Select model
        model = self.model_router.select_model(
            complexity=complexity,
            strategy=context.get("strategy", "balanced")
        )

        # Execute
        result = await super().execute(decision, context, model=model)

        # Track cost
        self.cost_tracker.record_usage(
            model_id=model.id,
            tokens=result.tokens_used,
            cost=result.cost
        )

        return result

# Usage
bridge = MultiModelBridge(config)
result = await bridge.execute(decision, context)

Monitoring

Model Performance Metrics

from prometheus_client import Counter, Histogram, Gauge

model_requests = Counter(
    "model_requests_total",
    "Total requests by model",
    ["model_id"]
)

model_latency = Histogram(
    "model_latency_seconds",
    "Model latency",
    ["model_id"]
)

model_cost = Counter(
    "model_cost_usd_total",
    "Total cost by model",
    ["model_id"]
)

# Record metrics
model_requests.labels(model_id=model.id).inc()
model_latency.labels(model_id=model.id).observe(latency)
model_cost.labels(model_id=model.id).inc(cost)

Best Practices

  1. Start with two models - GPT-4 for complex, GPT-3.5 for simple
  2. Measure everything - Track cost, latency, quality per model
  3. Optimize thresholds - Tune complexity thresholds based on data
  4. Use caching - Cache expensive model responses
  5. Monitor drift - Watch for changes in model performance
  6. Test fallbacks - Ensure graceful degradation

Next Steps


Route intelligently across models.