182 lines
6.7 KiB
Python
182 lines
6.7 KiB
Python
![]() |
import torch
|
||
|
import torch.nn as nn
|
||
|
from typing import Optional, List, Union, Tuple
|
||
|
|
||
|
from transformers import LlamaForCausalLM
|
||
|
from process_kge import load_pretrain_kge
|
||
|
|
||
|
|
||
|
class KoPA(nn.Module):
|
||
|
def __init__(
|
||
|
self,
|
||
|
model: LlamaForCausalLM
|
||
|
) -> None:
|
||
|
super(KoPA, self).__init__()
|
||
|
self.llama_model = model
|
||
|
# self.embeddings = nn.Embedding(100, 4096)
|
||
|
self.embeddings = PrefixKGEmbedding(
|
||
|
num_ent=2034,
|
||
|
num_rel=42,
|
||
|
dim_llm=4096,
|
||
|
num_prefix=1
|
||
|
)
|
||
|
|
||
|
def forward(
|
||
|
self,
|
||
|
input_ids: torch.LongTensor = None,
|
||
|
attention_mask: Optional[torch.Tensor] = None,
|
||
|
position_ids: Optional[torch.LongTensor] = None,
|
||
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||
|
labels: Optional[torch.LongTensor] = None,
|
||
|
use_cache: Optional[bool] = None,
|
||
|
output_attentions: Optional[bool] = None,
|
||
|
output_hidden_states: Optional[bool] = None,
|
||
|
return_dict: Optional[bool] = None,
|
||
|
embedding_ids: torch.LongTensor = None
|
||
|
):
|
||
|
kg_embeds = self.embeddings(embedding_ids)
|
||
|
batch_size, seq_len, _ = kg_embeds.shape
|
||
|
token_embeds = self.llama_model.model.model.embed_tokens(input_ids)
|
||
|
input_embeds = torch.cat((kg_embeds, token_embeds), dim=1)
|
||
|
prefix_mask = torch.ones((batch_size, seq_len))
|
||
|
prefix_labels = torch.full((batch_size, seq_len), fill_value=-100, dtype=torch.long)
|
||
|
new_attention_mask = torch.cat((prefix_mask.cuda(), attention_mask), dim=-1)
|
||
|
new_labels = torch.cat((prefix_labels.cuda(), labels), dim=-1)
|
||
|
return self.llama_model(
|
||
|
input_ids=None,
|
||
|
attention_mask=new_attention_mask,
|
||
|
position_ids=position_ids,
|
||
|
past_key_values=past_key_values,
|
||
|
inputs_embeds=input_embeds,
|
||
|
labels=new_labels,
|
||
|
use_cache=use_cache,
|
||
|
output_attentions=output_attentions,
|
||
|
output_hidden_states=output_hidden_states,
|
||
|
return_dict=return_dict,
|
||
|
)
|
||
|
|
||
|
|
||
|
class KoPAWithAdapter(nn.Module):
|
||
|
def __init__(
|
||
|
self,
|
||
|
model: LlamaForCausalLM,
|
||
|
num_prefix: int,
|
||
|
kge_model: str = "data/UMLS-rotate.pth",
|
||
|
pretrain_emb_path = None
|
||
|
) -> None:
|
||
|
super(KoPAWithAdapter, self).__init__()
|
||
|
self.llama_model = model
|
||
|
ent_embs, rel_embs = load_pretrain_kge(kge_model)
|
||
|
if pretrain_emb_path is None:
|
||
|
print("Adapter Trained From Scratch".format(pretrain_emb_path))
|
||
|
self.embeddings = PretrainKGEmbedding(
|
||
|
pretrain_ent_embs=ent_embs,
|
||
|
pretrain_rel_embs=rel_embs,
|
||
|
dim_llm=4096,
|
||
|
num_prefix=num_prefix
|
||
|
)
|
||
|
else:
|
||
|
print("Adapter Load From {}".format(pretrain_emb_path))
|
||
|
self.embeddings = torch.load(pretrain_emb_path)
|
||
|
|
||
|
def forward(
|
||
|
self,
|
||
|
input_ids: torch.LongTensor = None,
|
||
|
attention_mask: Optional[torch.Tensor] = None,
|
||
|
position_ids: Optional[torch.LongTensor] = None,
|
||
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||
|
labels: Optional[torch.LongTensor] = None,
|
||
|
use_cache: Optional[bool] = None,
|
||
|
output_attentions: Optional[bool] = None,
|
||
|
output_hidden_states: Optional[bool] = None,
|
||
|
return_dict: Optional[bool] = None,
|
||
|
embedding_ids: torch.LongTensor = None
|
||
|
):
|
||
|
kg_embeds = self.embeddings(embedding_ids)
|
||
|
# print(kg_embeds.shape)
|
||
|
batch_size, seq_len, _ = kg_embeds.shape
|
||
|
token_embeds = self.llama_model.model.model.embed_tokens(input_ids)
|
||
|
input_embeds = torch.cat((kg_embeds, token_embeds), dim=1)
|
||
|
prefix_mask = torch.ones((batch_size, seq_len))
|
||
|
prefix_labels = torch.full((batch_size, seq_len), fill_value=-100, dtype=torch.long)
|
||
|
new_attention_mask = torch.cat((prefix_mask.cuda(), attention_mask), dim=-1)
|
||
|
new_labels = torch.cat((prefix_labels.cuda(), labels), dim=-1)
|
||
|
return self.llama_model(
|
||
|
input_ids=None,
|
||
|
attention_mask=new_attention_mask,
|
||
|
position_ids=position_ids,
|
||
|
past_key_values=past_key_values,
|
||
|
inputs_embeds=input_embeds,
|
||
|
labels=new_labels,
|
||
|
use_cache=use_cache,
|
||
|
output_attentions=output_attentions,
|
||
|
output_hidden_states=output_hidden_states,
|
||
|
return_dict=return_dict,
|
||
|
)
|
||
|
|
||
|
|
||
|
|
||
|
class PrefixKGEmbedding(nn.Module):
|
||
|
def __init__(
|
||
|
self,
|
||
|
num_ent,
|
||
|
num_rel,
|
||
|
dim_llm,
|
||
|
num_prefix
|
||
|
):
|
||
|
super(PrefixKGEmbedding, self).__init__()
|
||
|
self.emb_dim = num_prefix * dim_llm
|
||
|
self.ent_embeddings = nn.Embedding(num_ent, self.emb_dim)
|
||
|
self.rel_embeddings = nn.Embedding(num_rel, self.emb_dim)
|
||
|
|
||
|
|
||
|
def forward(self, triple_ids):
|
||
|
head, relation, tail = triple_ids[:, 0], triple_ids[:, 1], triple_ids[:, 2]
|
||
|
h = self.ent_embeddings(head)
|
||
|
r = self.rel_embeddings(relation)
|
||
|
t = self.ent_embeddings(tail)
|
||
|
prefix = torch.stack((h, r, t), dim=1)
|
||
|
return prefix
|
||
|
|
||
|
class PretrainKGEmbedding(nn.Module):
|
||
|
def __init__(
|
||
|
self,
|
||
|
pretrain_ent_embs,
|
||
|
pretrain_rel_embs,
|
||
|
dim_llm,
|
||
|
num_prefix
|
||
|
):
|
||
|
super(PretrainKGEmbedding, self).__init__()
|
||
|
self.num_prefix = num_prefix
|
||
|
self.llm_dim = dim_llm
|
||
|
self.emb_dim = num_prefix * dim_llm
|
||
|
self.ent_embeddings = nn.Embedding.from_pretrained(pretrain_ent_embs)
|
||
|
self.rel_embeddings = nn.Embedding.from_pretrained(pretrain_rel_embs)
|
||
|
self.pretrain_dim = self.ent_embeddings.weight.shape[1]
|
||
|
# Froze the pretrain embeddings
|
||
|
self.ent_embeddings.requires_grad_(False)
|
||
|
self.rel_embeddings.requires_grad_(False)
|
||
|
self.adapter = nn.Linear(self.pretrain_dim, self.emb_dim)
|
||
|
|
||
|
|
||
|
def forward(self, triple_ids):
|
||
|
# main training stage
|
||
|
if triple_ids.shape[1] == 3:
|
||
|
head, relation, tail = triple_ids[:, 0], triple_ids[:, 1], triple_ids[:, 2]
|
||
|
h = self.ent_embeddings(head)
|
||
|
r = self.rel_embeddings(relation)
|
||
|
t = self.ent_embeddings(tail)
|
||
|
pretrain_embs = torch.stack((h, r, t), dim=1)
|
||
|
prefix = self.adapter(pretrain_embs).reshape(-1, 3*self.num_prefix, self.llm_dim)
|
||
|
return prefix
|
||
|
# entity-aware pre-funing
|
||
|
else:
|
||
|
ent = triple_ids.reshape(-1,)
|
||
|
emb = self.ent_embeddings(ent)
|
||
|
prefix = self.adapter(emb).reshape(-1, self.num_prefix, self.llm_dim)
|
||
|
# print(prefix.shape)
|
||
|
return prefix
|
||
|
|