Implementação de um Modelo de Mistura de Especialistas Esparsos (Sparse MoE) com PyTorch

Este guia técnico apresenta a implementação de um modelo de Mistura de Especialistas Esparsos (Sparse MoE) usando PyTorch. O Sparse MoE é uma arquitetura neural que ativa apenas os k especialistas mais releventes para cada entrada, equilibrando capacidade expressiva e eficiência computacional.

Conceito de Sparse MoE

Em contraste com modelos dense MoE, onde todos os especialistas processam cada entrada, o Sparse MoE seleciona um subconjunto dinâmico de especialistas. Isso é alcançado por meio de um mecanismo de roteamento que identifica os top-k especialistas com base em scores de relevância, resultando em computação esparsa e redução de custos.

Estrutura do Modelo

A implementação começa com a definição de parâmetros via classe de dados, seguida pela classe do modelo que encapsula os especialistas e a rede de roteamento.

import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass

torch.manual_seed(0)

@dataclass
class ConfigSparseMoE:
    embedding_dim: int = 128
    num_specialists: int = 10
    top_k: int = 2

A classe do modelo herda de nn.Module e inicializa os componentes necessários.

class SparseMoE(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.embedding_dim = config.embedding_dim
        self.num_specialists = config.num_specialists
        self.top_k = config.top_k

        self.specialists = nn.ModuleList(
            [nn.Linear(config.embedding_dim, config.embedding_dim) for _ in range(config.num_specialists)]
        )
        self.router_network = nn.Linear(config.embedding_dim, config.num_specialists)

No método forward, o roteamento é aplicado para selecionar especialistas e calcular a saída ponderada de forma esparsa.

    def forward(self, x):
        batch_size = x.shape[0]
        routing_logits = self.router_network(x)
        topk_scores, topk_indices = torch.topk(routing_logits, self.top_k, dim=-1)
        normalized_weights = F.softmax(topk_scores, dim=-1)

        output_tensor = torch.zeros_like(x)

        for specialist_idx in range(self.num_specialists):
            specialist_mask = (topk_indices == specialist_idx)
            selected_positions = specialist_mask.any(dim=-1).nonzero(as_tuple=False).squeeze(dim=1)

            if selected_positions.numel() == 0:
                continue

            input_subset = x[selected_positions]
            specialist_output = self.specialists[specialist_idx](input_subset)

            weight_mask = specialist_mask[selected_positions].float()
            weight_values = (normalized_weights[selected_positions] * weight_mask).sum(dim=-1, keepdim=True)

            weighted_contribution = specialist_output * weight_values
            output_tensor[selected_positions] += weighted_contribution

        return output_tensor

Validação do Modelo

Para testar a funcionalidade, instanciamos o modelo com uma configuração padrão e realizamos uma passagem de dados simulados.

config = ConfigSparseMoE()
model_instance = SparseMoE(config)

sample_input = torch.randn(10, 10, config.embedding_dim).reshape(-1, config.embedding_dim)
result = model_instance(sample_input).reshape(10, 10, config.embedding_dim)
print(f"Shape da saída: {result.shape}")

Tags: Pytorch Sparse MoE Mixture of Experts sparse activation

Publicado em 7-5 07:40