diff --git a/.gitignore b/.gitignore index 7e7376e..33d8e7e 100644 --- a/.gitignore +++ b/.gitignore @@ -19,3 +19,4 @@ NewData* motor05102023.csv run.sh +Finetuning/embeddings/all-mpnet-base-v2/* diff --git a/main.py b/main.py index 1fbaff7..2687e77 100644 --- a/main.py +++ b/main.py @@ -21,6 +21,7 @@ from sentence_transformers import SentenceTransformer from fastapi import FastAPI from unidecode import unidecode from nltk.corpus import stopwords +from typing import Optional #from cleantext import clean import re model="embeddings/all-mpnet-base-v2" @@ -135,7 +136,6 @@ def FinderDbs(query,dbs,filtred=False,th=1.2): AllData={} for dbt in dbs: Sal = dbt.similarity_search_with_score(query,4) - print(Sal) for output in Sal: if output[0].metadata["id"] in AllData.keys(): AllData[output[0].metadata["id"]]["d"]=min([AllData[output[0].metadata["id"]]["d"]-0.1,output[1]-0.1]) @@ -165,17 +165,24 @@ def read_main(): class Response(BaseModel): query: str + filtred : Optional[int] = 0 filtred=False @app.post("/angela/") def calculate_api(response: Response): - print(response.query) query = response.query + try: + filtred = response.filtred + except: + filtred = 0 + if filtred==1: + filtred=True + else: + filtred=False AllData=FinderDbs(query,[db2],filtred) versionL="_".join([model,entrenamiento]) if AllData: - AllData = list(AllData) dis=[] id=[]