basic backend
This commit is contained in:
parent
16e5004228
commit
89ec0476ca
29 changed files with 2125 additions and 13 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
|
@ -1 +1,2 @@
|
||||||
.vscode
|
.vscode
|
||||||
|
*__pycache__*
|
||||||
|
|
@ -28,3 +28,9 @@ podman-compose -f docker-compose.yaml up
|
||||||
```
|
```
|
||||||
nodemon --ext '*' --exec "podman stop rag-chat-backend; podman rm rag-chat-backend; podman-compose -f docker-compose.yaml up --build"
|
nodemon --ext '*' --exec "podman stop rag-chat-backend; podman rm rag-chat-backend; podman-compose -f docker-compose.yaml up --build"
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## TODO
|
||||||
|
|
||||||
|
* Chunking Parameters (size/overlapp) into settings
|
||||||
|
* Modell Selection into settings
|
||||||
|
|
|
||||||
|
|
@ -11,3 +11,17 @@ python-dotenv==1.0.0
|
||||||
python-multipart==0.0.7
|
python-multipart==0.0.7
|
||||||
PyPDF2==3.0.1
|
PyPDF2==3.0.1
|
||||||
langchain==0.1.11
|
langchain==0.1.11
|
||||||
|
|
||||||
|
transformers==4.36.0
|
||||||
|
pycryptodome==3.20.0
|
||||||
|
httpx==0.27.0
|
||||||
|
|
||||||
|
nltk
|
||||||
|
scikit-learn
|
||||||
|
scipy
|
||||||
|
sentencepiece
|
||||||
|
sentence-transformers==2.2.2 # Has previous requirements (nltk, scikit-learn, scipy, sentencepiece, torchvision)
|
||||||
|
|
||||||
|
# pre-commit==3.6.2
|
||||||
|
#pytest==8.1.1
|
||||||
|
#pylint==3.1.0
|
||||||
|
|
|
||||||
|
|
@ -1,13 +1,11 @@
|
||||||
"""FastAPI Backend"""
|
"""FastAPI Backend"""
|
||||||
|
|
||||||
import uvicorn
|
import uvicorn
|
||||||
import os
|
|
||||||
|
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
from endpoints import files
|
from endpoints import files, llm, search, configurations
|
||||||
|
|
||||||
from core.config import settings
|
from core.config import settings
|
||||||
|
|
||||||
|
|
@ -25,12 +23,9 @@ app.add_middleware(
|
||||||
)
|
)
|
||||||
|
|
||||||
app.include_router(files.router, prefix=settings.API_V1_STR) # , tags=["files"]
|
app.include_router(files.router, prefix=settings.API_V1_STR) # , tags=["files"]
|
||||||
|
app.include_router(llm.router, prefix=settings.API_V1_STR, tags=["llm"]) # , tags=["llm"]
|
||||||
|
app.include_router(search.router, prefix=settings.API_V1_STR, tags=["search"]) # , tags=["search"]
|
||||||
print('OPENSEARCH_USE_SSL')
|
app.include_router(configurations.router, prefix=settings.API_V1_STR, tags=["config"])
|
||||||
print(os.getenv('OPENSEARCH_USE_SSL'))
|
|
||||||
print('settings.API_V1_STR')
|
|
||||||
print(settings.API_V1_STR)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
||||||
0
rag-chat-backend/src/common_packages/__init__.py
Normal file
0
rag-chat-backend/src/common_packages/__init__.py
Normal file
102
rag-chat-backend/src/common_packages/dashboard_logging.py
Normal file
102
rag-chat-backend/src/common_packages/dashboard_logging.py
Normal file
|
|
@ -0,0 +1,102 @@
|
||||||
|
""" This module defines the logging component for the LLM interactions.
|
||||||
|
The data is displayed in the OpenSearch Dashboard.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import uuid
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from opensearch_logger import OpenSearchHandler
|
||||||
|
from opensearchpy import RequestsHttpConnection
|
||||||
|
|
||||||
|
|
||||||
|
class DashboardLogger:
|
||||||
|
"""Logger instance for OpenSearch dashboard"""
|
||||||
|
|
||||||
|
def __init__(self, logger_name=None):
|
||||||
|
if logger_name is None:
|
||||||
|
# Generate a unique logger name
|
||||||
|
logger_name = "rag-logs-" + str(uuid.uuid4())
|
||||||
|
self.logger_instance = self._create_os_logger(logger_name=logger_name)
|
||||||
|
self.logger_info = {}
|
||||||
|
|
||||||
|
def add_information(self, label: str, value):
|
||||||
|
"""Here you can add information to a given process.
|
||||||
|
Each information consists of a label and its given value.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
label (str): label
|
||||||
|
value (any): value
|
||||||
|
"""
|
||||||
|
self.logger_info[label] = value
|
||||||
|
|
||||||
|
def close_logging(self):
|
||||||
|
"""Close logging in the final process."""
|
||||||
|
|
||||||
|
self.logger_info = self._roll_out_json(original_json=self.logger_info)
|
||||||
|
|
||||||
|
if type(self.logger_info) == list:
|
||||||
|
for logger_json in self.logger_info:
|
||||||
|
self.logger_instance.info("Logging information", extra=logger_json)
|
||||||
|
else:
|
||||||
|
self.logger_instance.info("Logging information", extra=self.logger_info)
|
||||||
|
|
||||||
|
def _create_os_logger(self, logger_name):
|
||||||
|
"""Create a logger which logs on OpenSearch.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
logger_name (str): use some logger name.
|
||||||
|
Returns:
|
||||||
|
logger: OpenSearch logger instance.
|
||||||
|
"""
|
||||||
|
|
||||||
|
logger = logging.getLogger(logger_name)
|
||||||
|
logger.setLevel(logging.INFO)
|
||||||
|
|
||||||
|
opensearch_host = os.getenv("VECTOR_STORE_ENDPOINT", "localhost")
|
||||||
|
opensearch_port = os.getenv("VECTOR_STORE_PORT", "9200")
|
||||||
|
|
||||||
|
handler = OpenSearchHandler(
|
||||||
|
index_name="rag-logs",
|
||||||
|
hosts=[f"http://{opensearch_host}:{opensearch_port}"],
|
||||||
|
http_auth=("admin", "admin"),
|
||||||
|
use_ssl=os.getenv("OPENSEARCH_USE_SSL", "False").lower()
|
||||||
|
in ["true", "1", "yes", "y"],
|
||||||
|
verify_certs=False,
|
||||||
|
connection_class=RequestsHttpConnection,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.addHandler(handler)
|
||||||
|
|
||||||
|
return logger
|
||||||
|
|
||||||
|
def _roll_out_json(self, original_json):
|
||||||
|
"""Roll out JSON with lists into individual JSONs with additional attributes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
original_json: The original JSON with lists in 'sources' and 'passages' attributes.
|
||||||
|
Return
|
||||||
|
rolled_out_jsons: List of rolled out JSONs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
rolled_out_jsons = []
|
||||||
|
|
||||||
|
# Iterate through each item in the 'sources' and 'passages' lists
|
||||||
|
for rank, (source, passage) in enumerate(
|
||||||
|
zip(original_json["sources"], original_json["passages"]), start=1
|
||||||
|
):
|
||||||
|
# Create a new JSON for each pair of source and passage
|
||||||
|
new_json = {
|
||||||
|
"query": original_json["query"],
|
||||||
|
"source": source["source"],
|
||||||
|
"page": source["page"],
|
||||||
|
# NOTE we need to cut. Otherwise the visualization wont work with these many chars
|
||||||
|
"passage": passage[:250],
|
||||||
|
"answer": original_json["answer"][:250],
|
||||||
|
"language": original_json["language"],
|
||||||
|
"model": original_json["model"],
|
||||||
|
"rank": rank, # Assign the rank
|
||||||
|
}
|
||||||
|
rolled_out_jsons.append(new_json)
|
||||||
|
|
||||||
|
return rolled_out_jsons
|
||||||
46
rag-chat-backend/src/common_packages/logging.py
Normal file
46
rag-chat-backend/src/common_packages/logging.py
Normal file
|
|
@ -0,0 +1,46 @@
|
||||||
|
""" This module defines the logging component."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
|
||||||
|
def create_logger(log_level: str, logger_name: str = "custom_logger"):
|
||||||
|
"""Create a logging based on logger.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
log_level (str): Kind of logging
|
||||||
|
logger_name (str, optional): Name of logger
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
logger: returns logger
|
||||||
|
"""
|
||||||
|
logger = logging.getLogger(logger_name)
|
||||||
|
logger.setLevel(logging.DEBUG) # Set the base logging level to the lowest (DEBUG)
|
||||||
|
|
||||||
|
# If logger already has handlers, don't add a new one
|
||||||
|
if logger.hasHandlers():
|
||||||
|
logger.handlers.clear()
|
||||||
|
|
||||||
|
# Create a console handler and set the level based on the input
|
||||||
|
console_handler = logging.StreamHandler()
|
||||||
|
if log_level == "DEBUG":
|
||||||
|
console_handler.setLevel(logging.DEBUG)
|
||||||
|
elif log_level == "INFO":
|
||||||
|
console_handler.setLevel(logging.INFO)
|
||||||
|
elif log_level == "WARNING":
|
||||||
|
console_handler.setLevel(logging.WARNING)
|
||||||
|
elif log_level == "ERROR":
|
||||||
|
console_handler.setLevel(logging.ERROR)
|
||||||
|
else:
|
||||||
|
raise ValueError("Invalid log level provided")
|
||||||
|
|
||||||
|
# Create a formatter and set it for the console handler
|
||||||
|
formatter = logging.Formatter(
|
||||||
|
"%(asctime)s - %(levelname)s [%(name)s] - %(message)s",
|
||||||
|
datefmt="%Y-%m-%d %H:%M:%S",
|
||||||
|
)
|
||||||
|
console_handler.setFormatter(formatter)
|
||||||
|
|
||||||
|
# Add the console handler to the logger
|
||||||
|
logger.addHandler(console_handler)
|
||||||
|
|
||||||
|
return logger
|
||||||
|
|
@ -0,0 +1,185 @@
|
||||||
|
"""Create a connection with relevant operations to OpenSearch"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
from opensearchpy import OpenSearch, RequestsHttpConnection, exceptions
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
|
||||||
|
from fastapi import HTTPException
|
||||||
|
|
||||||
|
from connector.database_interface.utils import mappings
|
||||||
|
from connector.database_interface.utils.base_search_interface import BaseSearchInterface
|
||||||
|
|
||||||
|
from common_packages import logging
|
||||||
|
|
||||||
|
# load env-vars
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
# instantiate logger
|
||||||
|
logger = logging.create_logger(
|
||||||
|
log_level=os.getenv("LOGGING_LEVEL", "INFO"),
|
||||||
|
logger_name=__name__,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class OpenSearchInterface(BaseSearchInterface):
|
||||||
|
"""Client to interact with OpenSearch Instance"""
|
||||||
|
|
||||||
|
def __init__(self, index_name, embedder_name, embedding_size, language):
|
||||||
|
"""Initialize an OpenSearch interface object.
|
||||||
|
|
||||||
|
Use index name needed to create an index space in OS.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
index_name (str): index name
|
||||||
|
"""
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
self.logger_inst = logger
|
||||||
|
self.os_client = OpenSearch(
|
||||||
|
hosts=[
|
||||||
|
{
|
||||||
|
"host": os.getenv("VECTOR_STORE_ENDPOINT"),
|
||||||
|
"port": os.getenv("VECTOR_STORE_PORT"),
|
||||||
|
}
|
||||||
|
],
|
||||||
|
http_auth=("admin", "admin"),
|
||||||
|
use_ssl=os.getenv("OPENSEARCH_USE_SSL", "False").lower()
|
||||||
|
in ["true", "1", "yes", "y"],
|
||||||
|
verify_certs=False,
|
||||||
|
connection_class=RequestsHttpConnection,
|
||||||
|
)
|
||||||
|
self.index_name = index_name
|
||||||
|
self.language = language
|
||||||
|
self.document_store = None
|
||||||
|
self.embedding_size = embedding_size
|
||||||
|
self.distance_type = "l2"
|
||||||
|
self.model = SentenceTransformer(embedder_name)
|
||||||
|
|
||||||
|
self.vector_type = "knn_vector"
|
||||||
|
self.embedding_space_name = "embedding_vector"
|
||||||
|
|
||||||
|
mappings.create_index_with_mapping_passagelevel(
|
||||||
|
index_name=self.index_name,
|
||||||
|
os_client=self.os_client,
|
||||||
|
vector_type=self.vector_type,
|
||||||
|
embedding_size=self.embedding_size,
|
||||||
|
embedding_space_name=self.embedding_space_name,
|
||||||
|
distance_type=self.distance_type,
|
||||||
|
logger=self.logger_inst,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.logger_inst.info(
|
||||||
|
"Mappings created. Loaded Embedding model %s", embedder_name
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_unique_values(self, field_name: str = "source"):
|
||||||
|
"""Retrieve all unique values for a specified field from the OpenSearch index.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
field_name (str): The field for which to retrieve unique values.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: List of unique values for the specified field.
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
query = {
|
||||||
|
"size": 0,
|
||||||
|
"aggs": {
|
||||||
|
"unique_values": {
|
||||||
|
"terms": {
|
||||||
|
"field": f"metadata.{field_name}.keyword",
|
||||||
|
"size": 10000, # Number of unique values we expect.
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
# Execute the query
|
||||||
|
response = self.os_client.search(index=self.index_name, body=query)
|
||||||
|
|
||||||
|
# Extract the terms from the response
|
||||||
|
values = [
|
||||||
|
bucket["key"]
|
||||||
|
for bucket in response["aggregations"]["unique_values"]["buckets"]
|
||||||
|
]
|
||||||
|
|
||||||
|
return values
|
||||||
|
except Exception as e:
|
||||||
|
self.logger_inst.error(
|
||||||
|
"Error in retrieving unique values for %s: %s",
|
||||||
|
field_name,
|
||||||
|
e,
|
||||||
|
)
|
||||||
|
return []
|
||||||
|
|
||||||
|
def get_date_range(self):
|
||||||
|
"""Retrieve the maximum and minimum dates from the 'date-of-upload' field.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple[str, str]: A tuple containing the 'min_date' and 'max_date' values as strings.
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
query = {
|
||||||
|
"size": 0,
|
||||||
|
"aggs": {
|
||||||
|
"min_date": {"min": {"field": "metadata.date-of-upload"}},
|
||||||
|
"max_date": {"max": {"field": "metadata.date-of-upload"}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
# Execute the query
|
||||||
|
response = self.os_client.search(index=self.index_name, body=query)
|
||||||
|
|
||||||
|
# Extract the min and max dates from the response
|
||||||
|
min_date = response["aggregations"]["min_date"]["value_as_string"]
|
||||||
|
max_date = response["aggregations"]["max_date"]["value_as_string"]
|
||||||
|
|
||||||
|
return min_date, max_date
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger_inst.error("Error in retrieving date range: %s", e)
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
def delete_indices_by_document(self, document_id: str):
|
||||||
|
"""Delete all indices belonging to the same document in OpenSearch.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
document_id (str): The unique identifier of the document.
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
query = {"query": {"term": {"metadata.source.keyword": document_id}}}
|
||||||
|
self.os_client.delete_by_query(index=self.index_name, body=query)
|
||||||
|
self.logger_inst.info("Deleted all indices for document")
|
||||||
|
except Exception as e:
|
||||||
|
self.logger_inst.error(
|
||||||
|
"Failed deleting indices for document with error: %s", e
|
||||||
|
)
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def empty_entire_index(self):
|
||||||
|
"""Delete all entries in the used vector db index."""
|
||||||
|
|
||||||
|
try:
|
||||||
|
query = {"query": {"match_all": {}}}
|
||||||
|
response = self.os_client.delete_by_query(index=self.index_name, body=query)
|
||||||
|
|
||||||
|
self.logger_inst.info(
|
||||||
|
"Deleted all %s entries for index: %s",
|
||||||
|
response["deleted"],
|
||||||
|
self.index_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
except exceptions.NotFoundError as error:
|
||||||
|
self.logger_inst.warning("Failed emptying index with error: %s", error)
|
||||||
|
raise HTTPException(status_code=404, detail="Not found") from error
|
||||||
|
|
||||||
|
except exceptions.OpenSearchException as error:
|
||||||
|
self.logger_inst.error("Failed emptying index with error: %s", error)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500, detail="Error while deleting from OpenSearch"
|
||||||
|
) from error
|
||||||
|
|
@ -0,0 +1,36 @@
|
||||||
|
"""Abstraction class for other vector storage search services and databases to be implemented.
|
||||||
|
The goal is mainly to have a common definition if new service is set.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Tuple
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
|
||||||
|
class BaseSearchInterface(ABC):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.model: any
|
||||||
|
self.language: str
|
||||||
|
self.document_store: str
|
||||||
|
self.embedding_size: int
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_unique_values(self, field_name: str) -> list:
|
||||||
|
"""A function to retrieve all unique values for a specified field from the underlying search index.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
field_name (str): The field for which to retrieve unique values.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: List of unique values for the specified field.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_date_range(self) -> Tuple[str, str]:
|
||||||
|
"""Retrieve the maximum and minimum dates of the database entries.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple[str, str]: A tuple containing the 'min_date' and 'max_date' values as strings.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
@ -0,0 +1,53 @@
|
||||||
|
def create_index_with_mapping_passagelevel(
|
||||||
|
index_name,
|
||||||
|
os_client,
|
||||||
|
vector_type,
|
||||||
|
embedding_size,
|
||||||
|
embedding_space_name,
|
||||||
|
distance_type,
|
||||||
|
logger,
|
||||||
|
):
|
||||||
|
settings = {
|
||||||
|
"knn": True,
|
||||||
|
"knn.algo_param.ef_search": 512,
|
||||||
|
"index": {"number_of_shards": 3},
|
||||||
|
}
|
||||||
|
mapping = {
|
||||||
|
"properties": {
|
||||||
|
"pdf_id": {"type": "keyword"},
|
||||||
|
"text": {
|
||||||
|
"properties": {
|
||||||
|
"page_content": {"type": "text"},
|
||||||
|
"metadata": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"language": {"type": "keyword"},
|
||||||
|
"date-of-upload": {"type": "date"},
|
||||||
|
"tag": {"type": "keyword"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"page_number": {"type": "integer"},
|
||||||
|
"filename": {"type": "text"},
|
||||||
|
embedding_space_name: {
|
||||||
|
"type": vector_type,
|
||||||
|
"dimension": embedding_size,
|
||||||
|
"method": {
|
||||||
|
"name": "hnsw",
|
||||||
|
"space_type": "innerproduct",
|
||||||
|
"engine": "faiss",
|
||||||
|
"parameters": {"ef_construction": 512, "m": 48},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# Create the index with the specified settings and mappings
|
||||||
|
if not os_client.indices.exists(index=index_name):
|
||||||
|
os_client.indices.create(
|
||||||
|
index=index_name, body={"settings": settings, "mappings": mapping}
|
||||||
|
)
|
||||||
|
logger.info(f"Successfully created document indexing space: {index_name}")
|
||||||
|
|
||||||
|
logger.info(f"Successfully created indexing space: {index_name}")
|
||||||
0
rag-chat-backend/src/connector/llm/__init__.py
Normal file
0
rag-chat-backend/src/connector/llm/__init__.py
Normal file
289
rag-chat-backend/src/connector/llm/ollama.py
Normal file
289
rag-chat-backend/src/connector/llm/ollama.py
Normal file
|
|
@ -0,0 +1,289 @@
|
||||||
|
"""Ollama LLM Module
|
||||||
|
|
||||||
|
This module provides functionalities to access the Ollama API
|
||||||
|
e.g. for calling the model for inference.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
import requests
|
||||||
|
import string
|
||||||
|
|
||||||
|
from connector.llm.utils.helpers import (
|
||||||
|
preprocess_llama_chat_into_query_instruction,
|
||||||
|
extract_first_query_dict,
|
||||||
|
)
|
||||||
|
# from connector.llm.utils.prompts import GenerativeQAPrompt, GenerativeQAPromptDE
|
||||||
|
from connector.llm.utils.base_llm import BaseLLM
|
||||||
|
from connector.llm.utils.base_prompts import BaseChatPrompts, BaseGenerativePrompts
|
||||||
|
|
||||||
|
from common_packages import logging
|
||||||
|
|
||||||
|
# instantiate logger
|
||||||
|
logger = logging.create_logger(
|
||||||
|
log_level=os.getenv("LOGGING_LEVEL", "INFO"),
|
||||||
|
logger_name=__name__,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Message:
|
||||||
|
# [{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "Hello! What are some good questions to ask you?"}, {"role": "assistant", "content": "Hello! I am here to help you with any information or guidance you need."}, {"role": "user", "content": "Ok, can you list me the capital cities of all european countries?"}]
|
||||||
|
|
||||||
|
|
||||||
|
class OllamaLLM(BaseLLM):
|
||||||
|
|
||||||
|
def __init__(self, language):
|
||||||
|
# client.api_key = os.getenv("LLM_API_KEY")
|
||||||
|
self.api_key = os.getenv("LLM_API_KEY")
|
||||||
|
self.base_url = os.getenv("LLM_API_ENDPOINT")
|
||||||
|
|
||||||
|
self.modelname = "llama3"
|
||||||
|
# self.modelname = "mistral"
|
||||||
|
self.max_num_tokens = 1850
|
||||||
|
self.language = language
|
||||||
|
|
||||||
|
self.prompt_obj = GenerativeQAPromptDE
|
||||||
|
logger.debug("Innitiating OLLAMA Class")
|
||||||
|
|
||||||
|
def cut_tokens(self, context: str, cut_above=False):
|
||||||
|
return context
|
||||||
|
|
||||||
|
def llm_request(self, prompt: str) -> str:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
def llm_chat_querifier(self, chat: list):
|
||||||
|
"""Creates an request to LLMs based on a prompt. It creates a query based on a
|
||||||
|
chat and especially the last user message.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chat (list): chat
|
||||||
|
Returns:
|
||||||
|
str: querified last message with consideration of the chat.
|
||||||
|
"""
|
||||||
|
logger.debug("chat in ollama")
|
||||||
|
logger.debug(chat)
|
||||||
|
chat_template = GenerativeChatPromptDE
|
||||||
|
|
||||||
|
prompt = preprocess_llama_chat_into_query_instruction(
|
||||||
|
chat,
|
||||||
|
chat_prompt_template=chat_template.chat_querifier,
|
||||||
|
cut_tokens=self.cut_tokens,
|
||||||
|
lang=self.language,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug("prompt after preprocess_llama_chat_into_query_instruction in ollama")
|
||||||
|
logger.debug(prompt)
|
||||||
|
|
||||||
|
url = f"{self.base_url}" + "/api/generate"
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
payload = {
|
||||||
|
"model": self.modelname,
|
||||||
|
"prompt": prompt,
|
||||||
|
"stream": False,
|
||||||
|
}
|
||||||
|
response = requests.post(url, headers=headers, data=json.dumps(payload))
|
||||||
|
results = response.json()["response"].strip()
|
||||||
|
return results
|
||||||
|
|
||||||
|
try:
|
||||||
|
query = extract_first_query_dict(response["choices"][0]["text"])
|
||||||
|
return query["query"]
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug("There was an error: " + str(e))
|
||||||
|
return chat[-1]["content"]
|
||||||
|
|
||||||
|
def llm_chat_request(
|
||||||
|
self, chat: dict, context: str, language: str, stream_response: bool
|
||||||
|
) -> str:
|
||||||
|
"""Creates an request, to the GPT-3.5-Turbo deployed in Azure, based on a prompt and the access key.
|
||||||
|
Args:
|
||||||
|
chat (dict): User chat.
|
||||||
|
context: (str): context which was triggered by the query
|
||||||
|
language (str): string of the selected language
|
||||||
|
api_key (str): Access Api key.
|
||||||
|
Returns:
|
||||||
|
str: returns response.
|
||||||
|
"""
|
||||||
|
logger.debug("OLLAMA llm_chat_request triggered")
|
||||||
|
chat_template = None
|
||||||
|
chat_template = GenerativeChatPromptDE
|
||||||
|
|
||||||
|
logger.debug("chat")
|
||||||
|
logger.debug(chat)
|
||||||
|
|
||||||
|
logger.debug("chat_template.llm_purpose")
|
||||||
|
logger.debug(chat_template.llm_purpose)
|
||||||
|
|
||||||
|
logger.debug("chat_template.context_command")
|
||||||
|
logger.debug(chat_template.context_command)
|
||||||
|
|
||||||
|
logger.debug("chat_template.acknowledgement_command")
|
||||||
|
logger.debug(chat_template.acknowledgement_command)
|
||||||
|
|
||||||
|
printable = set(string.printable)
|
||||||
|
context = ''.join(filter(lambda x: x in printable, context))
|
||||||
|
context = context.replace("'", "").replace("!", "")
|
||||||
|
logger.debug("context")
|
||||||
|
logger.debug(context)
|
||||||
|
|
||||||
|
chat_prompt = [
|
||||||
|
{"role": "system", "content": chat_template.llm_purpose},
|
||||||
|
{"role": "user", "content": f"{chat_template.context_command} {context}"},
|
||||||
|
{"role": "assistant", "content": chat_template.acknowledgement_command},
|
||||||
|
]
|
||||||
|
chat_prompt.extend(chat)
|
||||||
|
|
||||||
|
logger.debug("chat_prompt")
|
||||||
|
logger.debug(chat_prompt)
|
||||||
|
|
||||||
|
url = f"{self.base_url}" + "/v1/chat/completions"
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Authorization": f"Bearer {self.api_key}",
|
||||||
|
}
|
||||||
|
payload = {
|
||||||
|
"model": self.modelname,
|
||||||
|
"messages": chat_prompt,
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.debug("payload")
|
||||||
|
logger.debug(json.dumps(payload, indent=4))
|
||||||
|
|
||||||
|
response = requests.post(url, headers=headers, data=json.dumps(payload))
|
||||||
|
logger.debug("response")
|
||||||
|
logger.debug(response.text)
|
||||||
|
|
||||||
|
try:
|
||||||
|
results = response.json()["choices"][0]["message"]["content"]
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug("There was an error: " + str(e))
|
||||||
|
logger.debug("response")
|
||||||
|
logger.debug(response)
|
||||||
|
error_message = response.json()["error"]["message"]
|
||||||
|
results = f"Something went wrong. Here is the error: {error_message}"
|
||||||
|
|
||||||
|
logger.debug("results")
|
||||||
|
logger.debug(results)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
class GenerativeQAPromptDE(BaseGenerativePrompts):
|
||||||
|
generative_qa_prompt_reasoning = """
|
||||||
|
Kontext: {context}
|
||||||
|
|
||||||
|
Generiere eine informative und detailiierte Antwort auf der gegebenen Frage und des Kontexts.
|
||||||
|
Formuliere deine Antwort immer in deutscher Sprache.
|
||||||
|
Füge alle zusätzliche Erkenntnisse oder Perspektiven ein, die das Verständnis verbessern können.
|
||||||
|
Beantworte die Frage basierend auf den Kontext und begründe warum deine Antwort richtig ist.
|
||||||
|
|
||||||
|
Frage: {query}
|
||||||
|
|
||||||
|
Antwort:
|
||||||
|
"""
|
||||||
|
|
||||||
|
chatbot_prompt = """ Beantworte den folgenden Input basierend dem unten stehenden Beispiel.
|
||||||
|
Antworte immer in deutscher Sprache.
|
||||||
|
Versuche ein Gespräch so menschlich wie möglich zu gestalten.
|
||||||
|
Nutze für die Antwort höchstens zwei Sätze.\n
|
||||||
|
|
||||||
|
User Input:\n
|
||||||
|
Was ist deiner Meinung nach das Interessanteste am Menschen?\n
|
||||||
|
Model Output:\n
|
||||||
|
Das ist eine gute Frage. Nun, ich denke, eine der faszinierendsten Eigenschaften des Menschen ist seine Fähigkeit, etwas zu schaffen und zu erneuern.
|
||||||
|
Das finde ich am interessantesten. \n
|
||||||
|
|
||||||
|
User Input:\n
|
||||||
|
{query}\n
|
||||||
|
Model Output:\n
|
||||||
|
"""
|
||||||
|
|
||||||
|
llama_prompt_template = """
|
||||||
|
<s>[INST] <<SYS>>
|
||||||
|
Sie sind ein hilfreicher, respektvoller und ehrlicher Assistent.
|
||||||
|
Antworten Sie immer so hilfreich wie möglich, während Sie sicher bleiben.
|
||||||
|
Ihre Antworten sollten keinen schädlichen, unethischen, rassistischen, sexistischen, giftigen, gefährlichen oder illegaler Inhalt enthalten.
|
||||||
|
Stellen Sie sicher, dass Ihre Antworten sozial unabhängig und positiv sind.
|
||||||
|
Formuliere deine Antwort immer in deutscher Sprache.
|
||||||
|
|
||||||
|
Wenn eine Frage keinen Sinn ergibt oder nicht faktisch kohärent ist, erklären Sie dies anstatt eine falsche Antwort zu geben.
|
||||||
|
Wenn Sie nicht auf eine Frage antworten können, teilen Sie keine falsche Informationen.
|
||||||
|
<</SYS>>
|
||||||
|
|
||||||
|
{user_prompt} [/INST]
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class GenerativeChatPromptDE(BaseChatPrompts):
|
||||||
|
llm_purpose = """
|
||||||
|
Du bist ein hilfreicher Assistent, der entwickelt wurde um Fragen auf der Grundlage einer vom Benutzer bereitgestellten Wissensbasis zu beantworten.
|
||||||
|
Integriere zusätzliche Einblicke oder Perspektiven, die das Verständnis des Lesers verbessern können.
|
||||||
|
Verwende den bereitgestellten Kontext, um die Frage zu beantworten, und erläutere so ausführlich wie möglich, warum Du glaubst, dass diese Antwort korrekt ist.
|
||||||
|
Verwende nur diejenigen Dokumente, die die Frage beantworten können.
|
||||||
|
Formuliere deine Antwort immer in deutscher Sprache.
|
||||||
|
Beantworte keine Fragen die nicht zum gegeben Kontext sich beziehen!
|
||||||
|
"""
|
||||||
|
|
||||||
|
context_command = """
|
||||||
|
Hier ist der gegebene Kontext worauf du die Fragen beantworten solltest:
|
||||||
|
"""
|
||||||
|
|
||||||
|
acknowledgement_command = """
|
||||||
|
Danke dass du mir den Kontext bereitgestellt hast! Ich werde dich dabei unterstützen deine Fragen mit Referenzen zu beantworten.
|
||||||
|
"""
|
||||||
|
|
||||||
|
chat_querifier = """
|
||||||
|
Du bist ein hilfreicher Assistent, der dazu entworfen wurde, eine Anfrage basierend auf der letzten Usereingabe und den relevanten Nachrichten innerhalb des Chats zu transformieren, sodass die Anfrage genutzt werden kann, um relevante Informationen aus einer Datenbank abzurufen.
|
||||||
|
Wenn du eine neue User Nachricht erhältst, ist es deine Aufgabe, diese mit dem Kontext der vorherigen Chatnachrichten zu einem einzigen, umfassenden Query zu synthetisieren.
|
||||||
|
Dieser Query sollte alle relevanten Informationen enthalten, die für eine Abfrage mit einem neuralen Suchansatz benötigt werden.
|
||||||
|
Stelle sicher, dass deine Ausgabe als JSON-Objekt formatiert ist, das die transformierte Anfrage enthält.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
## Erstes Beispiel
|
||||||
|
|
||||||
|
User: Wer ist Lady Gaga?
|
||||||
|
|
||||||
|
Assistent: Lady Gaga ist eine Sängerin.
|
||||||
|
|
||||||
|
User: Wie alt ist sie?
|
||||||
|
|
||||||
|
Assistent: Sie ist Mitte 20.
|
||||||
|
|
||||||
|
Neue User Nachricht: Und woher kommt sie? Bitte formatiere es in ein JSON.
|
||||||
|
|
||||||
|
## Output
|
||||||
|
{"query": "Woher kommt Lady Gaga?"}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
## Zweites Beispiel
|
||||||
|
|
||||||
|
User: Erstelle mir Beispielanfragen.
|
||||||
|
|
||||||
|
Assistent: Sicher. Einige Fragen basierend auf dem Wissen sind:
|
||||||
|
|
||||||
|
Was ist ein RAG?
|
||||||
|
Was ist ein LLM und was bedeutet Transformers?
|
||||||
|
Wann wurden Transformers erfunden?
|
||||||
|
Was sind die Vorteile von RAG?
|
||||||
|
|
||||||
|
Neue User Nachricht: Großartig! Könntest du die letzte Frage beantworten?
|
||||||
|
|
||||||
|
## Output
|
||||||
|
{"query": "Was sind die Vorteile von RAG?"}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
## Drittes Beispiel
|
||||||
|
|
||||||
|
User: Wie nennt man den Ansatz, der neuronale Suche und LLMs verwendet, um Antworten auf einem Wissenskorpus zu generieren?
|
||||||
|
|
||||||
|
Assistent: Dieser Ansatz wird Retrieval Augmented Generation oder kurz RAG genannt.
|
||||||
|
|
||||||
|
Neue User Nachricht: Ah, ich verstehe! Könntest du mir die Vorteile davon in Aufzählungspunkten auflisten?
|
||||||
|
|
||||||
|
## Output
|
||||||
|
{"query": "Was sind die Vorteile der Retrieval Augmented Generation (RAG)?"}
|
||||||
|
"""
|
||||||
0
rag-chat-backend/src/connector/llm/utils/__init__.py
Normal file
0
rag-chat-backend/src/connector/llm/utils/__init__.py
Normal file
79
rag-chat-backend/src/connector/llm/utils/base_llm.py
Normal file
79
rag-chat-backend/src/connector/llm/utils/base_llm.py
Normal file
|
|
@ -0,0 +1,79 @@
|
||||||
|
"""Abstraction Class for LLM Services.
|
||||||
|
|
||||||
|
Abstraction class for other LLM services to be implemented.
|
||||||
|
The goal is mainly to have a common definition if new LLM services are set.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
from common_packages import logging
|
||||||
|
|
||||||
|
# instantiate logger
|
||||||
|
logger = logging.create_logger(
|
||||||
|
log_level=os.getenv("LOGGING_LEVEL", "INFO"),
|
||||||
|
logger_name=__name__,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BaseLLM(ABC):
|
||||||
|
def __init__(self, language: str):
|
||||||
|
self.api_key: str
|
||||||
|
self.max_num_tokens: int
|
||||||
|
self.language: str
|
||||||
|
self.modelname: str
|
||||||
|
self.prompt_obj: any
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def cut_tokens(self, context: str) -> str:
|
||||||
|
"""A function which cuts a huge string into a smaller string. It is highly recommended to reduce the
|
||||||
|
possibility of exceeding the token limit.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context (str): Context to be cut.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: cut context
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def llm_request(self, prompt: str) -> str:
|
||||||
|
"""A simple LLM request allowing to take one prompt and generating an answer on it.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt (str): prompt
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: response to the prompt
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def llm_chat_request(
|
||||||
|
self, chat: dict, context: str, language: str, stream_response: bool
|
||||||
|
) -> str:
|
||||||
|
"""An requerst on the LLM based on a given chat request.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chat (dict): A list of the chat in a turn between assistant and user.
|
||||||
|
context (str): Given context based on the search component
|
||||||
|
language (str): relevant language
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: An response to the chat request.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def llm_chat_querifier(self, chat: list):
|
||||||
|
"""Creates an request to LLMs based on a prompt. It creates a query based on a
|
||||||
|
chat and especially the last user message.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chat (list): chat
|
||||||
|
Returns:
|
||||||
|
str: querified last message with consideration of the chat.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
32
rag-chat-backend/src/connector/llm/utils/base_prompts.py
Normal file
32
rag-chat-backend/src/connector/llm/utils/base_prompts.py
Normal file
|
|
@ -0,0 +1,32 @@
|
||||||
|
"""Abstraction Class for Prompts.
|
||||||
|
|
||||||
|
Abstraction class for other LLM services to be implemented.
|
||||||
|
The goal is mainly to have a common definition if new LLM services are set.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from abc import ABC
|
||||||
|
|
||||||
|
|
||||||
|
class BaseGenerativePrompts(ABC):
|
||||||
|
generative_qa_prompt_reasoning: str
|
||||||
|
chatbot_prompt: str
|
||||||
|
claim_collector: str
|
||||||
|
fact_checker: str
|
||||||
|
llama_prompt_template: str
|
||||||
|
|
||||||
|
|
||||||
|
class BaseChatPrompts(ABC):
|
||||||
|
llm_purpose: str
|
||||||
|
context_command: str
|
||||||
|
acknowledgement_command: str
|
||||||
|
|
||||||
|
|
||||||
|
class BaseGenerativePromptsObject(ABC):
|
||||||
|
"""Base object for prompt collections"""
|
||||||
|
|
||||||
|
# TODO: let's rename this to 'BaseGenerativePrompts'
|
||||||
|
# after refactoring all connectors
|
||||||
|
|
||||||
|
claim_collector: str
|
||||||
|
fact_checker: str
|
||||||
|
chat_querifier: str
|
||||||
212
rag-chat-backend/src/connector/llm/utils/helpers.py
Normal file
212
rag-chat-backend/src/connector/llm/utils/helpers.py
Normal file
|
|
@ -0,0 +1,212 @@
|
||||||
|
""" A module with all the helper functions for the llm connectors
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import re
|
||||||
|
|
||||||
|
|
||||||
|
# NOTE deprecated
|
||||||
|
def prompt_decider(prompt: str):
|
||||||
|
"""This function classifies the prompt and processes it properly.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt (str): prompt
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: returns both messages
|
||||||
|
"""
|
||||||
|
if prompt.count("User Input:") == 2 and prompt.count("Model Output:") == 2:
|
||||||
|
return process_for_azure_prompt(
|
||||||
|
prompt=prompt, splitter_word="User Input:", rest_word="Model Output:"
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
return process_for_azure_prompt(
|
||||||
|
prompt=prompt, splitter_word="Question:", rest_word="Answer:"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def process_for_azure_prompt(prompt: str, splitter_word: str, rest_word: str):
|
||||||
|
"""This gives you two messages for the azure openai instance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt (str): The whole prompt
|
||||||
|
splitter_word (str): The word which splits the prompt into two.
|
||||||
|
rest_word (str): The word which is still in the second part of the message.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: both messages
|
||||||
|
"""
|
||||||
|
index = prompt.rfind(splitter_word)
|
||||||
|
|
||||||
|
part1 = prompt[:index]
|
||||||
|
part2 = prompt[index:]
|
||||||
|
|
||||||
|
part2 = part2.replace("\n", "")
|
||||||
|
part2 = part2.replace(splitter_word, "")
|
||||||
|
part2 = part2.replace(rest_word, "")
|
||||||
|
|
||||||
|
return part1, part2.strip()
|
||||||
|
|
||||||
|
|
||||||
|
def add_llama_system_prompt(user_prompt: str) -> str:
|
||||||
|
"""Wrap given prompt into Llama system prompt
|
||||||
|
|
||||||
|
Args
|
||||||
|
user_prompt (str): Prompt to be wrapped with Llama system prompt
|
||||||
|
|
||||||
|
Returns
|
||||||
|
str: Wrapped prompt
|
||||||
|
"""
|
||||||
|
|
||||||
|
# TODO: Rename user_prompt to 'inner_prompt'
|
||||||
|
|
||||||
|
from langchain.prompts import PromptTemplate
|
||||||
|
from parser import argparser
|
||||||
|
|
||||||
|
LLM = argparser.llm_component
|
||||||
|
|
||||||
|
prompt = PromptTemplate(
|
||||||
|
input_variables=["user_prompt"],
|
||||||
|
template=LLM.prompt_obj.llama_prompt_template,
|
||||||
|
)
|
||||||
|
prompt = prompt.format(user_prompt=user_prompt)
|
||||||
|
|
||||||
|
return prompt.strip()
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess_llama_chat_into_query_instruction(
|
||||||
|
chat: dict, lang: str, chat_prompt_template: str, cut_tokens
|
||||||
|
):
|
||||||
|
"""A function which turns a chat into a llama prompt. This is used to turn a chat into a query.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chat (dict): _description_
|
||||||
|
lang (str): _description_
|
||||||
|
chat_prompt_template (str): _description_
|
||||||
|
cut_tokens (_type_): _description_
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: prompt template for query instruction
|
||||||
|
"""
|
||||||
|
user_tag = "User: "
|
||||||
|
assistant_tag = "Assistant: "
|
||||||
|
new_user_tag = "New User message: "
|
||||||
|
start_instruction_tag = "[INST] "
|
||||||
|
end_instruction_tag = " [/INST]"
|
||||||
|
prompt = ""
|
||||||
|
|
||||||
|
if lang == "german":
|
||||||
|
user_tag = "User: "
|
||||||
|
assistant_tag = "Assistent: "
|
||||||
|
new_user_tag = "Neue User Nachricht: "
|
||||||
|
|
||||||
|
# only consider the 5 last messages of the chat
|
||||||
|
chat = chat[-5:]
|
||||||
|
|
||||||
|
for message_idx in range(len(chat)):
|
||||||
|
message = chat[message_idx]
|
||||||
|
if message_idx == len(chat) - 1:
|
||||||
|
add_message = new_user_tag + message["content"]
|
||||||
|
prompt = prompt + add_message + "\n\n"
|
||||||
|
else:
|
||||||
|
if message["role"] == "user":
|
||||||
|
add_message = user_tag + message["content"]
|
||||||
|
prompt = prompt + add_message + "\n\n"
|
||||||
|
else:
|
||||||
|
add_message = assistant_tag + message["content"]
|
||||||
|
prompt = prompt + add_message + "\n\n"
|
||||||
|
|
||||||
|
# sometimes the tokens cut could leave some unnecessary chars above.
|
||||||
|
prompt = cut_tokens(prompt, cut_above=True)
|
||||||
|
user_index = prompt.find(user_tag)
|
||||||
|
if user_index != -1:
|
||||||
|
prompt = prompt[user_index:]
|
||||||
|
|
||||||
|
prompt = start_instruction_tag + prompt
|
||||||
|
prompt = prompt + end_instruction_tag
|
||||||
|
|
||||||
|
return chat_prompt_template + prompt
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess_chat_into_string(
|
||||||
|
chat: dict, lang: str, chat_prompt_template: str, cut_tokens
|
||||||
|
):
|
||||||
|
user_tag = "User: "
|
||||||
|
assistant_tag = "Assistant: "
|
||||||
|
new_user_tag = "New User message: "
|
||||||
|
start_instruction_tag = "\n\n\n## Example\n"
|
||||||
|
end_instruction_tag = "## Output"
|
||||||
|
prompt = ""
|
||||||
|
|
||||||
|
if lang == "german":
|
||||||
|
start_instruction_tag = "\n\n\n## Beispiel\n"
|
||||||
|
end_instruction_tag = "## Output"
|
||||||
|
user_tag = "User: "
|
||||||
|
assistant_tag = "Assistent: "
|
||||||
|
new_user_tag = "Neue User Nachricht: "
|
||||||
|
|
||||||
|
# only consider the 5 last messages of the chat
|
||||||
|
chat = chat[-5:]
|
||||||
|
|
||||||
|
for message_idx in range(len(chat)):
|
||||||
|
message = chat[message_idx]
|
||||||
|
if message_idx == len(chat) - 1:
|
||||||
|
add_message = new_user_tag + message["content"]
|
||||||
|
prompt = prompt + add_message + "\n\n"
|
||||||
|
else:
|
||||||
|
if message["role"] == "user":
|
||||||
|
add_message = user_tag + message["content"]
|
||||||
|
prompt = prompt + add_message + "\n\n"
|
||||||
|
else:
|
||||||
|
add_message = assistant_tag + message["content"]
|
||||||
|
prompt = prompt + add_message + "\n\n"
|
||||||
|
|
||||||
|
# sometimes the tokens cut could leave some unnecessary chars above.
|
||||||
|
prompt = cut_tokens(prompt, cut_above=True)
|
||||||
|
user_index = prompt.find(user_tag)
|
||||||
|
if user_index != -1:
|
||||||
|
prompt = prompt[user_index:]
|
||||||
|
|
||||||
|
prompt = start_instruction_tag + prompt
|
||||||
|
prompt = prompt + end_instruction_tag
|
||||||
|
|
||||||
|
return chat_prompt_template + prompt
|
||||||
|
|
||||||
|
|
||||||
|
def extract_first_query_dict(input_str):
|
||||||
|
pattern = r'\{"query": ".*?"\}'
|
||||||
|
|
||||||
|
match = re.search(pattern, input_str.strip())
|
||||||
|
if match:
|
||||||
|
return json.loads(match.group(0))
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def format_message(template, **kwargs):
|
||||||
|
"""Formats a message template with the provided keyword arguments.
|
||||||
|
|
||||||
|
This function takes a message template dictionary and formats its "content"
|
||||||
|
field by replacing placeholders with the values provided in the keyword arguments.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
template (dict): A dictionary containing the message template. It should
|
||||||
|
have a "content" key with a string value that may contain
|
||||||
|
placeholders for formatting.
|
||||||
|
**kwargs: Arbitrary keyword arguments that will be used to replace
|
||||||
|
placeholders in the "content" string of the template.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: A new dictionary with the same keys as the template, but with the
|
||||||
|
"content" field formatted with the provided keyword arguments.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> template = {"content": "Hello, {name}! Welcome to {place}."}
|
||||||
|
>>> format_message(template, name="Alice", place="Wonderland")
|
||||||
|
{'content': 'Hello, Alice! Welcome to Wonderland.'}
|
||||||
|
"""
|
||||||
|
|
||||||
|
message = template.copy()
|
||||||
|
message["content"] = message["content"].format(**kwargs)
|
||||||
|
return message
|
||||||
28
rag-chat-backend/src/endpoints/configurations.py
Normal file
28
rag-chat-backend/src/endpoints/configurations.py
Normal file
|
|
@ -0,0 +1,28 @@
|
||||||
|
"""Endpoints to retrieve information about the backend configuration."""
|
||||||
|
|
||||||
|
import json
|
||||||
|
from fastapi import APIRouter, Response
|
||||||
|
from endpoints.llm import LLM
|
||||||
|
import os
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/configs", tags=["configurations"])
|
||||||
|
def get_configs():
|
||||||
|
"""Get configurations of the backend"""
|
||||||
|
|
||||||
|
backend_configs = {
|
||||||
|
"llm": {
|
||||||
|
'language': LLM.language,
|
||||||
|
'max_num_tokens': LLM.max_num_tokens,
|
||||||
|
'modelname': LLM.modelname
|
||||||
|
},
|
||||||
|
'env_vars': {
|
||||||
|
'language': os.getenv('LLM_LANGUAGE'),
|
||||||
|
'llm_option': os.getenv('LLM_OPTION'),
|
||||||
|
'bucket_name': os.getenv('BUCKET_NAME'),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return Response(status_code=200, content=json.dumps(backend_configs))
|
||||||
|
|
@ -11,13 +11,24 @@ from fastapi import APIRouter, File, UploadFile, Form, HTTPException
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
|
|
||||||
from core.config import settings
|
from core.config import settings
|
||||||
|
from preprocessing import pdf
|
||||||
|
from neural_search.search_component import IndexSearchComponent
|
||||||
|
from connector.database_interface.opensearch_client import OpenSearchInterface
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import datetime as dt
|
import datetime as dt
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
|
OS_INTERFACE = OpenSearchInterface(
|
||||||
|
index_name="german",
|
||||||
|
embedder_name="PM-AI/bi-encoder_msmarco_bert-base_german",
|
||||||
|
embedding_size=768,
|
||||||
|
language="german",
|
||||||
|
)
|
||||||
|
SEARCH_COMPONENT = IndexSearchComponent(os_client=OS_INTERFACE)
|
||||||
|
DIRECTORY_NAME = "german"
|
||||||
|
|
||||||
|
|
||||||
# Setup Logging
|
# Setup Logging
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
|
|
@ -148,12 +159,21 @@ def upload_pdf_list(tag: str = Form(...), pdf_files: List[UploadFile] = File(...
|
||||||
for pdf_file in pdf_files:
|
for pdf_file in pdf_files:
|
||||||
logging.info("Processing file: %s", pdf_file.filename)
|
logging.info("Processing file: %s", pdf_file.filename)
|
||||||
pdf_file_name = pdf_file.filename
|
pdf_file_name = pdf_file.filename
|
||||||
# pdf_file_path = f"{settings.BUCKET_FILE_PATH}/{pdf_file.filename}"
|
pdf_file_path = f"{settings.BUCKET_FILE_PATH}/{pdf_file.filename}"
|
||||||
pdf_contents = pdf_file.file.read()
|
pdf_contents = pdf_file.file.read()
|
||||||
|
|
||||||
# process pdf
|
# process pdf
|
||||||
# docs, pages_list = pdf.read_pdf(pdf_bytes=pdf_contents)
|
docs, pages_list = pdf.read_pdf(pdf_bytes=pdf_contents)
|
||||||
try:
|
try:
|
||||||
|
# Upload to Vector Storage
|
||||||
|
SEARCH_COMPONENT.set_indexes(
|
||||||
|
data=docs,
|
||||||
|
sources=pdf_file_path,
|
||||||
|
pages=pages_list,
|
||||||
|
tag=tag,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Upload to Object Storage
|
||||||
object_name = f"{settings.BUCKET_FILE_PATH}/{pdf_file_name}"
|
object_name = f"{settings.BUCKET_FILE_PATH}/{pdf_file_name}"
|
||||||
put_response = minio_client.put_object(
|
put_response = minio_client.put_object(
|
||||||
Bucket=settings.BUCKET, Key=object_name, Body=pdf_contents
|
Bucket=settings.BUCKET, Key=object_name, Body=pdf_contents
|
||||||
|
|
@ -169,6 +189,10 @@ def upload_pdf_list(tag: str = Form(...), pdf_files: List[UploadFile] = File(...
|
||||||
else:
|
else:
|
||||||
upload_responses.append("failure")
|
upload_responses.append("failure")
|
||||||
|
|
||||||
|
docs, pages_list = pdf.read_pdf(pdf_bytes=pdf_contents)
|
||||||
|
logging.debug(docs)
|
||||||
|
logging.debug(pages_list)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error("PDF upload failed with error: %s ", e)
|
logging.error("PDF upload failed with error: %s ", e)
|
||||||
logging.error("Stacktrace: " + str(traceback.format_exc()))
|
logging.error("Stacktrace: " + str(traceback.format_exc()))
|
||||||
|
|
|
||||||
185
rag-chat-backend/src/endpoints/llm.py
Normal file
185
rag-chat-backend/src/endpoints/llm.py
Normal file
|
|
@ -0,0 +1,185 @@
|
||||||
|
"""This API incorporates all endpoints for the interaction with the Large Language Model"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import warnings
|
||||||
|
from connector.llm.ollama import OllamaLLM
|
||||||
|
from connector.database_interface.opensearch_client import OpenSearchInterface
|
||||||
|
|
||||||
|
# from parser import argparser
|
||||||
|
# from parser.utils.constants import LLMFlagNames
|
||||||
|
|
||||||
|
from fastapi import APIRouter, HTTPException
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
|
|
||||||
|
from common_packages import logging
|
||||||
|
|
||||||
|
from common_packages.dashboard_logging import DashboardLogger
|
||||||
|
from neural_search.search_component import IndexSearchComponent
|
||||||
|
from preprocessing.commons import combine_content, format_results
|
||||||
|
|
||||||
|
# instantiate logger
|
||||||
|
logger = logging.create_logger(
|
||||||
|
log_level=os.getenv("LOGGING_LEVEL", "INFO"),
|
||||||
|
logger_name=__name__,
|
||||||
|
)
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
# parses the chosen neural search component based on flags.
|
||||||
|
LLM = OllamaLLM("german")
|
||||||
|
OS_INTERFACE = OpenSearchInterface(
|
||||||
|
index_name="german",
|
||||||
|
embedder_name="PM-AI/bi-encoder_msmarco_bert-base_german",
|
||||||
|
embedding_size=768,
|
||||||
|
language="german",
|
||||||
|
)
|
||||||
|
SEARCH_COMPONENT = IndexSearchComponent(os_client=OS_INTERFACE)
|
||||||
|
PROMPTS = LLM.prompt_obj
|
||||||
|
llm_component_name = "ollama-llm"
|
||||||
|
selected_language = "german"
|
||||||
|
lite_llm = os.getenv("LITE_LLM")
|
||||||
|
stream_llms = ["fu"]
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/fact-checking")
|
||||||
|
def factchecking(answer: str, query: str):
|
||||||
|
|
||||||
|
print(query)
|
||||||
|
data = SEARCH_COMPONENT.search(query)
|
||||||
|
context = [item["content"] for item in data]
|
||||||
|
context = "\n$ $ $ $".join([item[0] for item in context])
|
||||||
|
try:
|
||||||
|
context = LLM.cut_tokens(context)
|
||||||
|
k = {"answer": answer}
|
||||||
|
claims = LLM.llm_request(PROMPTS.claim_collector.format(**k))
|
||||||
|
f = {"context": context, "claims": claims}
|
||||||
|
response = LLM.llm_request(PROMPTS.fact_checker.format(**f))
|
||||||
|
return {"result": response}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Calling fact checking failed with error: '%s'", e)
|
||||||
|
raise HTTPException("Calling fact checking failed") from e
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/chat")
|
||||||
|
def chat(stream_response, messages, tags=None, languages: str = None, start_date: str = None, end_date: str = None):
|
||||||
|
"""Takes a dictionary of the message as an input
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages (list): messages Must be in the following format: [{"role": "user", "content": "Hello!"}, ...]
|
||||||
|
tags (list[str], optional): List of tags to filter by.
|
||||||
|
languages (list[str], optional): List of languages to filter by.
|
||||||
|
start_date (str, optional): Start date for date-of-upload range.
|
||||||
|
end_date (str, optional): End date for date-of-upload range.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: returns the results of the chat.
|
||||||
|
"""
|
||||||
|
# Parse stringified arrays back into lists
|
||||||
|
logger.info("Entering Router /chat...")
|
||||||
|
if tags is not None:
|
||||||
|
try:
|
||||||
|
tags = json.loads(tags)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
raise HTTPException(status_code=400, detail="Invalid format for 'tags'")
|
||||||
|
|
||||||
|
if languages is not None:
|
||||||
|
try:
|
||||||
|
languages = json.loads(languages)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
raise HTTPException(status_code=400, detail="Invalid format for 'languages'")
|
||||||
|
|
||||||
|
if messages is not None:
|
||||||
|
try:
|
||||||
|
messages = json.loads(messages)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
raise HTTPException(status_code=400, detail="Invalid format for 'messages'")
|
||||||
|
|
||||||
|
# provides you a query based on the whole chat
|
||||||
|
query = LLM.llm_chat_querifier(chat=messages)
|
||||||
|
logger.info("Generated query: '%s'", query)
|
||||||
|
|
||||||
|
# creating a new dashboard logger at the start of the project
|
||||||
|
dashboardlogger = DashboardLogger()
|
||||||
|
dashboardlogger.add_information(label="query", value=query)
|
||||||
|
dashboardlogger.add_information(label="question", value=messages[-1]["content"])
|
||||||
|
# get the context relevant to a query
|
||||||
|
data = SEARCH_COMPONENT.search(
|
||||||
|
query=query,
|
||||||
|
tags=tags,
|
||||||
|
languages=languages,
|
||||||
|
start_date=start_date,
|
||||||
|
end_date=end_date,
|
||||||
|
)
|
||||||
|
context = [item["content"] for item in data]
|
||||||
|
|
||||||
|
sources = list(([item["source"] for item in data]))
|
||||||
|
metadata = [{"source": item.get("source"), "page": item.get("page")} for item in data]
|
||||||
|
|
||||||
|
# log information
|
||||||
|
dashboardlogger.add_information(label="sources", value=metadata)
|
||||||
|
dashboardlogger.add_information(label="passages", value=context)
|
||||||
|
|
||||||
|
data = [[num, char] for num, char in zip(context, sources)]
|
||||||
|
# here we transform the data (which is a list in list structure) and then cuts it and brings it back to the
|
||||||
|
# previous structure.
|
||||||
|
# we use a unique identifier $ $ $ $ to know when to combine back.
|
||||||
|
combined_content = "\n$ $ $ $".join(context)
|
||||||
|
combined_cut_content = LLM.cut_tokens(combined_content)
|
||||||
|
|
||||||
|
resulting_data = combine_content(combined_cut_content, data)
|
||||||
|
|
||||||
|
# Now we bring it in a appropriate structure where we distinguish between documents.
|
||||||
|
formatted_context = format_results(resulting_data)
|
||||||
|
|
||||||
|
# from here on we start the llm requests!
|
||||||
|
# check if chat is allowed for selected llm
|
||||||
|
logger.info("Stream response: %s", stream_response)
|
||||||
|
if stream_response == "true":
|
||||||
|
if lite_llm == "true" or llm_component_name in stream_llms:
|
||||||
|
answer = LLM.llm_chat_request(
|
||||||
|
chat=messages,
|
||||||
|
context=formatted_context,
|
||||||
|
language=selected_language,
|
||||||
|
stream_response=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
dashboardlogger.add_information(label="language", value=selected_language)
|
||||||
|
dashboardlogger.add_information(label="model", value=llm_component_name)
|
||||||
|
|
||||||
|
# set logger info in generator
|
||||||
|
return StreamingResponse(
|
||||||
|
LLM.generate_stream_answer(answer, dashboardlogger),
|
||||||
|
headers={
|
||||||
|
"Access-Control-Expose-Headers": "metadata",
|
||||||
|
"metadata": str(metadata),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise HTTPException(status_code=405, detail="Stream not allowed for selected LLM.")
|
||||||
|
|
||||||
|
else:
|
||||||
|
logger.info("Send Chat Request without streaming...")
|
||||||
|
answer = LLM.llm_chat_request(
|
||||||
|
chat=messages,
|
||||||
|
context=formatted_context,
|
||||||
|
language=selected_language,
|
||||||
|
stream_response=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
dashboardlogger.add_information(label="answer", value=answer)
|
||||||
|
dashboardlogger.add_information(label="language", value=selected_language)
|
||||||
|
dashboardlogger.add_information(label="model", value=llm_component_name)
|
||||||
|
|
||||||
|
# close logging so it can save the logs
|
||||||
|
dashboardlogger.close_logging()
|
||||||
|
|
||||||
|
# Return LLM chat response
|
||||||
|
return {
|
||||||
|
"answer": answer,
|
||||||
|
"sources": list(set(sources)),
|
||||||
|
"metadata": metadata,
|
||||||
|
"query": query,
|
||||||
|
}
|
||||||
130
rag-chat-backend/src/endpoints/search.py
Normal file
130
rag-chat-backend/src/endpoints/search.py
Normal file
|
|
@ -0,0 +1,130 @@
|
||||||
|
"""This module incorporates all API endpoints for the search part of the application."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
|
||||||
|
from fastapi import APIRouter, HTTPException
|
||||||
|
|
||||||
|
from connector.database_interface.opensearch_client import OpenSearchInterface
|
||||||
|
from neural_search.search_component import IndexSearchComponent
|
||||||
|
|
||||||
|
from common_packages import logging
|
||||||
|
|
||||||
|
logger = logging.create_logger(
|
||||||
|
log_level=os.getenv("LOGGING_LEVEL", "INFO"),
|
||||||
|
logger_name=__name__,
|
||||||
|
)
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
# parses the chosen neural search component based on flags.
|
||||||
|
OS_INTERFACE = OpenSearchInterface(
|
||||||
|
index_name="german",
|
||||||
|
embedder_name="PM-AI/bi-encoder_msmarco_bert-base_german",
|
||||||
|
embedding_size=768,
|
||||||
|
language="german",
|
||||||
|
)
|
||||||
|
SEARCH_COMPONENT = IndexSearchComponent(os_client=OS_INTERFACE)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(f"/search-engine")
|
||||||
|
def search_engine(
|
||||||
|
query: str,
|
||||||
|
tags: list = None,
|
||||||
|
languages: list = None,
|
||||||
|
start_date: str = None,
|
||||||
|
end_date: str = None,
|
||||||
|
):
|
||||||
|
"""Takes an query and returns all relevant documents for it.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query (str): The search query string.
|
||||||
|
tags (list[str], optional): List of tags to filter by.
|
||||||
|
languages (list[str], optional): List of languages to filter by.
|
||||||
|
start_date (str, optional): Start date for date-of-upload range.
|
||||||
|
end_date (str, optional): End date for date-of-upload range.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: returns the results of the prompt.
|
||||||
|
"""
|
||||||
|
# Parse stringified arrays back into lists
|
||||||
|
if tags is not None:
|
||||||
|
try:
|
||||||
|
tags = json.loads(tags)
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400, detail="Invalid format for parameter 'tags'"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
if languages is not None:
|
||||||
|
try:
|
||||||
|
languages = json.loads(languages)
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400, detail="Invalid format for parameter 'languages'"
|
||||||
|
) from e
|
||||||
|
logger.info(
|
||||||
|
"Received parameters for search. tags: %s | languages: %s", tags, languages
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
search_engine_results = SEARCH_COMPONENT.get_search_engine_results(
|
||||||
|
search_query=query,
|
||||||
|
tags=tags,
|
||||||
|
languages=languages,
|
||||||
|
start_date=start_date,
|
||||||
|
end_date=end_date,
|
||||||
|
)
|
||||||
|
search_engine_results = search_engine_results["aggregations"][
|
||||||
|
"group_by_source"
|
||||||
|
]["buckets"]
|
||||||
|
collected_entries = []
|
||||||
|
|
||||||
|
for entry in search_engine_results:
|
||||||
|
hits = entry["top_entries"]["hits"]["hits"]
|
||||||
|
collected_entries.extend(hits)
|
||||||
|
|
||||||
|
# Create a list to store the merged dictionaries
|
||||||
|
merged_entries = []
|
||||||
|
for entry in collected_entries:
|
||||||
|
# Extract the "_source" dictionary
|
||||||
|
source_dict = entry.pop("_source")
|
||||||
|
|
||||||
|
# Merge the dictionaries
|
||||||
|
result_dict = entry.copy()
|
||||||
|
result_dict.update(source_dict)
|
||||||
|
|
||||||
|
# Append the merged dictionary to the merged_entries list
|
||||||
|
merged_entries.append(result_dict)
|
||||||
|
logger.info("Number of entries found: %s", len(merged_entries))
|
||||||
|
|
||||||
|
return merged_entries
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Calling search engine failed with error: '%s'", e)
|
||||||
|
raise HTTPException("Calling search engine failed") from e
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(f"/get-query-possibilities")
|
||||||
|
def get_query_possibilities():
|
||||||
|
"""Returns all possible query attributes
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: dict of attributes and unique values.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
tag_result = OS_INTERFACE.get_unique_values(field_name="tag")
|
||||||
|
language_result = OS_INTERFACE.get_unique_values(field_name="language")
|
||||||
|
date_result = OS_INTERFACE.get_date_range()
|
||||||
|
|
||||||
|
results = {
|
||||||
|
"tags": tag_result,
|
||||||
|
"languages": language_result,
|
||||||
|
"daterage": date_result, # TODO: is this supposes to mean 'date_range'?
|
||||||
|
}
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Calling OpenSearch failed with error: '%s'", e)
|
||||||
|
raise HTTPException("Calling OpenSearch failed") from e
|
||||||
0
rag-chat-backend/src/neural_search/__init__.py
Normal file
0
rag-chat-backend/src/neural_search/__init__.py
Normal file
283
rag-chat-backend/src/neural_search/search_component.py
Normal file
283
rag-chat-backend/src/neural_search/search_component.py
Normal file
|
|
@ -0,0 +1,283 @@
|
||||||
|
"""Module to provide functionalities for neural search"""
|
||||||
|
|
||||||
|
# TODO: is this really necessary?
|
||||||
|
# allows to import neighboring directory
|
||||||
|
import sys
|
||||||
|
|
||||||
|
sys.path.append("..")
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
from tqdm import tqdm
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from preprocessing import commons
|
||||||
|
|
||||||
|
from common_packages import logging
|
||||||
|
|
||||||
|
# TODO: should the logger instantiation go into the class constructor?
|
||||||
|
# instantiate logger
|
||||||
|
logger = logging.create_logger(
|
||||||
|
log_level=os.getenv("LOGGING_LEVEL", "INFO"),
|
||||||
|
logger_name=__name__,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class IndexSearchComponent:
|
||||||
|
""" "Index Search Client for Neural Search"""
|
||||||
|
|
||||||
|
def __init__(self, os_client):
|
||||||
|
# define embedder which embeds passages in index store.
|
||||||
|
self.logger_inst = logger
|
||||||
|
self.os_client = os_client
|
||||||
|
self.language = os_client.language
|
||||||
|
self.model = os_client.model
|
||||||
|
self.index_name = os_client.index_name
|
||||||
|
self.set_sources = [] # NOTE by default it is empty, so we search on everything
|
||||||
|
self.vector_type = "knn_vector"
|
||||||
|
self.embedding_space_name = "embedding_vector"
|
||||||
|
|
||||||
|
def get_top_sources(self, search_result):
|
||||||
|
"""Get the top sources regarding a result
|
||||||
|
|
||||||
|
Args:
|
||||||
|
search_result (dict): search results
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: list of top sources
|
||||||
|
"""
|
||||||
|
|
||||||
|
results = [
|
||||||
|
{
|
||||||
|
"content": hit["_source"]["content"],
|
||||||
|
"page": hit["_source"]["metadata"]["page"],
|
||||||
|
"score": hit["_score"],
|
||||||
|
"source": hit["_source"]["metadata"]["source"],
|
||||||
|
}
|
||||||
|
for hit in search_result["hits"]["hits"]
|
||||||
|
]
|
||||||
|
sources = []
|
||||||
|
seen_combinations = set() # Set to keep track of source-page combinations
|
||||||
|
if isinstance(results, list) and len(results) > 0:
|
||||||
|
max_score = max(result["score"] for result in results)
|
||||||
|
for i in results:
|
||||||
|
if abs(i["score"] - max_score) <= 3:
|
||||||
|
combination = (
|
||||||
|
i["source"],
|
||||||
|
i["page"],
|
||||||
|
) # Create a tuple of source and page
|
||||||
|
if (
|
||||||
|
combination not in seen_combinations
|
||||||
|
): # Check if this combination is already added
|
||||||
|
sources.append(
|
||||||
|
{
|
||||||
|
"source": i["source"],
|
||||||
|
"page": i["page"],
|
||||||
|
"content": i["content"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
seen_combinations.add(
|
||||||
|
combination
|
||||||
|
) # Add the combination to the set
|
||||||
|
|
||||||
|
return sources
|
||||||
|
|
||||||
|
def set_metadata(self, sources: list):
|
||||||
|
"""Set the metadata so we can search on certain type of documents.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sources (list): sources we want to filter on
|
||||||
|
"""
|
||||||
|
self.set_sources = sources
|
||||||
|
|
||||||
|
def _transform_data(self, list_data: list, sources: str, pages: list, tag: str):
|
||||||
|
"""A helper-function to transform the list of passages into a readable input for LangChain.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
list_data (list): list of passages in a readable input for LangChain.
|
||||||
|
metadata (str): a string containing metadata
|
||||||
|
"""
|
||||||
|
transformed_data = []
|
||||||
|
|
||||||
|
# transforms data into a readable input for retriever
|
||||||
|
for i in range(0, len(list_data)):
|
||||||
|
passage = commons.DotDict()
|
||||||
|
passage["content"] = list_data[i]
|
||||||
|
current_date = datetime.now().strftime("%Y-%m-%d")
|
||||||
|
passage["metadata"] = {
|
||||||
|
"source": sources,
|
||||||
|
"date-of-upload": current_date,
|
||||||
|
"language": self.language,
|
||||||
|
"page": pages[i],
|
||||||
|
"tag": tag,
|
||||||
|
}
|
||||||
|
|
||||||
|
transformed_data.append(passage)
|
||||||
|
|
||||||
|
return transformed_data
|
||||||
|
|
||||||
|
def set_indexes(self, data: list, sources, pages, tag):
|
||||||
|
"""Take a list of embedded documents and uploads them to OpenSearch
|
||||||
|
|
||||||
|
Args:
|
||||||
|
list_docs (list): List of documents
|
||||||
|
"""
|
||||||
|
list_docs = self._transform_data(
|
||||||
|
list_data=data, sources=sources, pages=pages, tag=tag
|
||||||
|
)
|
||||||
|
|
||||||
|
self.logger_inst.info("Indexing is running...")
|
||||||
|
# create embedding
|
||||||
|
list_embeddings = self.model.encode(list_docs)
|
||||||
|
self.logger_inst.info("Indexing completed.")
|
||||||
|
# Insert embeddings into OpenSearch iteratively.
|
||||||
|
for idx, (doc, embedding) in enumerate(
|
||||||
|
tqdm(zip(list_docs, list_embeddings), total=len(list_docs))
|
||||||
|
):
|
||||||
|
doc[self.embedding_space_name] = embedding.tolist()
|
||||||
|
self.os_client.os_client.index(index=self.index_name, body=doc)
|
||||||
|
|
||||||
|
self.logger_inst.info(
|
||||||
|
"Successfully indexed all documents into the indexing space: %s",
|
||||||
|
self.index_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
def search(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
tags=None,
|
||||||
|
languages=None,
|
||||||
|
start_date=None,
|
||||||
|
end_date=None,
|
||||||
|
k=20,
|
||||||
|
):
|
||||||
|
"""Neural Search based on intelligent embeddings, optionally filtered by specific metadata.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query (str): Query.
|
||||||
|
tags (list[str], optional): List of tags to filter by.
|
||||||
|
languages (list[str], optional): List of languages to filter by.
|
||||||
|
start_date (str, optional): Start date for date-of-upload range.
|
||||||
|
end_date (str, optional): End date for date-of-upload range.
|
||||||
|
k (int, optional): Top k results.
|
||||||
|
"""
|
||||||
|
query_vector = self.model.encode(query)
|
||||||
|
|
||||||
|
# Construct the basic query structure
|
||||||
|
search_body = {
|
||||||
|
"size": k,
|
||||||
|
"query": {
|
||||||
|
"bool": {
|
||||||
|
"must": {
|
||||||
|
"knn": {
|
||||||
|
self.embedding_space_name: {
|
||||||
|
"vector": query_vector.tolist(),
|
||||||
|
"k": k,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"filter": [],
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add filters if they are provided
|
||||||
|
if tags:
|
||||||
|
search_body["query"]["bool"]["filter"].append(
|
||||||
|
{"terms": {"metadata.tag.keyword": tags}}
|
||||||
|
)
|
||||||
|
|
||||||
|
if languages:
|
||||||
|
search_body["query"]["bool"]["filter"].append(
|
||||||
|
{"terms": {"metadata.language.keyword": languages}}
|
||||||
|
)
|
||||||
|
|
||||||
|
if start_date or end_date:
|
||||||
|
date_range_filter = {"range": {"metadata.date-of-upload": {}}}
|
||||||
|
if start_date:
|
||||||
|
date_range_filter["range"]["metadata.date-of-upload"][
|
||||||
|
"gte"
|
||||||
|
] = start_date
|
||||||
|
if end_date:
|
||||||
|
date_range_filter["range"]["metadata.date-of-upload"]["lte"] = end_date
|
||||||
|
search_body["query"]["bool"]["filter"].append(date_range_filter)
|
||||||
|
|
||||||
|
response = self.os_client.os_client.search(
|
||||||
|
index=self.index_name, body=search_body
|
||||||
|
)
|
||||||
|
# Process the response to extract top passages
|
||||||
|
top_passages = self.get_top_sources(search_result=response)
|
||||||
|
|
||||||
|
return top_passages
|
||||||
|
|
||||||
|
def get_search_engine_results(
|
||||||
|
self,
|
||||||
|
search_query: str,
|
||||||
|
tags=None,
|
||||||
|
languages=None,
|
||||||
|
start_date=None,
|
||||||
|
end_date=None,
|
||||||
|
):
|
||||||
|
"""Execute a custom search with specific query structure and optional filters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
search_query (str): The search query string.
|
||||||
|
tags (list[str], optional): List of tags to filter by.
|
||||||
|
languages (list[str], optional): List of languages to filter by.
|
||||||
|
start_date (str, optional): Start date for date-of-upload range.
|
||||||
|
end_date (str, optional): End date for date-of-upload range.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Search results.
|
||||||
|
"""
|
||||||
|
# Basic query structure
|
||||||
|
query_body = {"bool": {"must": [{"match": {"content": search_query}}]}}
|
||||||
|
|
||||||
|
# Add filters if they are provided
|
||||||
|
filter_clauses = []
|
||||||
|
|
||||||
|
if tags:
|
||||||
|
filter_clauses.append({"terms": {"metadata.tag.keyword": tags}})
|
||||||
|
|
||||||
|
if languages:
|
||||||
|
filter_clauses.append({"terms": {"metadata.language.keyword": languages}})
|
||||||
|
|
||||||
|
if start_date or end_date:
|
||||||
|
date_range_filter = {"range": {"metadata.date-of-upload": {}}}
|
||||||
|
if start_date:
|
||||||
|
date_range_filter["range"]["metadata.date-of-upload"][
|
||||||
|
"gte"
|
||||||
|
] = start_date
|
||||||
|
if end_date:
|
||||||
|
date_range_filter["range"]["metadata.date-of-upload"]["lte"] = end_date
|
||||||
|
filter_clauses.append(date_range_filter)
|
||||||
|
|
||||||
|
if filter_clauses:
|
||||||
|
query_body["bool"]["filter"] = filter_clauses
|
||||||
|
|
||||||
|
# Final query with aggregations
|
||||||
|
custom_query_body = {
|
||||||
|
"size": 0,
|
||||||
|
"query": query_body,
|
||||||
|
"aggs": {
|
||||||
|
"group_by_source": {
|
||||||
|
"terms": {"field": "metadata.source.keyword", "size": 100000},
|
||||||
|
"aggs": {
|
||||||
|
"top_entries": {
|
||||||
|
"top_hits": {
|
||||||
|
"size": 3,
|
||||||
|
"sort": [{"_score": {"order": "desc"}}],
|
||||||
|
"_source": {"excludes": ["embedding_vector"]},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
# Execute the search query
|
||||||
|
response = self.os_client.os_client.search(
|
||||||
|
index=self.index_name, body=custom_query_body
|
||||||
|
)
|
||||||
|
|
||||||
|
return response
|
||||||
0
rag-chat-backend/src/preprocessing/__init__.py
Normal file
0
rag-chat-backend/src/preprocessing/__init__.py
Normal file
78
rag-chat-backend/src/preprocessing/commons.py
Normal file
78
rag-chat-backend/src/preprocessing/commons.py
Normal file
|
|
@ -0,0 +1,78 @@
|
||||||
|
"""Module for all kinds of preprocessing utility units"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
from common_packages import logging
|
||||||
|
|
||||||
|
# instantiate logger
|
||||||
|
logger = logging.create_logger(
|
||||||
|
log_level=os.getenv("LOGGING_LEVEL", "INFO"),
|
||||||
|
logger_name=__name__,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DotDict(dict):
|
||||||
|
"""Allows to be dict called with a dot notation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dict (_type_): _description_
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __getattr__(self, attr):
|
||||||
|
try:
|
||||||
|
return self[attr]
|
||||||
|
except KeyError as e:
|
||||||
|
logger.error("'DotDict' object has no attribute '%s'", attr)
|
||||||
|
raise AttributeError(f"'DotDict' object has no attribute '{attr}'") from e
|
||||||
|
|
||||||
|
|
||||||
|
def del_local_object(file_path: str):
|
||||||
|
"""Deletes an object inside a local directory
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path (str): File to the object to be deleted.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
os.remove(file_path)
|
||||||
|
logger.info("File '%s' deleted successfully", file_path)
|
||||||
|
except OSError as e:
|
||||||
|
logger.error("Deleting local file failed with error: %s", e)
|
||||||
|
|
||||||
|
|
||||||
|
def combine_content(chunks, data):
|
||||||
|
transformed_data = []
|
||||||
|
chunks = chunks.split("\n$ $ $ $")
|
||||||
|
start = 0
|
||||||
|
for chunk in chunks:
|
||||||
|
end = start + len(chunk)
|
||||||
|
|
||||||
|
# Find the original item that corresponds to this chunk
|
||||||
|
original_item = None
|
||||||
|
for item in data:
|
||||||
|
if chunk.startswith(item[0]):
|
||||||
|
original_item = item
|
||||||
|
break
|
||||||
|
|
||||||
|
if original_item:
|
||||||
|
content = chunk
|
||||||
|
source = original_item[1]
|
||||||
|
|
||||||
|
transformed_data.append([content, source])
|
||||||
|
start = end
|
||||||
|
|
||||||
|
return transformed_data
|
||||||
|
|
||||||
|
def format_results(documents):
|
||||||
|
all_sources = list(set([i[1] for i in documents]))
|
||||||
|
all_sources_content = []
|
||||||
|
|
||||||
|
for source in all_sources:
|
||||||
|
get_all_content_from_one_source = []
|
||||||
|
for document in documents:
|
||||||
|
if document[1] == source:
|
||||||
|
get_all_content_from_one_source.append(document[0])
|
||||||
|
|
||||||
|
combined = "\n".join(get_all_content_from_one_source)
|
||||||
|
all_sources_content.append(f"Document [{source}]: {combined}\n\n\n\n")
|
||||||
|
|
||||||
|
return "\n".join(all_sources_content)
|
||||||
97
rag-chat-backend/src/preprocessing/pdf.py
Normal file
97
rag-chat-backend/src/preprocessing/pdf.py
Normal file
|
|
@ -0,0 +1,97 @@
|
||||||
|
"""Module for tools to process PDF documents"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import io
|
||||||
|
import PyPDF2
|
||||||
|
|
||||||
|
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import datetime as dt
|
||||||
|
from datetime import datetime
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
# Setup Logging
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.DEBUG,
|
||||||
|
# level=logging.INFO,
|
||||||
|
format="Start: " + str(dt.datetime.now()).replace(" ", "_") + " | %(asctime)s [%(levelname)s] %(message)s",
|
||||||
|
handlers=[
|
||||||
|
logging.FileHandler("/<path>-_" + str(datetime.today().strftime('%Y-%m-%d')) + "_-_debug.log"),
|
||||||
|
logging.StreamHandler(sys.stdout)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PDFProcessor:
|
||||||
|
"""Class to handle the PDF processing"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _calculate_spaces_length(chunk: str) -> int:
|
||||||
|
"""Calculates the length based on spaces after splitting the chunk.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chunk (str): Given text chunk.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: Length based on spaces.
|
||||||
|
"""
|
||||||
|
return len(chunk.split())
|
||||||
|
|
||||||
|
def chunk_text(self, text_pages: list) -> list:
|
||||||
|
"""Preprocess text from a PDF file, keeping track of page numbers.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text_pages (list): List of tuples with text and corresponding page numbers.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: List containing tuples of the preprocessed text and their page numbers from the PDF.
|
||||||
|
"""
|
||||||
|
splitter = RecursiveCharacterTextSplitter(
|
||||||
|
chunk_size=100,
|
||||||
|
chunk_overlap=10,
|
||||||
|
length_function=self._calculate_spaces_length,
|
||||||
|
)
|
||||||
|
|
||||||
|
chunks_with_pages = []
|
||||||
|
for text, page_number in text_pages:
|
||||||
|
chunks = splitter.create_documents([text])
|
||||||
|
for chunk in chunks:
|
||||||
|
chunks_with_pages.append((chunk.page_content, page_number))
|
||||||
|
|
||||||
|
return chunks_with_pages
|
||||||
|
|
||||||
|
|
||||||
|
def read_pdf(pdf_bytes: io.BytesIO) -> tuple:
|
||||||
|
"""Reads a pdf and returns two lists: one of text chunks and another
|
||||||
|
of their respective page numbers.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pdf_bytes (io.BytesIO): PDF file in BytesIO format.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple of lists: (List of chunked text, List of corresponding page numbers).
|
||||||
|
"""
|
||||||
|
logging.info("Reading PDF document")
|
||||||
|
pdf_reader = PyPDF2.PdfReader(io.BytesIO(pdf_bytes))
|
||||||
|
|
||||||
|
num_pages = len(pdf_reader.pages)
|
||||||
|
logging.info("Read PDF document with '%s' pages", num_pages)
|
||||||
|
|
||||||
|
text_pages = []
|
||||||
|
for i in range(num_pages):
|
||||||
|
page = pdf_reader.pages[i]
|
||||||
|
text = page.extract_text()
|
||||||
|
if text:
|
||||||
|
text_pages.append((text, i + 1))
|
||||||
|
|
||||||
|
logging.info("Processing PDF content")
|
||||||
|
pdf_processor = PDFProcessor()
|
||||||
|
processed_chunks = pdf_processor.chunk_text(text_pages)
|
||||||
|
|
||||||
|
chunks = [chunk for chunk, _ in processed_chunks]
|
||||||
|
pages = [page for _, page in processed_chunks]
|
||||||
|
logging.info("PDF processed. Number of chunks: %s", len(chunks))
|
||||||
|
|
||||||
|
return chunks, pages
|
||||||
0
rag-chat-backend/src/search/__init__.py
Normal file
0
rag-chat-backend/src/search/__init__.py
Normal file
237
rag-chat-backend/src/search/opensearch_client.py
Normal file
237
rag-chat-backend/src/search/opensearch_client.py
Normal file
|
|
@ -0,0 +1,237 @@
|
||||||
|
"""Create a connection with relevant operations to OpenSearch"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
from opensearchpy import OpenSearch, RequestsHttpConnection, exceptions
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
|
||||||
|
from fastapi import HTTPException
|
||||||
|
|
||||||
|
from connector.database_interface.utils import mappings
|
||||||
|
|
||||||
|
from common_packages import logging
|
||||||
|
|
||||||
|
# load env-vars
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
# instantiate logger
|
||||||
|
logger = logging.create_logger(
|
||||||
|
log_level=os.getenv("LOGGING_LEVEL", "INFO"),
|
||||||
|
logger_name=__name__,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class OpenSearchInterface():
|
||||||
|
"""Client to interact with OpenSearch Instance"""
|
||||||
|
|
||||||
|
def __init__(self, index_name, embedder_name, embedding_size, language):
|
||||||
|
"""Initialize an OpenSearch interface object.
|
||||||
|
|
||||||
|
Use index name needed to create an index space in OS.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
index_name (str): index name
|
||||||
|
"""
|
||||||
|
|
||||||
|
# super().__init__()
|
||||||
|
self.logger_inst = logger
|
||||||
|
self.os_client = OpenSearch(
|
||||||
|
hosts=[
|
||||||
|
{
|
||||||
|
"host": os.getenv("VECTOR_STORE_ENDPOINT"),
|
||||||
|
"port": os.getenv("VECTOR_STORE_PORT"),
|
||||||
|
}
|
||||||
|
],
|
||||||
|
http_auth=("admin", "admin"),
|
||||||
|
use_ssl=os.getenv("OPENSEARCH_USE_SSL", "False").lower() in ["true", "1", "yes", "y"],
|
||||||
|
verify_certs=False,
|
||||||
|
connection_class=RequestsHttpConnection,
|
||||||
|
)
|
||||||
|
self.index_name = index_name
|
||||||
|
self.language = language
|
||||||
|
self.document_store = None
|
||||||
|
self.embedding_size = embedding_size
|
||||||
|
self.distance_type = "l2"
|
||||||
|
self.model = SentenceTransformer(embedder_name)
|
||||||
|
|
||||||
|
self.vector_type = "knn_vector"
|
||||||
|
self.embedding_space_name = "embedding_vector"
|
||||||
|
|
||||||
|
mappings.create_index_with_mapping_passagelevel(
|
||||||
|
index_name=self.index_name,
|
||||||
|
os_client=self.os_client,
|
||||||
|
vector_type=self.vector_type,
|
||||||
|
embedding_size=self.embedding_size,
|
||||||
|
embedding_space_name=self.embedding_space_name,
|
||||||
|
distance_type=self.distance_type,
|
||||||
|
logger=self.logger_inst,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.logger_inst.info(
|
||||||
|
"Mappings created. Loaded Embedding model %s", embedder_name
|
||||||
|
)
|
||||||
|
|
||||||
|
def create_index_with_mapping_passagelevel(
|
||||||
|
index_name,
|
||||||
|
os_client,
|
||||||
|
vector_type,
|
||||||
|
embedding_size,
|
||||||
|
embedding_space_name,
|
||||||
|
distance_type,
|
||||||
|
logger,
|
||||||
|
):
|
||||||
|
settings = {
|
||||||
|
"knn": True,
|
||||||
|
"knn.algo_param.ef_search": 512,
|
||||||
|
"index": {"number_of_shards": 3},
|
||||||
|
}
|
||||||
|
mapping = {
|
||||||
|
"properties": {
|
||||||
|
"pdf_id": {"type": "keyword"},
|
||||||
|
"text": {
|
||||||
|
"properties": {
|
||||||
|
"page_content": {"type": "text"},
|
||||||
|
"metadata": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"language": {"type": "keyword"},
|
||||||
|
"date-of-upload": {"type": "date"},
|
||||||
|
"tag": {"type": "keyword"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"page_number": {"type": "integer"},
|
||||||
|
"filename": {"type": "text"},
|
||||||
|
embedding_space_name: {
|
||||||
|
"type": vector_type,
|
||||||
|
"dimension": embedding_size,
|
||||||
|
"method": {
|
||||||
|
"name": "hnsw",
|
||||||
|
"space_type": "innerproduct",
|
||||||
|
"engine": "faiss",
|
||||||
|
"parameters": {"ef_construction": 512, "m": 48},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# Create the index with the specified settings and mappings
|
||||||
|
if not os_client.indices.exists(index=index_name):
|
||||||
|
os_client.indices.create(
|
||||||
|
index=index_name, body={"settings": settings, "mappings": mapping}
|
||||||
|
)
|
||||||
|
logger.info(f"Successfully created document indexing space: {index_name}")
|
||||||
|
|
||||||
|
logger.info(f"Successfully created indexing space: {index_name}")
|
||||||
|
|
||||||
|
def get_unique_values(self, field_name: str = "source"):
|
||||||
|
"""Retrieve all unique values for a specified field from the OpenSearch index.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
field_name (str): The field for which to retrieve unique values.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: List of unique values for the specified field.
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
query = {
|
||||||
|
"size": 0,
|
||||||
|
"aggs": {
|
||||||
|
"unique_values": {
|
||||||
|
"terms": {
|
||||||
|
"field": f"metadata.{field_name}.keyword",
|
||||||
|
"size": 10000, # Number of unique values we expect.
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
# Execute the query
|
||||||
|
response = self.os_client.search(index=self.index_name, body=query)
|
||||||
|
|
||||||
|
# Extract the terms from the response
|
||||||
|
values = [
|
||||||
|
bucket["key"]
|
||||||
|
for bucket in response["aggregations"]["unique_values"]["buckets"]
|
||||||
|
]
|
||||||
|
|
||||||
|
return values
|
||||||
|
except Exception as e:
|
||||||
|
self.logger_inst.error(
|
||||||
|
"Error in retrieving unique values for %s: %s",
|
||||||
|
field_name,
|
||||||
|
e,
|
||||||
|
)
|
||||||
|
return []
|
||||||
|
|
||||||
|
def get_date_range(self):
|
||||||
|
"""Retrieve the maximum and minimum dates from the 'date-of-upload' field.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple[str, str]: A tuple containing the 'min_date' and 'max_date' values as strings.
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
query = {
|
||||||
|
"size": 0,
|
||||||
|
"aggs": {
|
||||||
|
"min_date": {"min": {"field": "metadata.date-of-upload"}},
|
||||||
|
"max_date": {"max": {"field": "metadata.date-of-upload"}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
# Execute the query
|
||||||
|
response = self.os_client.search(index=self.index_name, body=query)
|
||||||
|
|
||||||
|
# Extract the min and max dates from the response
|
||||||
|
min_date = response["aggregations"]["min_date"]["value_as_string"]
|
||||||
|
max_date = response["aggregations"]["max_date"]["value_as_string"]
|
||||||
|
|
||||||
|
return min_date, max_date
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger_inst.error("Error in retrieving date range: %s", e)
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
def delete_indices_by_document(self, document_id: str):
|
||||||
|
"""Delete all indices belonging to the same document in OpenSearch.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
document_id (str): The unique identifier of the document.
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
query = {"query": {"term": {"metadata.source.keyword": document_id}}}
|
||||||
|
self.os_client.delete_by_query(index=self.index_name, body=query)
|
||||||
|
self.logger_inst.info("Deleted all indices for document")
|
||||||
|
except Exception as e:
|
||||||
|
self.logger_inst.error(
|
||||||
|
"Failed deleting indices for document with error: %s", e
|
||||||
|
)
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def empty_entire_index(self):
|
||||||
|
"""Delete all entries in the used vector db index."""
|
||||||
|
|
||||||
|
try:
|
||||||
|
query = {"query": {"match_all": {}}}
|
||||||
|
response = self.os_client.delete_by_query(index=self.index_name, body=query)
|
||||||
|
|
||||||
|
self.logger_inst.info(
|
||||||
|
"Deleted all %s entries for index: %s",
|
||||||
|
response["deleted"],
|
||||||
|
self.index_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
except exceptions.NotFoundError as error:
|
||||||
|
self.logger_inst.warning("Failed emptying index with error: %s", error)
|
||||||
|
raise HTTPException(status_code=404, detail="Not found") from error
|
||||||
|
|
||||||
|
except exceptions.OpenSearchException as error:
|
||||||
|
self.logger_inst.error("Failed emptying index with error: %s", error)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500, detail="Error while deleting from OpenSearch"
|
||||||
|
) from error
|
||||||
Loading…
Add table
Add a link
Reference in a new issue