Exploração e Implementação de Knowledge Graph Embeddings com AmpliGraph

O AmpliGraph é uma biblioteca poderosa para aprendizado de representações em Grafos de Conhecimento (Knowledge Graphs). Embora a versão 2.0 traga melhorias significativas, a transição entre versões e a compatibilidade com o ecossistema TensorFlow exigem atenção a detalhes de configuração e ambiente.

Configuração do Ambiente e Versão 1.4

Para cenários que exigem a versão 1.4, a instalação deve ser feita via Conda, garantindo a compatibilidade com versões específicas do TensorFlow.

conda create --name env_ampligraph python=3.7
conda activate env_ampligraph
conda install tensorflow-gpu==1.5
pip install ampligraph

Um problema comum durante a execução em ambientes com versões desencontradas é o erro de importação do profiler (ImportError: cannot import name 'trace' from 'tensorflow.python.profiler'). Isso geralmente indica uma incompatibilidade entre o tensorflow e o tensorflow-estimator. A correção pode ser aplicada ajustando o estimator:

conda install tensorflow-estimator=1.15

Treinamento e Avaliação com Modelo ComplEx

Abaixo, um exemplo de implementação utilizando o modelo ComplEx no dataset Wordnet18 (WN18), incluindo mecanismos de early stopping e filtragem de triplas para avaliação precisa.

import numpy as np
from ampligraph.datasets import load_wn18
from ampligraph.latent_features import ComplEx
from ampligraph.evaluation import evaluate_performance, mrr_score, hits_at_n_score

def executar_treinamento_complex():
    # Carregamento do dataset WN18
    dados = load_wn18()

    # Instanciação do modelo ComplEx
    # Configurado para 20 épocas iniciais e otimizador Adam
    modelo_complex = ComplEx(batches_count=10, 
                             seed=42, 
                             epochs=20, 
                             k=150, 
                             eta=10,
                             optimizer='adam', 
                             optimizer_params={'lr': 1e-3},
                             loss='pairwise', 
                             loss_params={'margin': 0.5},
                             regularizer='LP', 
                             regularizer_params={'p': 2, 'lambda': 1e-5}, 
                             verbose=True)

    # Criação do filtro para evitar falsos negativos na avaliação
    filtro_triplas = np.concatenate((dados['train'], dados['valid'], dados['test']))
    
    # Treinamento com interrupção antecipada (Early Stopping)
    modelo_complex.fit(dados['train'], 
                       early_stopping=True,
                       early_stopping_params={
                           'x_valid': dados['valid'],
                           'criteria': 'hits10',
                           'burn_in': 100,
                           'check_interval': 20,
                           'stop_interval': 5,
                           'x_filter': filtro_triplas,
                           'corruption_entities': 'all',
                           'corrupt_side': 's+o'
                       })

    # Avaliação de performance no conjunto de teste
    ranks_avaliacao = evaluate_performance(dados['test'], 
                                           model=modelo_complex, 
                                           filter_triples=filtro_triplas,
                                           use_default_protocol=True,
                                           verbose=True)

    # Cálculo de métricas
    valor_mrr = mrr_score(ranks_avaliacao)
    valor_hits10 = hits_at_n_score(ranks_avaliacao, n=10)
    
    print(f"Resultado - MRR: {valor_mrr:.6f}, Hits@10: {valor_hits10:.6f}")

if __name__ == "__main__":
    executar_treinamento_complex()

Visualização de Métricas com TensorBoard

O AmpliGraph possui integração nativa com o TensorBoard para monitoramento da função de perda (loss) e métricas de validação. Durante o fit, os logs são gerados automaticamente se um caminho for especificaod. Para visualizar os resultados, utiliza-se o comando:

tensorboard --logdir=./caminho_dos_logs

Ao acessar http://localhost:6006/, é possível observar a curva de aprendizado e a convergência do modelo em tempo real.

Transição para o AmpliGraph 2.0 e Módulo de Compatibilidade

A versão 2.0 reestruturou a API para separar de forma mais clara as etapas de inicialização, compilação e ajuste do modelo. No entanto, algumas funcionalidades da versão 1.x, como o parâmetro direto tensorboard_logs_path no método fit, foram modificadas.

Para manter fluxos de trabalho antigos ou utilizar funcionalidades específicas da API anterior dentro da versão 2.0, utiliza-se o módulo ampligraph.compat.

import numpy as np
from ampligraph.datasets import load_wn18
from ampligraph.compat import TransE, evaluate_performance
from ampligraph.evaluation import mrr_score, hits_at_n_score
from ampligraph.utils import save_model, restore_model

# Carregamento do dataset
dataset = load_wn18()

# Definição do modelo TransE via camada de compatibilidade
modelo_transe = TransE(k=350, 
                       eta=30, 
                       epochs=400, 
                       batches_count=150, 
                       seed=0,
                       embedding_model_params={'norm': 1, 'normalize_ent_emb': False},
                       optimizer='adam', 
                       optimizer_params={'lr': 0.0001},
                       loss='multiclass_nll', 
                       regularizer="LP", 
                       regularizer_params={'p': 3, 'lambda': 0.0001},
                       initializer='xavier', 
                       verbose=True)

filtro_global = {'test': np.concatenate((dataset['train'], dataset['valid'], dataset['test']))}

# Execução do treinamento compatível com 1.x
modelo_transe.fit(dataset['train'], 
                  tensorboard_logs_path="logs_transe_wn18",
                  early_stopping=True,
                  early_stopping_params={
                      'x_valid': dataset['valid'],
                      'criteria': 'hits10',
                      'check_interval': 40,
                      'stop_interval': 100,
                      'x_filter': filtro_global,
                      'corrupt_side': 's+o'
                  })

# Persistência do modelo
path_modelo = "modelo_transe_final.pkl"
save_model(modelo_transe, model_name_path=path_modelo)

# Restauração
modelo_carregado = restore_model(model_name_path=path_modelo)

Nota importante ao utilizar restore_model: em alguns casos, após a restauração, o objeto do modelo pode perder metadados sobre o formato dos dados (data_shape). Caso ocorram erros na fase de avaliação pós-restauração, certifique-se de validar se as dimensões de entrada foram preservadas corretamente no objeto recuperado.

Tags: AmpliGraph Knowledge-Graph tensorflow Python machine-learning

Publicado em 6-12 00:18 por Thomas