Multi-Model Routing¶
Route queries to the optimal LLM based on complexity, cost, and performance requirements.
Overview¶
Multi-model routing intelligently distributes queries across multiple LLMs to optimize for:
- Cost - Route simple queries to cheaper models
- Performance - Route complex queries to capable models
- Latency - Route to fastest model when speed matters
- Specialization - Route to domain-specific models
Quick Start¶
from stratarouter import Router, Route
from stratarouter.multi_model import ModelRouter, ModelConfig
# Define models with capabilities
models = [
ModelConfig(
id="gpt-4",
cost_per_1k_tokens=0.03,
avg_latency_ms=2000,
max_complexity=10,
capabilities=["reasoning", "code", "math"]
),
ModelConfig(
id="gpt-3.5-turbo",
cost_per_1k_tokens=0.002,
avg_latency_ms=800,
max_complexity=7,
capabilities=["general", "qa"]
),
ModelConfig(
id="llama-2-70b",
cost_per_1k_tokens=0.001,
avg_latency_ms=500,
max_complexity=6,
capabilities=["general", "privacy"]
)
]
# Create model router
model_router = ModelRouter(models)
# Route based on complexity
query = "Explain quantum entanglement"
complexity = model_router.estimate_complexity(query)
selected_model = model_router.select_model(complexity=complexity)
print(f"Selected: {selected_model.id} (complexity: {complexity})")
Complexity Estimation¶
Automatic Complexity Scoring¶
class ComplexityEstimator:
"""Estimate query complexity."""
def __init__(self):
self.patterns = {
"simple": [
r"what is",
r"define",
r"who is",
r"when did"
],
"medium": [
r"how to",
r"explain",
r"compare",
r"analyze"
],
"complex": [
r"derive",
r"prove",
r"design",
r"implement"
]
}
def estimate(self, query: str) -> float:
"""Return complexity score 0-10."""
# Check patterns
for level, patterns in self.patterns.items():
for pattern in patterns:
if re.search(pattern, query.lower()):
if level == "simple":
return 3.0
elif level == "medium":
return 6.0
else:
return 9.0
# Fallback: use query length and features
length_score = min(len(query) / 100, 5.0)
features = {
"code": 2.0 if "```" in query else 0,
"math": 2.0 if any(c in query for c in "∫∑∏") else 0,
"multipart": 1.0 if len(query.split("?")) > 2 else 0
}
return min(length_score + sum(features.values()), 10.0)
estimator = ComplexityEstimator()
print(estimator.estimate("What is Python?")) # 3.0
print(estimator.estimate("Explain how neural networks learn")) # 6.0
print(estimator.estimate("Derive the backpropagation algorithm")) # 9.0
ML-Based Complexity¶
from transformers import AutoTokenizer, AutoModel
import torch
class MLComplexityEstimator:
"""ML-based complexity estimation."""
def __init__(self):
self.model = AutoModel.from_pretrained("complexity-classifier")
self.tokenizer = AutoTokenizer.from_pretrained("complexity-classifier")
def estimate(self, query: str) -> float:
"""Estimate using trained model."""
inputs = self.tokenizer(query, return_tensors="pt")
with torch.no_grad():
outputs = self.model(**inputs)
complexity = outputs.logits.softmax(dim=-1)
# Return weighted score
return float((complexity * torch.tensor([1, 5, 10])).sum())
Model Selection Strategies¶
1. Cost-Optimized¶
class CostOptimizedRouter:
"""Route to minimize cost while meeting quality requirements."""
def select_model(self, complexity: float, min_quality: float = 0.8):
"""Select cheapest model that meets requirements."""
suitable_models = [
m for m in self.models
if m.max_complexity >= complexity
]
# Sort by cost
suitable_models.sort(key=lambda m: m.cost_per_1k_tokens)
return suitable_models[0] if suitable_models else self.models[-1]
2. Latency-Optimized¶
class LatencyOptimizedRouter:
"""Route to minimize latency."""
def select_model(self, complexity: float):
"""Select fastest model that can handle complexity."""
suitable_models = [
m for m in self.models
if m.max_complexity >= complexity
]
# Sort by latency
suitable_models.sort(key=lambda m: m.avg_latency_ms)
return suitable_models[0]
3. Quality-Optimized¶
class QualityOptimizedRouter:
"""Route to maximize quality."""
def select_model(self, complexity: float, required_capabilities: list = None):
"""Select highest quality model."""
if required_capabilities:
suitable_models = [
m for m in self.models
if all(cap in m.capabilities for cap in required_capabilities)
]
else:
suitable_models = self.models
# Sort by max complexity (proxy for capability)
suitable_models.sort(key=lambda m: m.max_complexity, reverse=True)
return suitable_models[0]
4. Balanced¶
class BalancedRouter:
"""Balance cost, latency, and quality."""
def select_model(self, complexity: float, weights: dict = None):
"""Select model optimizing multiple objectives."""
weights = weights or {
"cost": 0.4,
"latency": 0.3,
"quality": 0.3
}
scores = []
for model in self.models:
# Normalize metrics (0-1)
cost_score = 1 - (model.cost_per_1k_tokens / max_cost)
latency_score = 1 - (model.avg_latency_ms / max_latency)
quality_score = model.max_complexity / 10.0
# Weighted sum
total_score = (
weights["cost"] * cost_score +
weights["latency"] * latency_score +
weights["quality"] * quality_score
)
scores.append((model, total_score))
# Return best
scores.sort(key=lambda x: x[1], reverse=True)
return scores[0][0]
Advanced Routing Patterns¶
Cascade Routing¶
Try cheap model first, escalate if quality insufficient:
class CascadeRouter:
"""Cascade through models until quality threshold met."""
async def route_cascade(self, query: str):
"""Try models in order of increasing capability."""
models_ordered = sorted(self.models, key=lambda m: m.max_complexity)
for model in models_ordered:
# Try model
response = await self.execute(model, query)
# Check quality
quality = self.assess_quality(response)
if quality >= self.quality_threshold:
return {
"model": model.id,
"response": response,
"quality": quality,
"cost": model.cost_per_1k_tokens
}
# Fallback to most capable
return await self.execute(models_ordered[-1], query)
Ensemble Routing¶
Query multiple models and aggregate responses:
class EnsembleRouter:
"""Query multiple models and aggregate."""
async def route_ensemble(self, query: str, n_models: int = 3):
"""Get responses from multiple models."""
# Select diverse models
models = self.select_diverse_models(n_models)
# Query in parallel
responses = await asyncio.gather(*[
self.execute(model, query)
for model in models
])
# Aggregate (majority vote, averaging, etc.)
final_response = self.aggregate_responses(responses)
return {
"response": final_response,
"models_used": [m.id for m in models],
"confidence": self.compute_confidence(responses)
}
Specialized Routing¶
Route to domain-specific models:
class SpecializedRouter:
"""Route to domain-specific models."""
def __init__(self):
self.domain_models = {
"code": ["codex", "code-llama"],
"math": ["gpt-4", "claude-2"],
"medical": ["med-palm", "gpt-4-medical"],
"legal": ["legal-bert", "gpt-4-legal"]
}
def select_model(self, query: str, domain: str = None):
"""Select model based on domain."""
if not domain:
domain = self.detect_domain(query)
# Get domain-specific models
candidates = self.domain_models.get(domain, self.models)
# Select best from candidates
return self.select_best(candidates, query)
Cost Tracking¶
Track Usage¶
class CostTracker:
"""Track costs across models."""
def __init__(self):
self.usage = {} # model_id -> tokens used
self.costs = {} # model_id -> total cost
def record_usage(self, model_id: str, tokens: int, cost: float):
"""Record model usage."""
self.usage[model_id] = self.usage.get(model_id, 0) + tokens
self.costs[model_id] = self.costs.get(model_id, 0) + cost
def get_report(self) -> dict:
"""Get cost report."""
return {
"total_cost": sum(self.costs.values()),
"by_model": {
model: {
"tokens": self.usage[model],
"cost": self.costs[model],
"cost_per_request": self.costs[model] / requests
}
for model in self.costs
}
}
# Usage
tracker = CostTracker()
result = await model_router.execute(model, query)
tracker.record_usage(
model_id=model.id,
tokens=result.tokens_used,
cost=result.cost
)
print(tracker.get_report())
Cost Optimization¶
def optimize_model_selection(historical_data):
"""Optimize model selection based on historical performance."""
# Analyze query patterns
query_types = cluster_queries(historical_data)
# For each type, find optimal model
optimal_models = {}
for query_type, queries in query_types.items():
# Calculate cost-quality tradeoff
model_performance = {}
for model in models:
avg_cost = np.mean([q.cost for q in queries if q.model == model.id])
avg_quality = np.mean([q.quality for q in queries if q.model == model.id])
model_performance[model.id] = {
"cost": avg_cost,
"quality": avg_quality,
"score": avg_quality / avg_cost # quality per dollar
}
# Select best
best_model = max(model_performance.items(), key=lambda x: x[1]["score"])
optimal_models[query_type] = best_model[0]
return optimal_models
Integration with Runtime¶
from stratarouter_runtime import CoreRuntimeBridge
class MultiModelBridge(CoreRuntimeBridge):
"""Runtime bridge with multi-model routing."""
def __init__(self, config):
super().__init__(config)
self.model_router = ModelRouter(config.models)
self.cost_tracker = CostTracker()
async def execute(self, decision, context):
"""Execute with multi-model routing."""
# Estimate complexity
complexity = self.model_router.estimate_complexity(context["query"])
# Select model
model = self.model_router.select_model(
complexity=complexity,
strategy=context.get("strategy", "balanced")
)
# Execute
result = await super().execute(decision, context, model=model)
# Track cost
self.cost_tracker.record_usage(
model_id=model.id,
tokens=result.tokens_used,
cost=result.cost
)
return result
# Usage
bridge = MultiModelBridge(config)
result = await bridge.execute(decision, context)
Monitoring¶
Model Performance Metrics¶
from prometheus_client import Counter, Histogram, Gauge
model_requests = Counter(
"model_requests_total",
"Total requests by model",
["model_id"]
)
model_latency = Histogram(
"model_latency_seconds",
"Model latency",
["model_id"]
)
model_cost = Counter(
"model_cost_usd_total",
"Total cost by model",
["model_id"]
)
# Record metrics
model_requests.labels(model_id=model.id).inc()
model_latency.labels(model_id=model.id).observe(latency)
model_cost.labels(model_id=model.id).inc(cost)
Best Practices¶
- Start with two models - GPT-4 for complex, GPT-3.5 for simple
- Measure everything - Track cost, latency, quality per model
- Optimize thresholds - Tune complexity thresholds based on data
- Use caching - Cache expensive model responses
- Monitor drift - Watch for changes in model performance
- Test fallbacks - Ensure graceful degradation
Next Steps¶
- Cost Optimization - Reduce LLM costs
- Performance - Optimize latency
- Metrics - Advanced monitoring
Route intelligently across models.