kopa/process_kge.py
2023-10-11 11:51:08 +08:00

40 lines
1.4 KiB
Python

import torch
def load_pretrain_kge(path):
if "complex" in path:
return load_complex_model(path)
kge_model = torch.load(path)
ent_embs = torch.tensor(kge_model["ent_embeddings.weight"]).cpu()
rel_embs = torch.tensor(kge_model["rel_embeddings.weight"]).cpu()
ent_embs.requires_grad = False
rel_embs.requires_grad = False
ent_dim = ent_embs.shape[1]
rel_dim = rel_embs.shape[1]
print(ent_dim, rel_dim)
if ent_dim != rel_dim:
rel_embs = torch.cat((rel_embs, rel_embs), dim=-1)
# print(ent_embs.shape, rel_embs.shape)
# print(ent_embs.requires_grad, rel_embs.requires_grad)
return ent_embs, rel_embs
def load_complex_model(path):
kge_model = torch.load(path)
ent_embs1 = torch.tensor(kge_model["ent_re_embeddings.weight"]).cpu()
ent_embs2 = torch.tensor(kge_model["ent_im_embeddings.weight"]).cpu()
rel_embs1 = torch.tensor(kge_model["rel_re_embeddings.weight"]).cpu()
rel_embs2 = torch.tensor(kge_model["rel_im_embeddings.weight"]).cpu()
ent_embs = torch.cat((ent_embs1, ent_embs2), dim=-1)
rel_embs = torch.cat((rel_embs1, rel_embs2), dim=-1)
ent_embs.requires_grad = False
rel_embs.requires_grad = False
ent_dim = ent_embs.shape[1]
rel_dim = rel_embs.shape[1]
print(ent_dim, rel_dim)
return ent_embs, rel_embs
if __name__ == "__main__":
load_pretrain_kge("data/CoDeX-S-complex.pth")