123 lines
3.8 KiB
Python
123 lines
3.8 KiB
Python
""" from sentence_transformers import SentenceTransformer, models
|
|
|
|
## Step 1: use an existing language model
|
|
word_embedding_model = models.Transformer('distilroberta-base')
|
|
|
|
## Step 2: use a pool function over the token embeddings
|
|
pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension())
|
|
|
|
## Join steps 1 and 2 using the modules argument
|
|
model = SentenceTransformer(modules=[word_embedding_model, pooling_model])
|
|
|
|
from sentence_transformers import InputExample
|
|
|
|
from datasets import load_dataset
|
|
|
|
dataset_id = "embedding-data/QQP_triplets"
|
|
# dataset_id = "embedding-data/sentence-compression"
|
|
|
|
dataset = load_dataset(dataset_id)
|
|
|
|
|
|
train_examples = []
|
|
train_data = dataset['train']['set']
|
|
# For agility we only 1/2 of our available data
|
|
n_examples = dataset['train'].num_rows // 2
|
|
|
|
for i in range(10):
|
|
example = train_data[i]
|
|
train_examples.append(InputExample(texts=[example['query'], example['pos'][0]]))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from torch.utils.data import DataLoader
|
|
|
|
train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=2)
|
|
|
|
|
|
from sentence_transformers import losses
|
|
|
|
train_loss = losses.MultipleNegativesRankingLoss(model=model)
|
|
|
|
num_epochs = 10
|
|
|
|
warmup_steps = int(len(train_dataloader) * num_epochs * 0.1) #10% of train data
|
|
|
|
model.fit(train_objectives=[(train_dataloader, train_loss)],epochs=num_epochs,warmup_steps=2)
|
|
|
|
|
|
"""
|
|
from sentence_transformers import SentenceTransformer, losses, InputExample
|
|
from torch.utils.data import DataLoader
|
|
from unidecode import unidecode
|
|
from pathlib import Path
|
|
import json
|
|
import os
|
|
from datetime import datetime
|
|
|
|
nameModel="Modelo_embedding_Mexico_Puebla_hiiamasid"
|
|
def extractConfig(nameModel="Modelo_embedding_Mexico_Puebla",relPath="./conf/experiment_config.json",dataOut="train_dataset_pos"):
|
|
configPath=Path(relPath)
|
|
with open(configPath, 'r', encoding='utf-8') as file:
|
|
config = json.load(file)[nameModel]
|
|
if type(dataOut) is list and len(dataOut)==2:
|
|
Output= config[dataOut[0]][dataOut[1]]
|
|
else:
|
|
Output= config[dataOut]
|
|
return Output
|
|
def saveConfig(dictionary):
|
|
pathOutfile='./%s/%s/params/'%(nameModel,baseModel)
|
|
if not os.path.exists(pathOutfile):
|
|
os.makedirs(pathOutfile)
|
|
with open(pathOutfile+"params.json", "w",encoding='utf-8') as outfile:
|
|
json.dump(dictionary, outfile)
|
|
|
|
def saveData(dictionary):
|
|
Sal={}
|
|
pathOutfile='./%s/%s/data/'%(nameModel,baseModel)
|
|
if not os.path.exists(pathOutfile):
|
|
os.makedirs(pathOutfile)
|
|
|
|
with open(pathOutfile+"train.json", "w",encoding='utf-8') as outfile:
|
|
json.dump(dictionary, outfile)
|
|
|
|
now = datetime.now()
|
|
|
|
entrenamiento="V_%s_%s_%s"%(now.year,now.month,now.day)
|
|
baseModel=extractConfig(nameModel=nameModel,dataOut="base_model")
|
|
trainDatasetPos=extractConfig(nameModel=nameModel,dataOut="train_dataset_pos")
|
|
|
|
model=extractConfig(nameModel=nameModel,dataOut="path_model")
|
|
modelST = SentenceTransformer(model+"/model")
|
|
train_loss = losses.MultipleNegativesRankingLoss(model=modelST)
|
|
train_path = Path(trainDatasetPos)
|
|
with open(train_path, 'r', encoding='utf-8') as file:
|
|
queries_Categoricos = json.load(file)
|
|
|
|
|
|
train_examples = []
|
|
for i in queries_Categoricos.keys():
|
|
|
|
for j in queries_Categoricos[i]:
|
|
i=unidecode(i).strip().lower()
|
|
j=unidecode(j).strip().lower()
|
|
|
|
train_examples.append(InputExample(texts=[ i,j]))
|
|
|
|
|
|
train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=5)#16
|
|
print(len(train_dataloader))
|
|
modelST.fit(train_objectives=[(train_dataloader, train_loss)],epochs=extractConfig(dataOut=["params","num_epochs"]),warmup_steps=extractConfig(dataOut=["params","warmup_steps"]))
|
|
save_path = './%s/%s/model/'%(nameModel,baseModel)
|
|
modelST.save(save_path)
|
|
|
|
params={"entrenamiento":entrenamiento,"baseModel":baseModel}
|
|
params.update(extractConfig(dataOut="params"))
|
|
saveConfig(params)
|
|
saveData(queries_Categoricos)
|
|
|