CQRS Pattern for ML Systems: Separating Training from Inference Architecture
CQRS Pattern for ML Systems: Separating Training from Inference Architecture
Introduction
Command Query Responsibility Segregation (CQRS) is an architectural pattern that separates read operations (queries) from write operations (commands) into distinct models. While traditionally applied to CRUD applications, CQRS proves remarkably effective for machine learning systems, where training (write-heavy) and inference (read-heavy) have fundamentally different requirements for performance, scalability, and data consistency.
The Core Concept
In ML systems, CQRS maps naturally:
- Commands (Writes): Model training, retraining, feature engineering, data ingestion
- Queries (Reads): Real-time inference, batch predictions, model evaluations
These operations differ so dramatically in characteristics that forcing them into a unified architecture creates significant trade-offs.
Why CQRS for ML Systems?
1. Different Performance Profiles
Training (Write Side):
- Compute-intensive, long-running operations (minutes to days)
- Optimized for throughput over latency
- Can tolerate eventual consistency
- Requires GPU/TPU acceleration
- Handles large batch operations
Inference (Read Side):
- Latency-sensitive (milliseconds to seconds)
- Optimized for low latency over throughput
- Requires strong consistency (serve latest model)
- Often CPU-optimized or specialized inference accelerators
- Handles individual or micro-batch requests
2. Scaling Characteristics
Training scales vertically (larger GPUs) and in batch mode. Inference scales horizontally (more replicas) with consistent load patterns. CQRS allows independent scaling strategies.
3. Technology Stack Flexibility
Training often uses Python (PyTorch, TensorFlow) for flexibility. Inference may benefit from Go or Rust for performance and operational simplicity. CQRS permits polyglot architectures.
Architecture Pattern
┌─────────────────────────────────────────────────────────────┐
│ Command Side (Training) │
│ ┌──────────────┐ ┌──────────────┐ ┌────────────┐ │
│ │ Data │─────>│ Training │────>│ Model │ │
│ │ Ingestion │ │ Pipeline │ │ Registry │ │
│ │ (Python) │ │ (Python) │ │ (Versioned)│ │
│ └──────────────┘ └──────────────┘ └────────────┘ │
│ │ │ │
│ └──────────┬───────────┘ │
│ │ │
│ v │
│ ┌──────────────────┐ │
│ │ Event Store │ │
│ │ (Model Updates) │ │
│ └──────────────────┘ │
│ │ │
└─────────────────────────────────────────┼────────────────────┘
│
│ Model Update Events
│
┌─────────────────────────────────────────▼────────────────────┐
│ Query Side (Inference) │
│ ┌──────────────┐ ┌──────────────┐ ┌────────────┐ │
│ │ API Gateway │─────>│ Inference │────>│ Response │ │
│ │ (Go) │ │ Service (Go)│ │ Cache │ │
│ └──────────────┘ └──────────────┘ └────────────┘ │
│ │ │
│ v │
│ ┌──────────────┐ │
│ │ Model Store │ │
│ │ (Optimized) │ │
│ └──────────────┘ │
└───────────────────────────────────────────────────────────────┘
Implementation Example
Command Side: Training Pipeline (Python)
# training_pipeline.py
from dataclasses import dataclass
from typing import Dict, Any
import mlflow
import torch
@dataclass
class TrainingCommand:
"""Command to trigger model training"""
model_type: str
hyperparameters: Dict[str, Any]
training_data_version: str
experiment_id: str
class TrainingService:
def __init__(self, model_registry_client, event_publisher):
self.model_registry = model_registry_client
self.event_publisher = event_publisher
async def handle_training_command(self, command: TrainingCommand):
"""Handle model training command"""
# 1. Load training data
train_loader = self._load_data(command.training_data_version)
# 2. Initialize model
model = self._create_model(command.model_type, command.hyperparameters)
# 3. Train model (long-running operation)
with mlflow.start_run(experiment_id=command.experiment_id):
metrics = self._train(model, train_loader, command.hyperparameters)
# 4. Register trained model
model_version = self._register_model(
model,
command.model_type,
metrics
)
# 5. Publish model update event
await self.event_publisher.publish({
"event_type": "ModelTrained",
"model_version": model_version,
"model_type": command.model_type,
"metrics": metrics,
"timestamp": datetime.utcnow().isoformat()
})
def _train(self, model, train_loader, hyperparameters):
"""Training logic optimized for throughput"""
optimizer = torch.optim.AdamW(
model.parameters(),
lr=hyperparameters['learning_rate']
)
# Training loop optimized for GPU throughput
for epoch in range(hyperparameters['epochs']):
for batch in train_loader:
loss = model(batch)
loss.backward()
optimizer.step()
optimizer.zero_grad()
return {"accuracy": 0.95, "f1_score": 0.93} # Placeholder
def _register_model(self, model, model_type, metrics):
"""Register model to versioned model registry"""
# Export to optimized inference format (ONNX, TorchScript, etc.)
optimized_model = torch.jit.script(model)
# Register with metadata
version = self.model_registry.register(
model=optimized_model,
model_type=model_type,
metrics=metrics,
framework="pytorch"
)
return version
Query Side: Inference Service (Go)
// inference_service.go
package inference
import (
"context"
"time"
)
// InferenceQuery represents a prediction request
type InferenceQuery struct {
ModelType string `json:"model_type"`
Features map[string]interface{} `json:"features"`
RequestID string `json:"request_id"`
}
// InferenceResult represents prediction response
type InferenceResult struct {
Prediction float64 `json:"prediction"`
Confidence float64 `json:"confidence"`
ModelVersion string `json:"model_version"`
Latency time.Duration `json:"latency_ms"`
}
// InferenceService handles real-time predictions
type InferenceService struct {
modelStore ModelStore
cache PredictionCache
metricsClient MetricsClient
}
func NewInferenceService(
store ModelStore,
cache PredictionCache,
metrics MetricsClient,
) *InferenceService {
return &InferenceService{
modelStore: store,
cache: cache,
metricsClient: metrics,
}
}
// HandleInferenceQuery processes prediction requests with low latency
func (s *InferenceService) HandleInferenceQuery(
ctx context.Context,
query InferenceQuery,
) (*InferenceResult, error) {
start := time.Now()
// 1. Check cache for duplicate requests (idempotency)
if cached, found := s.cache.Get(query.RequestID); found {
s.metricsClient.IncrementCounter("inference.cache_hit")
return cached, nil
}
// 2. Load latest model for this type (optimized read path)
model, err := s.modelStore.GetLatestModel(query.ModelType)
if err != nil {
return nil, err
}
// 3. Run inference (optimized for latency)
prediction, confidence, err := s.runInference(ctx, model, query.Features)
if err != nil {
return nil, err
}
// 4. Build result
result := &InferenceResult{
Prediction: prediction,
Confidence: confidence,
ModelVersion: model.Version,
Latency: time.Since(start),
}
// 5. Cache result
s.cache.Set(query.RequestID, result, 5*time.Minute)
// 6. Record metrics
s.metricsClient.RecordHistogram("inference.latency_ms", result.Latency.Milliseconds())
return result, nil
}
func (s *InferenceService) runInference(
ctx context.Context,
model *Model,
features map[string]interface{},
) (float64, float64, error) {
// Optimized inference path using ONNX Runtime or TorchScript
// Typically runs on CPU with optimized SIMD instructions
// or specialized inference accelerators
inputTensor := s.featuresToTensor(features)
output := model.Predict(inputTensor) // Optimized C++ bindings
return output.Prediction, output.Confidence, nil
}
// ListenForModelUpdates subscribes to model update events from command side
func (s *InferenceService) ListenForModelUpdates(ctx context.Context) {
subscriber := NewEventSubscriber("model-updates")
for {
select {
case event := <-subscriber.Events():
if event.Type == "ModelTrained" {
// Reload model from registry
s.modelStore.RefreshModel(event.ModelType, event.ModelVersion)
s.metricsClient.IncrementCounter("model.updated")
}
case <-ctx.Done():
return
}
}
}
Event Bridge (Connecting Command & Query)
# event_publisher.py (Command Side)
import json
from kafka import KafkaProducer
class ModelEventPublisher:
def __init__(self, kafka_brokers):
self.producer = KafkaProducer(
bootstrap_servers=kafka_brokers,
value_serializer=lambda v: json.dumps(v).encode('utf-8')
)
async def publish(self, event: dict):
"""Publish model update events to query side"""
self.producer.send('model-updates', value=event)
self.producer.flush()
// event_subscriber.go (Query Side)
package inference
import (
"encoding/json"
"github.com/segmentio/kafka-go"
)
type EventSubscriber struct {
reader *kafka.Reader
events chan ModelEvent
}
func NewEventSubscriber(topic string) *EventSubscriber {
reader := kafka.NewReader(kafka.ReaderConfig{
Brokers: []string{"localhost:9092"},
Topic: topic,
GroupID: "inference-service",
})
sub := &EventSubscriber{
reader: reader,
events: make(chan ModelEvent, 100),
}
go sub.consume()
return sub
}
func (s *EventSubscriber) consume() {
for {
msg, err := s.reader.ReadMessage(context.Background())
if err != nil {
continue
}
var event ModelEvent
if err := json.Unmarshal(msg.Value, &event); err != nil {
continue
}
s.events <- event
}
}
func (s *EventSubscriber) Events() <-chan ModelEvent {
return s.events
}
Trade-offs and Considerations
Advantages
- Independent Scaling: Scale training (vertical, GPU-heavy) and inference (horizontal, CPU-optimized) independently
- Technology Flexibility: Use Python for training flexibility, Go/Rust for inference performance
- Performance Optimization: Optimize each side for its workload characteristics
- Operational Simplicity: Deploy, monitor, and debug training and inference separately
- Cost Efficiency: Run expensive training infrastructure only when needed; keep cheap inference always available
Disadvantages
- Increased Complexity: Two systems to maintain instead of one
- Eventual Consistency: Inference may temporarily serve stale models during updates
- Synchronization Overhead: Event-driven updates add latency between training completion and inference availability
- Operational Burden: Requires event streaming infrastructure (Kafka, NATS, etc.)
When to Use CQRS for ML
Good Fit:
- Real-time inference with periodic retraining (daily/weekly)
- Significantly different scaling needs (e.g., 100 QPS inference, 1 training job/day)
- Performance-critical inference (sub-100ms latency requirements)
- Large model size requiring optimization (ONNX conversion, quantization)
Poor Fit:
- Continuous online learning (model updates per request)
- Low-traffic systems where operational complexity outweighs benefits
- Tight coupling between training and inference (e.g., personalization requiring user-specific models)
Monitoring and Observability
Both sides require distinct monitoring:
Command Side (Training):
- Training job duration
- Model quality metrics (accuracy, F1, AUC)
- Data drift detection
- Resource utilization (GPU usage, memory)
Query Side (Inference):
- Inference latency (p50, p95, p99)
- Prediction throughput (QPS)
- Model version served
- Prediction distribution drift
Conclusion
CQRS provides a natural architectural fit for ML systems, aligning with the inherent separation between training and inference workloads. By decoupling these concerns, teams gain flexibility to optimize each side independently, choose appropriate technologies, and scale efficiently.
For principal engineers building production ML systems, CQRS should be a default consideration—especially when inference latency, scalability, or cost optimization are critical requirements. The pattern’s complexity is justified when the performance and operational benefits outweigh the added architectural overhead.
Start simple, but design for CQRS from the beginning. The cost of refactoring a monolithic ML system to CQRS later far exceeds the upfront architectural investment.