Este guia técnico apresenta a implementação de um modelo de Mistura de Especialistas Esparsos (Sparse MoE) usando PyTorch. O Sparse MoE é uma arquitetura neural que ativa apenas os k especialistas mais releventes para cada entrada, equilibrando capacidade expressiva e eficiência computacional.
Conceito de Sparse MoE
Em contraste com modelos dense MoE, onde todos os especialistas processam cada entrada, o Sparse MoE seleciona um subconjunto dinâmico de especialistas. Isso é alcançado por meio de um mecanismo de roteamento que identifica os top-k especialistas com base em scores de relevância, resultando em computação esparsa e redução de custos.
Estrutura do Modelo
A implementação começa com a definição de parâmetros via classe de dados, seguida pela classe do modelo que encapsula os especialistas e a rede de roteamento.
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
torch.manual_seed(0)
@dataclass
class ConfigSparseMoE:
embedding_dim: int = 128
num_specialists: int = 10
top_k: int = 2
A classe do modelo herda de nn.Module e inicializa os componentes necessários.
class SparseMoE(nn.Module):
def __init__(self, config):
super().__init__()
self.embedding_dim = config.embedding_dim
self.num_specialists = config.num_specialists
self.top_k = config.top_k
self.specialists = nn.ModuleList(
[nn.Linear(config.embedding_dim, config.embedding_dim) for _ in range(config.num_specialists)]
)
self.router_network = nn.Linear(config.embedding_dim, config.num_specialists)
No método forward, o roteamento é aplicado para selecionar especialistas e calcular a saída ponderada de forma esparsa.
def forward(self, x):
batch_size = x.shape[0]
routing_logits = self.router_network(x)
topk_scores, topk_indices = torch.topk(routing_logits, self.top_k, dim=-1)
normalized_weights = F.softmax(topk_scores, dim=-1)
output_tensor = torch.zeros_like(x)
for specialist_idx in range(self.num_specialists):
specialist_mask = (topk_indices == specialist_idx)
selected_positions = specialist_mask.any(dim=-1).nonzero(as_tuple=False).squeeze(dim=1)
if selected_positions.numel() == 0:
continue
input_subset = x[selected_positions]
specialist_output = self.specialists[specialist_idx](input_subset)
weight_mask = specialist_mask[selected_positions].float()
weight_values = (normalized_weights[selected_positions] * weight_mask).sum(dim=-1, keepdim=True)
weighted_contribution = specialist_output * weight_values
output_tensor[selected_positions] += weighted_contribution
return output_tensor
Validação do Modelo
Para testar a funcionalidade, instanciamos o modelo com uma configuração padrão e realizamos uma passagem de dados simulados.
config = ConfigSparseMoE()
model_instance = SparseMoE(config)
sample_input = torch.randn(10, 10, config.embedding_dim).reshape(-1, config.embedding_dim)
result = model_instance(sample_input).reshape(10, 10, config.embedding_dim)
print(f"Shape da saída: {result.shape}")