O AmpliGraph é uma biblioteca poderosa para aprendizado de representações em Grafos de Conhecimento (Knowledge Graphs). Embora a versão 2.0 traga melhorias significativas, a transição entre versões e a compatibilidade com o ecossistema TensorFlow exigem atenção a detalhes de configuração e ambiente.
Configuração do Ambiente e Versão 1.4
Para cenários que exigem a versão 1.4, a instalação deve ser feita via Conda, garantindo a compatibilidade com versões específicas do TensorFlow.
conda create --name env_ampligraph python=3.7
conda activate env_ampligraph
conda install tensorflow-gpu==1.5
pip install ampligraph
Um problema comum durante a execução em ambientes com versões desencontradas é o erro de importação do profiler (ImportError: cannot import name 'trace' from 'tensorflow.python.profiler'). Isso geralmente indica uma incompatibilidade entre o tensorflow e o tensorflow-estimator. A correção pode ser aplicada ajustando o estimator:
conda install tensorflow-estimator=1.15
Treinamento e Avaliação com Modelo ComplEx
Abaixo, um exemplo de implementação utilizando o modelo ComplEx no dataset Wordnet18 (WN18), incluindo mecanismos de early stopping e filtragem de triplas para avaliação precisa.
import numpy as np
from ampligraph.datasets import load_wn18
from ampligraph.latent_features import ComplEx
from ampligraph.evaluation import evaluate_performance, mrr_score, hits_at_n_score
def executar_treinamento_complex():
# Carregamento do dataset WN18
dados = load_wn18()
# Instanciação do modelo ComplEx
# Configurado para 20 épocas iniciais e otimizador Adam
modelo_complex = ComplEx(batches_count=10,
seed=42,
epochs=20,
k=150,
eta=10,
optimizer='adam',
optimizer_params={'lr': 1e-3},
loss='pairwise',
loss_params={'margin': 0.5},
regularizer='LP',
regularizer_params={'p': 2, 'lambda': 1e-5},
verbose=True)
# Criação do filtro para evitar falsos negativos na avaliação
filtro_triplas = np.concatenate((dados['train'], dados['valid'], dados['test']))
# Treinamento com interrupção antecipada (Early Stopping)
modelo_complex.fit(dados['train'],
early_stopping=True,
early_stopping_params={
'x_valid': dados['valid'],
'criteria': 'hits10',
'burn_in': 100,
'check_interval': 20,
'stop_interval': 5,
'x_filter': filtro_triplas,
'corruption_entities': 'all',
'corrupt_side': 's+o'
})
# Avaliação de performance no conjunto de teste
ranks_avaliacao = evaluate_performance(dados['test'],
model=modelo_complex,
filter_triples=filtro_triplas,
use_default_protocol=True,
verbose=True)
# Cálculo de métricas
valor_mrr = mrr_score(ranks_avaliacao)
valor_hits10 = hits_at_n_score(ranks_avaliacao, n=10)
print(f"Resultado - MRR: {valor_mrr:.6f}, Hits@10: {valor_hits10:.6f}")
if __name__ == "__main__":
executar_treinamento_complex()
Visualização de Métricas com TensorBoard
O AmpliGraph possui integração nativa com o TensorBoard para monitoramento da função de perda (loss) e métricas de validação. Durante o fit, os logs são gerados automaticamente se um caminho for especificaod. Para visualizar os resultados, utiliza-se o comando:
tensorboard --logdir=./caminho_dos_logs
Ao acessar http://localhost:6006/, é possível observar a curva de aprendizado e a convergência do modelo em tempo real.
Transição para o AmpliGraph 2.0 e Módulo de Compatibilidade
A versão 2.0 reestruturou a API para separar de forma mais clara as etapas de inicialização, compilação e ajuste do modelo. No entanto, algumas funcionalidades da versão 1.x, como o parâmetro direto tensorboard_logs_path no método fit, foram modificadas.
Para manter fluxos de trabalho antigos ou utilizar funcionalidades específicas da API anterior dentro da versão 2.0, utiliza-se o módulo ampligraph.compat.
import numpy as np
from ampligraph.datasets import load_wn18
from ampligraph.compat import TransE, evaluate_performance
from ampligraph.evaluation import mrr_score, hits_at_n_score
from ampligraph.utils import save_model, restore_model
# Carregamento do dataset
dataset = load_wn18()
# Definição do modelo TransE via camada de compatibilidade
modelo_transe = TransE(k=350,
eta=30,
epochs=400,
batches_count=150,
seed=0,
embedding_model_params={'norm': 1, 'normalize_ent_emb': False},
optimizer='adam',
optimizer_params={'lr': 0.0001},
loss='multiclass_nll',
regularizer="LP",
regularizer_params={'p': 3, 'lambda': 0.0001},
initializer='xavier',
verbose=True)
filtro_global = {'test': np.concatenate((dataset['train'], dataset['valid'], dataset['test']))}
# Execução do treinamento compatível com 1.x
modelo_transe.fit(dataset['train'],
tensorboard_logs_path="logs_transe_wn18",
early_stopping=True,
early_stopping_params={
'x_valid': dataset['valid'],
'criteria': 'hits10',
'check_interval': 40,
'stop_interval': 100,
'x_filter': filtro_global,
'corrupt_side': 's+o'
})
# Persistência do modelo
path_modelo = "modelo_transe_final.pkl"
save_model(modelo_transe, model_name_path=path_modelo)
# Restauração
modelo_carregado = restore_model(model_name_path=path_modelo)
Nota importante ao utilizar restore_model: em alguns casos, após a restauração, o objeto do modelo pode perder metadados sobre o formato dos dados (data_shape). Caso ocorram erros na fase de avaliação pós-restauração, certifique-se de validar se as dimensões de entrada foram preservadas corretamente no objeto recuperado.