32 lines
1.2 KiB
Python
32 lines
1.2 KiB
Python
from sentence_transformers import SentenceTransformer
|
|
# Preguntas y respuestas especializado en eso "multi-qa-mpnet-base-dot-v1"
|
|
# uno de uso gereal el de mejor desempeño all-mpnet-base-v2
|
|
# el mas rapido "paraphrase-MiniLM-L3-v2" y "all-MiniLM-L6-v2"
|
|
# muy rappudo y muy acertado "all-MiniLM-L12-v2"
|
|
#models=["all-MiniLM-L12-v2","paraphrase-MiniLM-L3-v2" , "all-MiniLM-L6-v2",
|
|
from pathlib import Path
|
|
import json
|
|
#"paraphrase-multilingual-mpnet-base-v2",'hackathon-pln-es/paraphrase-spanish-distilroberta'
|
|
nameModel="Modelo_embedding_CIDITEL"
|
|
def extractConfig(nameModel="Modelo_embedding_Mexico_Puebla",relPath="./conf/experiment_config.json",dataOut="train_dataset_pos"):
|
|
configPath=Path(relPath)
|
|
with open(configPath, 'r', encoding='utf-8') as file:
|
|
config = json.load(file)[nameModel]
|
|
if dataOut is list and len(dataOut)==2:
|
|
Output= config[dataOut[0]][dataOut[1]]
|
|
else:
|
|
Output= config[dataOut]
|
|
return Output
|
|
baseModel=extractConfig(nameModel=nameModel,dataOut="base_model")
|
|
models=[baseModel]
|
|
|
|
for model in models:
|
|
modelST = SentenceTransformer(model)
|
|
# Define the path where you want to save the model
|
|
save_path = './embeddings/%s/model'%(model)
|
|
print(save_path)
|
|
# Save the model
|
|
modelST.save(save_path)
|
|
|
|
|