basic backend

This commit is contained in:
Niklas Mueller 2024-06-23 12:57:53 +02:00
parent 16e5004228
commit 89ec0476ca
29 changed files with 2125 additions and 13 deletions

View file

@ -1,13 +1,11 @@
"""FastAPI Backend"""
import uvicorn
import os
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from dotenv import load_dotenv
from endpoints import files
from endpoints import files, llm, search, configurations
from core.config import settings
@ -25,12 +23,9 @@ app.add_middleware(
)
app.include_router(files.router, prefix=settings.API_V1_STR) # , tags=["files"]
print('OPENSEARCH_USE_SSL')
print(os.getenv('OPENSEARCH_USE_SSL'))
print('settings.API_V1_STR')
print(settings.API_V1_STR)
app.include_router(llm.router, prefix=settings.API_V1_STR, tags=["llm"]) # , tags=["llm"]
app.include_router(search.router, prefix=settings.API_V1_STR, tags=["search"]) # , tags=["search"]
app.include_router(configurations.router, prefix=settings.API_V1_STR, tags=["config"])
if __name__ == "__main__":

View file

@ -0,0 +1,102 @@
""" This module defines the logging component for the LLM interactions.
The data is displayed in the OpenSearch Dashboard.
"""
import os
import uuid
import logging
from opensearch_logger import OpenSearchHandler
from opensearchpy import RequestsHttpConnection
class DashboardLogger:
"""Logger instance for OpenSearch dashboard"""
def __init__(self, logger_name=None):
if logger_name is None:
# Generate a unique logger name
logger_name = "rag-logs-" + str(uuid.uuid4())
self.logger_instance = self._create_os_logger(logger_name=logger_name)
self.logger_info = {}
def add_information(self, label: str, value):
"""Here you can add information to a given process.
Each information consists of a label and its given value.
Args:
label (str): label
value (any): value
"""
self.logger_info[label] = value
def close_logging(self):
"""Close logging in the final process."""
self.logger_info = self._roll_out_json(original_json=self.logger_info)
if type(self.logger_info) == list:
for logger_json in self.logger_info:
self.logger_instance.info("Logging information", extra=logger_json)
else:
self.logger_instance.info("Logging information", extra=self.logger_info)
def _create_os_logger(self, logger_name):
"""Create a logger which logs on OpenSearch.
Args:
logger_name (str): use some logger name.
Returns:
logger: OpenSearch logger instance.
"""
logger = logging.getLogger(logger_name)
logger.setLevel(logging.INFO)
opensearch_host = os.getenv("VECTOR_STORE_ENDPOINT", "localhost")
opensearch_port = os.getenv("VECTOR_STORE_PORT", "9200")
handler = OpenSearchHandler(
index_name="rag-logs",
hosts=[f"http://{opensearch_host}:{opensearch_port}"],
http_auth=("admin", "admin"),
use_ssl=os.getenv("OPENSEARCH_USE_SSL", "False").lower()
in ["true", "1", "yes", "y"],
verify_certs=False,
connection_class=RequestsHttpConnection,
)
logger.addHandler(handler)
return logger
def _roll_out_json(self, original_json):
"""Roll out JSON with lists into individual JSONs with additional attributes.
Args:
original_json: The original JSON with lists in 'sources' and 'passages' attributes.
Return
rolled_out_jsons: List of rolled out JSONs.
"""
rolled_out_jsons = []
# Iterate through each item in the 'sources' and 'passages' lists
for rank, (source, passage) in enumerate(
zip(original_json["sources"], original_json["passages"]), start=1
):
# Create a new JSON for each pair of source and passage
new_json = {
"query": original_json["query"],
"source": source["source"],
"page": source["page"],
# NOTE we need to cut. Otherwise the visualization wont work with these many chars
"passage": passage[:250],
"answer": original_json["answer"][:250],
"language": original_json["language"],
"model": original_json["model"],
"rank": rank, # Assign the rank
}
rolled_out_jsons.append(new_json)
return rolled_out_jsons

View file

@ -0,0 +1,46 @@
""" This module defines the logging component."""
import logging
def create_logger(log_level: str, logger_name: str = "custom_logger"):
"""Create a logging based on logger.
Args:
log_level (str): Kind of logging
logger_name (str, optional): Name of logger
Returns:
logger: returns logger
"""
logger = logging.getLogger(logger_name)
logger.setLevel(logging.DEBUG) # Set the base logging level to the lowest (DEBUG)
# If logger already has handlers, don't add a new one
if logger.hasHandlers():
logger.handlers.clear()
# Create a console handler and set the level based on the input
console_handler = logging.StreamHandler()
if log_level == "DEBUG":
console_handler.setLevel(logging.DEBUG)
elif log_level == "INFO":
console_handler.setLevel(logging.INFO)
elif log_level == "WARNING":
console_handler.setLevel(logging.WARNING)
elif log_level == "ERROR":
console_handler.setLevel(logging.ERROR)
else:
raise ValueError("Invalid log level provided")
# Create a formatter and set it for the console handler
formatter = logging.Formatter(
"%(asctime)s - %(levelname)s [%(name)s] - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
console_handler.setFormatter(formatter)
# Add the console handler to the logger
logger.addHandler(console_handler)
return logger

View file

@ -0,0 +1,185 @@
"""Create a connection with relevant operations to OpenSearch"""
import os
from dotenv import load_dotenv
from opensearchpy import OpenSearch, RequestsHttpConnection, exceptions
from sentence_transformers import SentenceTransformer
from fastapi import HTTPException
from connector.database_interface.utils import mappings
from connector.database_interface.utils.base_search_interface import BaseSearchInterface
from common_packages import logging
# load env-vars
load_dotenv()
# instantiate logger
logger = logging.create_logger(
log_level=os.getenv("LOGGING_LEVEL", "INFO"),
logger_name=__name__,
)
class OpenSearchInterface(BaseSearchInterface):
"""Client to interact with OpenSearch Instance"""
def __init__(self, index_name, embedder_name, embedding_size, language):
"""Initialize an OpenSearch interface object.
Use index name needed to create an index space in OS.
Args:
index_name (str): index name
"""
super().__init__()
self.logger_inst = logger
self.os_client = OpenSearch(
hosts=[
{
"host": os.getenv("VECTOR_STORE_ENDPOINT"),
"port": os.getenv("VECTOR_STORE_PORT"),
}
],
http_auth=("admin", "admin"),
use_ssl=os.getenv("OPENSEARCH_USE_SSL", "False").lower()
in ["true", "1", "yes", "y"],
verify_certs=False,
connection_class=RequestsHttpConnection,
)
self.index_name = index_name
self.language = language
self.document_store = None
self.embedding_size = embedding_size
self.distance_type = "l2"
self.model = SentenceTransformer(embedder_name)
self.vector_type = "knn_vector"
self.embedding_space_name = "embedding_vector"
mappings.create_index_with_mapping_passagelevel(
index_name=self.index_name,
os_client=self.os_client,
vector_type=self.vector_type,
embedding_size=self.embedding_size,
embedding_space_name=self.embedding_space_name,
distance_type=self.distance_type,
logger=self.logger_inst,
)
self.logger_inst.info(
"Mappings created. Loaded Embedding model %s", embedder_name
)
def get_unique_values(self, field_name: str = "source"):
"""Retrieve all unique values for a specified field from the OpenSearch index.
Args:
field_name (str): The field for which to retrieve unique values.
Returns:
list: List of unique values for the specified field.
"""
try:
query = {
"size": 0,
"aggs": {
"unique_values": {
"terms": {
"field": f"metadata.{field_name}.keyword",
"size": 10000, # Number of unique values we expect.
}
}
},
}
# Execute the query
response = self.os_client.search(index=self.index_name, body=query)
# Extract the terms from the response
values = [
bucket["key"]
for bucket in response["aggregations"]["unique_values"]["buckets"]
]
return values
except Exception as e:
self.logger_inst.error(
"Error in retrieving unique values for %s: %s",
field_name,
e,
)
return []
def get_date_range(self):
"""Retrieve the maximum and minimum dates from the 'date-of-upload' field.
Returns:
tuple[str, str]: A tuple containing the 'min_date' and 'max_date' values as strings.
"""
try:
query = {
"size": 0,
"aggs": {
"min_date": {"min": {"field": "metadata.date-of-upload"}},
"max_date": {"max": {"field": "metadata.date-of-upload"}},
},
}
# Execute the query
response = self.os_client.search(index=self.index_name, body=query)
# Extract the min and max dates from the response
min_date = response["aggregations"]["min_date"]["value_as_string"]
max_date = response["aggregations"]["max_date"]["value_as_string"]
return min_date, max_date
except Exception as e:
self.logger_inst.error("Error in retrieving date range: %s", e)
return None, None
def delete_indices_by_document(self, document_id: str):
"""Delete all indices belonging to the same document in OpenSearch.
Args:
document_id (str): The unique identifier of the document.
"""
try:
query = {"query": {"term": {"metadata.source.keyword": document_id}}}
self.os_client.delete_by_query(index=self.index_name, body=query)
self.logger_inst.info("Deleted all indices for document")
except Exception as e:
self.logger_inst.error(
"Failed deleting indices for document with error: %s", e
)
raise e
def empty_entire_index(self):
"""Delete all entries in the used vector db index."""
try:
query = {"query": {"match_all": {}}}
response = self.os_client.delete_by_query(index=self.index_name, body=query)
self.logger_inst.info(
"Deleted all %s entries for index: %s",
response["deleted"],
self.index_name,
)
except exceptions.NotFoundError as error:
self.logger_inst.warning("Failed emptying index with error: %s", error)
raise HTTPException(status_code=404, detail="Not found") from error
except exceptions.OpenSearchException as error:
self.logger_inst.error("Failed emptying index with error: %s", error)
raise HTTPException(
status_code=500, detail="Error while deleting from OpenSearch"
) from error

View file

@ -0,0 +1,36 @@
"""Abstraction class for other vector storage search services and databases to be implemented.
The goal is mainly to have a common definition if new service is set.
"""
from typing import Tuple
from abc import ABC, abstractmethod
class BaseSearchInterface(ABC):
def __init__(self) -> None:
super().__init__()
self.model: any
self.language: str
self.document_store: str
self.embedding_size: int
@abstractmethod
def get_unique_values(self, field_name: str) -> list:
"""A function to retrieve all unique values for a specified field from the underlying search index.
Args:
field_name (str): The field for which to retrieve unique values.
Returns:
list: List of unique values for the specified field.
"""
pass
@abstractmethod
def get_date_range(self) -> Tuple[str, str]:
"""Retrieve the maximum and minimum dates of the database entries.
Returns:
tuple[str, str]: A tuple containing the 'min_date' and 'max_date' values as strings.
"""
pass

View file

@ -0,0 +1,53 @@
def create_index_with_mapping_passagelevel(
index_name,
os_client,
vector_type,
embedding_size,
embedding_space_name,
distance_type,
logger,
):
settings = {
"knn": True,
"knn.algo_param.ef_search": 512,
"index": {"number_of_shards": 3},
}
mapping = {
"properties": {
"pdf_id": {"type": "keyword"},
"text": {
"properties": {
"page_content": {"type": "text"},
"metadata": {
"type": "object",
"properties": {
"language": {"type": "keyword"},
"date-of-upload": {"type": "date"},
"tag": {"type": "keyword"},
},
},
}
},
"page_number": {"type": "integer"},
"filename": {"type": "text"},
embedding_space_name: {
"type": vector_type,
"dimension": embedding_size,
"method": {
"name": "hnsw",
"space_type": "innerproduct",
"engine": "faiss",
"parameters": {"ef_construction": 512, "m": 48},
},
},
}
}
# Create the index with the specified settings and mappings
if not os_client.indices.exists(index=index_name):
os_client.indices.create(
index=index_name, body={"settings": settings, "mappings": mapping}
)
logger.info(f"Successfully created document indexing space: {index_name}")
logger.info(f"Successfully created indexing space: {index_name}")

View file

@ -0,0 +1,289 @@
"""Ollama LLM Module
This module provides functionalities to access the Ollama API
e.g. for calling the model for inference.
"""
import os
import json
import requests
import string
from connector.llm.utils.helpers import (
preprocess_llama_chat_into_query_instruction,
extract_first_query_dict,
)
# from connector.llm.utils.prompts import GenerativeQAPrompt, GenerativeQAPromptDE
from connector.llm.utils.base_llm import BaseLLM
from connector.llm.utils.base_prompts import BaseChatPrompts, BaseGenerativePrompts
from common_packages import logging
# instantiate logger
logger = logging.create_logger(
log_level=os.getenv("LOGGING_LEVEL", "INFO"),
logger_name=__name__,
)
# Message:
# [{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "Hello! What are some good questions to ask you?"}, {"role": "assistant", "content": "Hello! I am here to help you with any information or guidance you need."}, {"role": "user", "content": "Ok, can you list me the capital cities of all european countries?"}]
class OllamaLLM(BaseLLM):
def __init__(self, language):
# client.api_key = os.getenv("LLM_API_KEY")
self.api_key = os.getenv("LLM_API_KEY")
self.base_url = os.getenv("LLM_API_ENDPOINT")
self.modelname = "llama3"
# self.modelname = "mistral"
self.max_num_tokens = 1850
self.language = language
self.prompt_obj = GenerativeQAPromptDE
logger.debug("Innitiating OLLAMA Class")
def cut_tokens(self, context: str, cut_above=False):
return context
def llm_request(self, prompt: str) -> str:
return ""
def llm_chat_querifier(self, chat: list):
"""Creates an request to LLMs based on a prompt. It creates a query based on a
chat and especially the last user message.
Args:
chat (list): chat
Returns:
str: querified last message with consideration of the chat.
"""
logger.debug("chat in ollama")
logger.debug(chat)
chat_template = GenerativeChatPromptDE
prompt = preprocess_llama_chat_into_query_instruction(
chat,
chat_prompt_template=chat_template.chat_querifier,
cut_tokens=self.cut_tokens,
lang=self.language,
)
logger.debug("prompt after preprocess_llama_chat_into_query_instruction in ollama")
logger.debug(prompt)
url = f"{self.base_url}" + "/api/generate"
headers = {
"Content-Type": "application/json",
}
payload = {
"model": self.modelname,
"prompt": prompt,
"stream": False,
}
response = requests.post(url, headers=headers, data=json.dumps(payload))
results = response.json()["response"].strip()
return results
try:
query = extract_first_query_dict(response["choices"][0]["text"])
return query["query"]
except Exception as e:
logger.debug("There was an error: " + str(e))
return chat[-1]["content"]
def llm_chat_request(
self, chat: dict, context: str, language: str, stream_response: bool
) -> str:
"""Creates an request, to the GPT-3.5-Turbo deployed in Azure, based on a prompt and the access key.
Args:
chat (dict): User chat.
context: (str): context which was triggered by the query
language (str): string of the selected language
api_key (str): Access Api key.
Returns:
str: returns response.
"""
logger.debug("OLLAMA llm_chat_request triggered")
chat_template = None
chat_template = GenerativeChatPromptDE
logger.debug("chat")
logger.debug(chat)
logger.debug("chat_template.llm_purpose")
logger.debug(chat_template.llm_purpose)
logger.debug("chat_template.context_command")
logger.debug(chat_template.context_command)
logger.debug("chat_template.acknowledgement_command")
logger.debug(chat_template.acknowledgement_command)
printable = set(string.printable)
context = ''.join(filter(lambda x: x in printable, context))
context = context.replace("'", "").replace("!", "")
logger.debug("context")
logger.debug(context)
chat_prompt = [
{"role": "system", "content": chat_template.llm_purpose},
{"role": "user", "content": f"{chat_template.context_command} {context}"},
{"role": "assistant", "content": chat_template.acknowledgement_command},
]
chat_prompt.extend(chat)
logger.debug("chat_prompt")
logger.debug(chat_prompt)
url = f"{self.base_url}" + "/v1/chat/completions"
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.api_key}",
}
payload = {
"model": self.modelname,
"messages": chat_prompt,
}
logger.debug("payload")
logger.debug(json.dumps(payload, indent=4))
response = requests.post(url, headers=headers, data=json.dumps(payload))
logger.debug("response")
logger.debug(response.text)
try:
results = response.json()["choices"][0]["message"]["content"]
except Exception as e:
logger.debug("There was an error: " + str(e))
logger.debug("response")
logger.debug(response)
error_message = response.json()["error"]["message"]
results = f"Something went wrong. Here is the error: {error_message}"
logger.debug("results")
logger.debug(results)
return results
class GenerativeQAPromptDE(BaseGenerativePrompts):
generative_qa_prompt_reasoning = """
Kontext: {context}
Generiere eine informative und detailiierte Antwort auf der gegebenen Frage und des Kontexts.
Formuliere deine Antwort immer in deutscher Sprache.
Füge alle zusätzliche Erkenntnisse oder Perspektiven ein, die das Verständnis verbessern können.
Beantworte die Frage basierend auf den Kontext und begründe warum deine Antwort richtig ist.
Frage: {query}
Antwort:
"""
chatbot_prompt = """ Beantworte den folgenden Input basierend dem unten stehenden Beispiel.
Antworte immer in deutscher Sprache.
Versuche ein Gespräch so menschlich wie möglich zu gestalten.
Nutze für die Antwort höchstens zwei Sätze.\n
User Input:\n
Was ist deiner Meinung nach das Interessanteste am Menschen?\n
Model Output:\n
Das ist eine gute Frage. Nun, ich denke, eine der faszinierendsten Eigenschaften des Menschen ist seine Fähigkeit, etwas zu schaffen und zu erneuern.
Das finde ich am interessantesten. \n
User Input:\n
{query}\n
Model Output:\n
"""
llama_prompt_template = """
<s>[INST] <<SYS>>
Sie sind ein hilfreicher, respektvoller und ehrlicher Assistent.
Antworten Sie immer so hilfreich wie möglich, während Sie sicher bleiben.
Ihre Antworten sollten keinen schädlichen, unethischen, rassistischen, sexistischen, giftigen, gefährlichen oder illegaler Inhalt enthalten.
Stellen Sie sicher, dass Ihre Antworten sozial unabhängig und positiv sind.
Formuliere deine Antwort immer in deutscher Sprache.
Wenn eine Frage keinen Sinn ergibt oder nicht faktisch kohärent ist, erklären Sie dies anstatt eine falsche Antwort zu geben.
Wenn Sie nicht auf eine Frage antworten können, teilen Sie keine falsche Informationen.
<</SYS>>
{user_prompt} [/INST]
"""
class GenerativeChatPromptDE(BaseChatPrompts):
llm_purpose = """
Du bist ein hilfreicher Assistent, der entwickelt wurde um Fragen auf der Grundlage einer vom Benutzer bereitgestellten Wissensbasis zu beantworten.
Integriere zusätzliche Einblicke oder Perspektiven, die das Verständnis des Lesers verbessern können.
Verwende den bereitgestellten Kontext, um die Frage zu beantworten, und erläutere so ausführlich wie möglich, warum Du glaubst, dass diese Antwort korrekt ist.
Verwende nur diejenigen Dokumente, die die Frage beantworten können.
Formuliere deine Antwort immer in deutscher Sprache.
Beantworte keine Fragen die nicht zum gegeben Kontext sich beziehen!
"""
context_command = """
Hier ist der gegebene Kontext worauf du die Fragen beantworten solltest:
"""
acknowledgement_command = """
Danke dass du mir den Kontext bereitgestellt hast! Ich werde dich dabei unterstützen deine Fragen mit Referenzen zu beantworten.
"""
chat_querifier = """
Du bist ein hilfreicher Assistent, der dazu entworfen wurde, eine Anfrage basierend auf der letzten Usereingabe und den relevanten Nachrichten innerhalb des Chats zu transformieren, sodass die Anfrage genutzt werden kann, um relevante Informationen aus einer Datenbank abzurufen.
Wenn du eine neue User Nachricht erhältst, ist es deine Aufgabe, diese mit dem Kontext der vorherigen Chatnachrichten zu einem einzigen, umfassenden Query zu synthetisieren.
Dieser Query sollte alle relevanten Informationen enthalten, die für eine Abfrage mit einem neuralen Suchansatz benötigt werden.
Stelle sicher, dass deine Ausgabe als JSON-Objekt formatiert ist, das die transformierte Anfrage enthält.
## Erstes Beispiel
User: Wer ist Lady Gaga?
Assistent: Lady Gaga ist eine Sängerin.
User: Wie alt ist sie?
Assistent: Sie ist Mitte 20.
Neue User Nachricht: Und woher kommt sie? Bitte formatiere es in ein JSON.
## Output
{"query": "Woher kommt Lady Gaga?"}
## Zweites Beispiel
User: Erstelle mir Beispielanfragen.
Assistent: Sicher. Einige Fragen basierend auf dem Wissen sind:
Was ist ein RAG?
Was ist ein LLM und was bedeutet Transformers?
Wann wurden Transformers erfunden?
Was sind die Vorteile von RAG?
Neue User Nachricht: Großartig! Könntest du die letzte Frage beantworten?
## Output
{"query": "Was sind die Vorteile von RAG?"}
## Drittes Beispiel
User: Wie nennt man den Ansatz, der neuronale Suche und LLMs verwendet, um Antworten auf einem Wissenskorpus zu generieren?
Assistent: Dieser Ansatz wird Retrieval Augmented Generation oder kurz RAG genannt.
Neue User Nachricht: Ah, ich verstehe! Könntest du mir die Vorteile davon in Aufzählungspunkten auflisten?
## Output
{"query": "Was sind die Vorteile der Retrieval Augmented Generation (RAG)?"}
"""

View file

@ -0,0 +1,79 @@
"""Abstraction Class for LLM Services.
Abstraction class for other LLM services to be implemented.
The goal is mainly to have a common definition if new LLM services are set.
"""
import os
from abc import ABC, abstractmethod
from common_packages import logging
# instantiate logger
logger = logging.create_logger(
log_level=os.getenv("LOGGING_LEVEL", "INFO"),
logger_name=__name__,
)
class BaseLLM(ABC):
def __init__(self, language: str):
self.api_key: str
self.max_num_tokens: int
self.language: str
self.modelname: str
self.prompt_obj: any
@abstractmethod
def cut_tokens(self, context: str) -> str:
"""A function which cuts a huge string into a smaller string. It is highly recommended to reduce the
possibility of exceeding the token limit.
Args:
context (str): Context to be cut.
Returns:
str: cut context
"""
pass
@abstractmethod
def llm_request(self, prompt: str) -> str:
"""A simple LLM request allowing to take one prompt and generating an answer on it.
Args:
prompt (str): prompt
Returns:
str: response to the prompt
"""
pass
@abstractmethod
def llm_chat_request(
self, chat: dict, context: str, language: str, stream_response: bool
) -> str:
"""An requerst on the LLM based on a given chat request.
Args:
chat (dict): A list of the chat in a turn between assistant and user.
context (str): Given context based on the search component
language (str): relevant language
Returns:
str: An response to the chat request.
"""
pass
@abstractmethod
def llm_chat_querifier(self, chat: list):
"""Creates an request to LLMs based on a prompt. It creates a query based on a
chat and especially the last user message.
Args:
chat (list): chat
Returns:
str: querified last message with consideration of the chat.
"""
pass

View file

@ -0,0 +1,32 @@
"""Abstraction Class for Prompts.
Abstraction class for other LLM services to be implemented.
The goal is mainly to have a common definition if new LLM services are set.
"""
from abc import ABC
class BaseGenerativePrompts(ABC):
generative_qa_prompt_reasoning: str
chatbot_prompt: str
claim_collector: str
fact_checker: str
llama_prompt_template: str
class BaseChatPrompts(ABC):
llm_purpose: str
context_command: str
acknowledgement_command: str
class BaseGenerativePromptsObject(ABC):
"""Base object for prompt collections"""
# TODO: let's rename this to 'BaseGenerativePrompts'
# after refactoring all connectors
claim_collector: str
fact_checker: str
chat_querifier: str

View file

@ -0,0 +1,212 @@
""" A module with all the helper functions for the llm connectors
"""
import json
import re
# NOTE deprecated
def prompt_decider(prompt: str):
"""This function classifies the prompt and processes it properly.
Args:
prompt (str): prompt
Returns:
tuple: returns both messages
"""
if prompt.count("User Input:") == 2 and prompt.count("Model Output:") == 2:
return process_for_azure_prompt(
prompt=prompt, splitter_word="User Input:", rest_word="Model Output:"
)
else:
return process_for_azure_prompt(
prompt=prompt, splitter_word="Question:", rest_word="Answer:"
)
def process_for_azure_prompt(prompt: str, splitter_word: str, rest_word: str):
"""This gives you two messages for the azure openai instance.
Args:
prompt (str): The whole prompt
splitter_word (str): The word which splits the prompt into two.
rest_word (str): The word which is still in the second part of the message.
Returns:
tuple: both messages
"""
index = prompt.rfind(splitter_word)
part1 = prompt[:index]
part2 = prompt[index:]
part2 = part2.replace("\n", "")
part2 = part2.replace(splitter_word, "")
part2 = part2.replace(rest_word, "")
return part1, part2.strip()
def add_llama_system_prompt(user_prompt: str) -> str:
"""Wrap given prompt into Llama system prompt
Args
user_prompt (str): Prompt to be wrapped with Llama system prompt
Returns
str: Wrapped prompt
"""
# TODO: Rename user_prompt to 'inner_prompt'
from langchain.prompts import PromptTemplate
from parser import argparser
LLM = argparser.llm_component
prompt = PromptTemplate(
input_variables=["user_prompt"],
template=LLM.prompt_obj.llama_prompt_template,
)
prompt = prompt.format(user_prompt=user_prompt)
return prompt.strip()
def preprocess_llama_chat_into_query_instruction(
chat: dict, lang: str, chat_prompt_template: str, cut_tokens
):
"""A function which turns a chat into a llama prompt. This is used to turn a chat into a query.
Args:
chat (dict): _description_
lang (str): _description_
chat_prompt_template (str): _description_
cut_tokens (_type_): _description_
Returns:
str: prompt template for query instruction
"""
user_tag = "User: "
assistant_tag = "Assistant: "
new_user_tag = "New User message: "
start_instruction_tag = "[INST] "
end_instruction_tag = " [/INST]"
prompt = ""
if lang == "german":
user_tag = "User: "
assistant_tag = "Assistent: "
new_user_tag = "Neue User Nachricht: "
# only consider the 5 last messages of the chat
chat = chat[-5:]
for message_idx in range(len(chat)):
message = chat[message_idx]
if message_idx == len(chat) - 1:
add_message = new_user_tag + message["content"]
prompt = prompt + add_message + "\n\n"
else:
if message["role"] == "user":
add_message = user_tag + message["content"]
prompt = prompt + add_message + "\n\n"
else:
add_message = assistant_tag + message["content"]
prompt = prompt + add_message + "\n\n"
# sometimes the tokens cut could leave some unnecessary chars above.
prompt = cut_tokens(prompt, cut_above=True)
user_index = prompt.find(user_tag)
if user_index != -1:
prompt = prompt[user_index:]
prompt = start_instruction_tag + prompt
prompt = prompt + end_instruction_tag
return chat_prompt_template + prompt
def preprocess_chat_into_string(
chat: dict, lang: str, chat_prompt_template: str, cut_tokens
):
user_tag = "User: "
assistant_tag = "Assistant: "
new_user_tag = "New User message: "
start_instruction_tag = "\n\n\n## Example\n"
end_instruction_tag = "## Output"
prompt = ""
if lang == "german":
start_instruction_tag = "\n\n\n## Beispiel\n"
end_instruction_tag = "## Output"
user_tag = "User: "
assistant_tag = "Assistent: "
new_user_tag = "Neue User Nachricht: "
# only consider the 5 last messages of the chat
chat = chat[-5:]
for message_idx in range(len(chat)):
message = chat[message_idx]
if message_idx == len(chat) - 1:
add_message = new_user_tag + message["content"]
prompt = prompt + add_message + "\n\n"
else:
if message["role"] == "user":
add_message = user_tag + message["content"]
prompt = prompt + add_message + "\n\n"
else:
add_message = assistant_tag + message["content"]
prompt = prompt + add_message + "\n\n"
# sometimes the tokens cut could leave some unnecessary chars above.
prompt = cut_tokens(prompt, cut_above=True)
user_index = prompt.find(user_tag)
if user_index != -1:
prompt = prompt[user_index:]
prompt = start_instruction_tag + prompt
prompt = prompt + end_instruction_tag
return chat_prompt_template + prompt
def extract_first_query_dict(input_str):
pattern = r'\{"query": ".*?"\}'
match = re.search(pattern, input_str.strip())
if match:
return json.loads(match.group(0))
return None
def format_message(template, **kwargs):
"""Formats a message template with the provided keyword arguments.
This function takes a message template dictionary and formats its "content"
field by replacing placeholders with the values provided in the keyword arguments.
Parameters:
template (dict): A dictionary containing the message template. It should
have a "content" key with a string value that may contain
placeholders for formatting.
**kwargs: Arbitrary keyword arguments that will be used to replace
placeholders in the "content" string of the template.
Returns:
dict: A new dictionary with the same keys as the template, but with the
"content" field formatted with the provided keyword arguments.
Example:
>>> template = {"content": "Hello, {name}! Welcome to {place}."}
>>> format_message(template, name="Alice", place="Wonderland")
{'content': 'Hello, Alice! Welcome to Wonderland.'}
"""
message = template.copy()
message["content"] = message["content"].format(**kwargs)
return message

View file

@ -0,0 +1,28 @@
"""Endpoints to retrieve information about the backend configuration."""
import json
from fastapi import APIRouter, Response
from endpoints.llm import LLM
import os
router = APIRouter()
@router.get("/configs", tags=["configurations"])
def get_configs():
"""Get configurations of the backend"""
backend_configs = {
"llm": {
'language': LLM.language,
'max_num_tokens': LLM.max_num_tokens,
'modelname': LLM.modelname
},
'env_vars': {
'language': os.getenv('LLM_LANGUAGE'),
'llm_option': os.getenv('LLM_OPTION'),
'bucket_name': os.getenv('BUCKET_NAME'),
}
}
return Response(status_code=200, content=json.dumps(backend_configs))

View file

@ -11,13 +11,24 @@ from fastapi import APIRouter, File, UploadFile, Form, HTTPException
from fastapi.responses import StreamingResponse
from core.config import settings
from preprocessing import pdf
from neural_search.search_component import IndexSearchComponent
from connector.database_interface.opensearch_client import OpenSearchInterface
import logging
import datetime as dt
from datetime import datetime
import traceback
OS_INTERFACE = OpenSearchInterface(
index_name="german",
embedder_name="PM-AI/bi-encoder_msmarco_bert-base_german",
embedding_size=768,
language="german",
)
SEARCH_COMPONENT = IndexSearchComponent(os_client=OS_INTERFACE)
DIRECTORY_NAME = "german"
# Setup Logging
logging.basicConfig(
@ -148,12 +159,21 @@ def upload_pdf_list(tag: str = Form(...), pdf_files: List[UploadFile] = File(...
for pdf_file in pdf_files:
logging.info("Processing file: %s", pdf_file.filename)
pdf_file_name = pdf_file.filename
# pdf_file_path = f"{settings.BUCKET_FILE_PATH}/{pdf_file.filename}"
pdf_file_path = f"{settings.BUCKET_FILE_PATH}/{pdf_file.filename}"
pdf_contents = pdf_file.file.read()
# process pdf
# docs, pages_list = pdf.read_pdf(pdf_bytes=pdf_contents)
docs, pages_list = pdf.read_pdf(pdf_bytes=pdf_contents)
try:
# Upload to Vector Storage
SEARCH_COMPONENT.set_indexes(
data=docs,
sources=pdf_file_path,
pages=pages_list,
tag=tag,
)
# Upload to Object Storage
object_name = f"{settings.BUCKET_FILE_PATH}/{pdf_file_name}"
put_response = minio_client.put_object(
Bucket=settings.BUCKET, Key=object_name, Body=pdf_contents
@ -169,6 +189,10 @@ def upload_pdf_list(tag: str = Form(...), pdf_files: List[UploadFile] = File(...
else:
upload_responses.append("failure")
docs, pages_list = pdf.read_pdf(pdf_bytes=pdf_contents)
logging.debug(docs)
logging.debug(pages_list)
except Exception as e:
logging.error("PDF upload failed with error: %s ", e)
logging.error("Stacktrace: " + str(traceback.format_exc()))

View file

@ -0,0 +1,185 @@
"""This API incorporates all endpoints for the interaction with the Large Language Model"""
import json
import os
import warnings
from connector.llm.ollama import OllamaLLM
from connector.database_interface.opensearch_client import OpenSearchInterface
# from parser import argparser
# from parser.utils.constants import LLMFlagNames
from fastapi import APIRouter, HTTPException
from fastapi.responses import StreamingResponse
from common_packages import logging
from common_packages.dashboard_logging import DashboardLogger
from neural_search.search_component import IndexSearchComponent
from preprocessing.commons import combine_content, format_results
# instantiate logger
logger = logging.create_logger(
log_level=os.getenv("LOGGING_LEVEL", "INFO"),
logger_name=__name__,
)
router = APIRouter()
# parses the chosen neural search component based on flags.
LLM = OllamaLLM("german")
OS_INTERFACE = OpenSearchInterface(
index_name="german",
embedder_name="PM-AI/bi-encoder_msmarco_bert-base_german",
embedding_size=768,
language="german",
)
SEARCH_COMPONENT = IndexSearchComponent(os_client=OS_INTERFACE)
PROMPTS = LLM.prompt_obj
llm_component_name = "ollama-llm"
selected_language = "german"
lite_llm = os.getenv("LITE_LLM")
stream_llms = ["fu"]
@router.post("/fact-checking")
def factchecking(answer: str, query: str):
print(query)
data = SEARCH_COMPONENT.search(query)
context = [item["content"] for item in data]
context = "\n$ $ $ $".join([item[0] for item in context])
try:
context = LLM.cut_tokens(context)
k = {"answer": answer}
claims = LLM.llm_request(PROMPTS.claim_collector.format(**k))
f = {"context": context, "claims": claims}
response = LLM.llm_request(PROMPTS.fact_checker.format(**f))
return {"result": response}
except Exception as e:
logger.error("Calling fact checking failed with error: '%s'", e)
raise HTTPException("Calling fact checking failed") from e
@router.get("/chat")
def chat(stream_response, messages, tags=None, languages: str = None, start_date: str = None, end_date: str = None):
"""Takes a dictionary of the message as an input
Args:
messages (list): messages Must be in the following format: [{"role": "user", "content": "Hello!"}, ...]
tags (list[str], optional): List of tags to filter by.
languages (list[str], optional): List of languages to filter by.
start_date (str, optional): Start date for date-of-upload range.
end_date (str, optional): End date for date-of-upload range.
Returns:
dict: returns the results of the chat.
"""
# Parse stringified arrays back into lists
logger.info("Entering Router /chat...")
if tags is not None:
try:
tags = json.loads(tags)
except json.JSONDecodeError:
raise HTTPException(status_code=400, detail="Invalid format for 'tags'")
if languages is not None:
try:
languages = json.loads(languages)
except json.JSONDecodeError:
raise HTTPException(status_code=400, detail="Invalid format for 'languages'")
if messages is not None:
try:
messages = json.loads(messages)
except json.JSONDecodeError:
raise HTTPException(status_code=400, detail="Invalid format for 'messages'")
# provides you a query based on the whole chat
query = LLM.llm_chat_querifier(chat=messages)
logger.info("Generated query: '%s'", query)
# creating a new dashboard logger at the start of the project
dashboardlogger = DashboardLogger()
dashboardlogger.add_information(label="query", value=query)
dashboardlogger.add_information(label="question", value=messages[-1]["content"])
# get the context relevant to a query
data = SEARCH_COMPONENT.search(
query=query,
tags=tags,
languages=languages,
start_date=start_date,
end_date=end_date,
)
context = [item["content"] for item in data]
sources = list(([item["source"] for item in data]))
metadata = [{"source": item.get("source"), "page": item.get("page")} for item in data]
# log information
dashboardlogger.add_information(label="sources", value=metadata)
dashboardlogger.add_information(label="passages", value=context)
data = [[num, char] for num, char in zip(context, sources)]
# here we transform the data (which is a list in list structure) and then cuts it and brings it back to the
# previous structure.
# we use a unique identifier $ $ $ $ to know when to combine back.
combined_content = "\n$ $ $ $".join(context)
combined_cut_content = LLM.cut_tokens(combined_content)
resulting_data = combine_content(combined_cut_content, data)
# Now we bring it in a appropriate structure where we distinguish between documents.
formatted_context = format_results(resulting_data)
# from here on we start the llm requests!
# check if chat is allowed for selected llm
logger.info("Stream response: %s", stream_response)
if stream_response == "true":
if lite_llm == "true" or llm_component_name in stream_llms:
answer = LLM.llm_chat_request(
chat=messages,
context=formatted_context,
language=selected_language,
stream_response=True,
)
dashboardlogger.add_information(label="language", value=selected_language)
dashboardlogger.add_information(label="model", value=llm_component_name)
# set logger info in generator
return StreamingResponse(
LLM.generate_stream_answer(answer, dashboardlogger),
headers={
"Access-Control-Expose-Headers": "metadata",
"metadata": str(metadata),
},
)
else:
raise HTTPException(status_code=405, detail="Stream not allowed for selected LLM.")
else:
logger.info("Send Chat Request without streaming...")
answer = LLM.llm_chat_request(
chat=messages,
context=formatted_context,
language=selected_language,
stream_response=False,
)
dashboardlogger.add_information(label="answer", value=answer)
dashboardlogger.add_information(label="language", value=selected_language)
dashboardlogger.add_information(label="model", value=llm_component_name)
# close logging so it can save the logs
dashboardlogger.close_logging()
# Return LLM chat response
return {
"answer": answer,
"sources": list(set(sources)),
"metadata": metadata,
"query": query,
}

View file

@ -0,0 +1,130 @@
"""This module incorporates all API endpoints for the search part of the application."""
import os
import json
from fastapi import APIRouter, HTTPException
from connector.database_interface.opensearch_client import OpenSearchInterface
from neural_search.search_component import IndexSearchComponent
from common_packages import logging
logger = logging.create_logger(
log_level=os.getenv("LOGGING_LEVEL", "INFO"),
logger_name=__name__,
)
router = APIRouter()
# parses the chosen neural search component based on flags.
OS_INTERFACE = OpenSearchInterface(
index_name="german",
embedder_name="PM-AI/bi-encoder_msmarco_bert-base_german",
embedding_size=768,
language="german",
)
SEARCH_COMPONENT = IndexSearchComponent(os_client=OS_INTERFACE)
@router.post(f"/search-engine")
def search_engine(
query: str,
tags: list = None,
languages: list = None,
start_date: str = None,
end_date: str = None,
):
"""Takes an query and returns all relevant documents for it.
Args:
query (str): The search query string.
tags (list[str], optional): List of tags to filter by.
languages (list[str], optional): List of languages to filter by.
start_date (str, optional): Start date for date-of-upload range.
end_date (str, optional): End date for date-of-upload range.
Returns:
dict: returns the results of the prompt.
"""
# Parse stringified arrays back into lists
if tags is not None:
try:
tags = json.loads(tags)
except json.JSONDecodeError as e:
raise HTTPException(
status_code=400, detail="Invalid format for parameter 'tags'"
) from e
if languages is not None:
try:
languages = json.loads(languages)
except json.JSONDecodeError as e:
raise HTTPException(
status_code=400, detail="Invalid format for parameter 'languages'"
) from e
logger.info(
"Received parameters for search. tags: %s | languages: %s", tags, languages
)
try:
search_engine_results = SEARCH_COMPONENT.get_search_engine_results(
search_query=query,
tags=tags,
languages=languages,
start_date=start_date,
end_date=end_date,
)
search_engine_results = search_engine_results["aggregations"][
"group_by_source"
]["buckets"]
collected_entries = []
for entry in search_engine_results:
hits = entry["top_entries"]["hits"]["hits"]
collected_entries.extend(hits)
# Create a list to store the merged dictionaries
merged_entries = []
for entry in collected_entries:
# Extract the "_source" dictionary
source_dict = entry.pop("_source")
# Merge the dictionaries
result_dict = entry.copy()
result_dict.update(source_dict)
# Append the merged dictionary to the merged_entries list
merged_entries.append(result_dict)
logger.info("Number of entries found: %s", len(merged_entries))
return merged_entries
except Exception as e:
logger.error("Calling search engine failed with error: '%s'", e)
raise HTTPException("Calling search engine failed") from e
@router.get(f"/get-query-possibilities")
def get_query_possibilities():
"""Returns all possible query attributes
Returns:
dict: dict of attributes and unique values.
"""
try:
tag_result = OS_INTERFACE.get_unique_values(field_name="tag")
language_result = OS_INTERFACE.get_unique_values(field_name="language")
date_result = OS_INTERFACE.get_date_range()
results = {
"tags": tag_result,
"languages": language_result,
"daterage": date_result, # TODO: is this supposes to mean 'date_range'?
}
return results
except Exception as e:
logger.error("Calling OpenSearch failed with error: '%s'", e)
raise HTTPException("Calling OpenSearch failed") from e

View file

@ -0,0 +1,283 @@
"""Module to provide functionalities for neural search"""
# TODO: is this really necessary?
# allows to import neighboring directory
import sys
sys.path.append("..")
import os
from tqdm import tqdm
from datetime import datetime
from preprocessing import commons
from common_packages import logging
# TODO: should the logger instantiation go into the class constructor?
# instantiate logger
logger = logging.create_logger(
log_level=os.getenv("LOGGING_LEVEL", "INFO"),
logger_name=__name__,
)
class IndexSearchComponent:
""" "Index Search Client for Neural Search"""
def __init__(self, os_client):
# define embedder which embeds passages in index store.
self.logger_inst = logger
self.os_client = os_client
self.language = os_client.language
self.model = os_client.model
self.index_name = os_client.index_name
self.set_sources = [] # NOTE by default it is empty, so we search on everything
self.vector_type = "knn_vector"
self.embedding_space_name = "embedding_vector"
def get_top_sources(self, search_result):
"""Get the top sources regarding a result
Args:
search_result (dict): search results
Returns:
list: list of top sources
"""
results = [
{
"content": hit["_source"]["content"],
"page": hit["_source"]["metadata"]["page"],
"score": hit["_score"],
"source": hit["_source"]["metadata"]["source"],
}
for hit in search_result["hits"]["hits"]
]
sources = []
seen_combinations = set() # Set to keep track of source-page combinations
if isinstance(results, list) and len(results) > 0:
max_score = max(result["score"] for result in results)
for i in results:
if abs(i["score"] - max_score) <= 3:
combination = (
i["source"],
i["page"],
) # Create a tuple of source and page
if (
combination not in seen_combinations
): # Check if this combination is already added
sources.append(
{
"source": i["source"],
"page": i["page"],
"content": i["content"],
}
)
seen_combinations.add(
combination
) # Add the combination to the set
return sources
def set_metadata(self, sources: list):
"""Set the metadata so we can search on certain type of documents.
Args:
sources (list): sources we want to filter on
"""
self.set_sources = sources
def _transform_data(self, list_data: list, sources: str, pages: list, tag: str):
"""A helper-function to transform the list of passages into a readable input for LangChain.
Args:
list_data (list): list of passages in a readable input for LangChain.
metadata (str): a string containing metadata
"""
transformed_data = []
# transforms data into a readable input for retriever
for i in range(0, len(list_data)):
passage = commons.DotDict()
passage["content"] = list_data[i]
current_date = datetime.now().strftime("%Y-%m-%d")
passage["metadata"] = {
"source": sources,
"date-of-upload": current_date,
"language": self.language,
"page": pages[i],
"tag": tag,
}
transformed_data.append(passage)
return transformed_data
def set_indexes(self, data: list, sources, pages, tag):
"""Take a list of embedded documents and uploads them to OpenSearch
Args:
list_docs (list): List of documents
"""
list_docs = self._transform_data(
list_data=data, sources=sources, pages=pages, tag=tag
)
self.logger_inst.info("Indexing is running...")
# create embedding
list_embeddings = self.model.encode(list_docs)
self.logger_inst.info("Indexing completed.")
# Insert embeddings into OpenSearch iteratively.
for idx, (doc, embedding) in enumerate(
tqdm(zip(list_docs, list_embeddings), total=len(list_docs))
):
doc[self.embedding_space_name] = embedding.tolist()
self.os_client.os_client.index(index=self.index_name, body=doc)
self.logger_inst.info(
"Successfully indexed all documents into the indexing space: %s",
self.index_name,
)
def search(
self,
query: str,
tags=None,
languages=None,
start_date=None,
end_date=None,
k=20,
):
"""Neural Search based on intelligent embeddings, optionally filtered by specific metadata.
Args:
query (str): Query.
tags (list[str], optional): List of tags to filter by.
languages (list[str], optional): List of languages to filter by.
start_date (str, optional): Start date for date-of-upload range.
end_date (str, optional): End date for date-of-upload range.
k (int, optional): Top k results.
"""
query_vector = self.model.encode(query)
# Construct the basic query structure
search_body = {
"size": k,
"query": {
"bool": {
"must": {
"knn": {
self.embedding_space_name: {
"vector": query_vector.tolist(),
"k": k,
}
}
},
"filter": [],
}
},
}
# Add filters if they are provided
if tags:
search_body["query"]["bool"]["filter"].append(
{"terms": {"metadata.tag.keyword": tags}}
)
if languages:
search_body["query"]["bool"]["filter"].append(
{"terms": {"metadata.language.keyword": languages}}
)
if start_date or end_date:
date_range_filter = {"range": {"metadata.date-of-upload": {}}}
if start_date:
date_range_filter["range"]["metadata.date-of-upload"][
"gte"
] = start_date
if end_date:
date_range_filter["range"]["metadata.date-of-upload"]["lte"] = end_date
search_body["query"]["bool"]["filter"].append(date_range_filter)
response = self.os_client.os_client.search(
index=self.index_name, body=search_body
)
# Process the response to extract top passages
top_passages = self.get_top_sources(search_result=response)
return top_passages
def get_search_engine_results(
self,
search_query: str,
tags=None,
languages=None,
start_date=None,
end_date=None,
):
"""Execute a custom search with specific query structure and optional filters.
Args:
search_query (str): The search query string.
tags (list[str], optional): List of tags to filter by.
languages (list[str], optional): List of languages to filter by.
start_date (str, optional): Start date for date-of-upload range.
end_date (str, optional): End date for date-of-upload range.
Returns:
dict: Search results.
"""
# Basic query structure
query_body = {"bool": {"must": [{"match": {"content": search_query}}]}}
# Add filters if they are provided
filter_clauses = []
if tags:
filter_clauses.append({"terms": {"metadata.tag.keyword": tags}})
if languages:
filter_clauses.append({"terms": {"metadata.language.keyword": languages}})
if start_date or end_date:
date_range_filter = {"range": {"metadata.date-of-upload": {}}}
if start_date:
date_range_filter["range"]["metadata.date-of-upload"][
"gte"
] = start_date
if end_date:
date_range_filter["range"]["metadata.date-of-upload"]["lte"] = end_date
filter_clauses.append(date_range_filter)
if filter_clauses:
query_body["bool"]["filter"] = filter_clauses
# Final query with aggregations
custom_query_body = {
"size": 0,
"query": query_body,
"aggs": {
"group_by_source": {
"terms": {"field": "metadata.source.keyword", "size": 100000},
"aggs": {
"top_entries": {
"top_hits": {
"size": 3,
"sort": [{"_score": {"order": "desc"}}],
"_source": {"excludes": ["embedding_vector"]},
}
}
},
}
},
}
# Execute the search query
response = self.os_client.os_client.search(
index=self.index_name, body=custom_query_body
)
return response

View file

@ -0,0 +1,78 @@
"""Module for all kinds of preprocessing utility units"""
import os
from common_packages import logging
# instantiate logger
logger = logging.create_logger(
log_level=os.getenv("LOGGING_LEVEL", "INFO"),
logger_name=__name__,
)
class DotDict(dict):
"""Allows to be dict called with a dot notation.
Args:
dict (_type_): _description_
"""
def __getattr__(self, attr):
try:
return self[attr]
except KeyError as e:
logger.error("'DotDict' object has no attribute '%s'", attr)
raise AttributeError(f"'DotDict' object has no attribute '{attr}'") from e
def del_local_object(file_path: str):
"""Deletes an object inside a local directory
Args:
file_path (str): File to the object to be deleted.
"""
try:
os.remove(file_path)
logger.info("File '%s' deleted successfully", file_path)
except OSError as e:
logger.error("Deleting local file failed with error: %s", e)
def combine_content(chunks, data):
transformed_data = []
chunks = chunks.split("\n$ $ $ $")
start = 0
for chunk in chunks:
end = start + len(chunk)
# Find the original item that corresponds to this chunk
original_item = None
for item in data:
if chunk.startswith(item[0]):
original_item = item
break
if original_item:
content = chunk
source = original_item[1]
transformed_data.append([content, source])
start = end
return transformed_data
def format_results(documents):
all_sources = list(set([i[1] for i in documents]))
all_sources_content = []
for source in all_sources:
get_all_content_from_one_source = []
for document in documents:
if document[1] == source:
get_all_content_from_one_source.append(document[0])
combined = "\n".join(get_all_content_from_one_source)
all_sources_content.append(f"Document [{source}]: {combined}\n\n\n\n")
return "\n".join(all_sources_content)

View file

@ -0,0 +1,97 @@
"""Module for tools to process PDF documents"""
import os
import sys
import io
import PyPDF2
from langchain.text_splitter import RecursiveCharacterTextSplitter
import logging
import datetime as dt
from datetime import datetime
import traceback
# Setup Logging
logging.basicConfig(
level=logging.DEBUG,
# level=logging.INFO,
format="Start: " + str(dt.datetime.now()).replace(" ", "_") + " | %(asctime)s [%(levelname)s] %(message)s",
handlers=[
logging.FileHandler("/<path>-_" + str(datetime.today().strftime('%Y-%m-%d')) + "_-_debug.log"),
logging.StreamHandler(sys.stdout)
]
)
class PDFProcessor:
"""Class to handle the PDF processing"""
@staticmethod
def _calculate_spaces_length(chunk: str) -> int:
"""Calculates the length based on spaces after splitting the chunk.
Args:
chunk (str): Given text chunk.
Returns:
int: Length based on spaces.
"""
return len(chunk.split())
def chunk_text(self, text_pages: list) -> list:
"""Preprocess text from a PDF file, keeping track of page numbers.
Args:
text_pages (list): List of tuples with text and corresponding page numbers.
Returns:
list: List containing tuples of the preprocessed text and their page numbers from the PDF.
"""
splitter = RecursiveCharacterTextSplitter(
chunk_size=100,
chunk_overlap=10,
length_function=self._calculate_spaces_length,
)
chunks_with_pages = []
for text, page_number in text_pages:
chunks = splitter.create_documents([text])
for chunk in chunks:
chunks_with_pages.append((chunk.page_content, page_number))
return chunks_with_pages
def read_pdf(pdf_bytes: io.BytesIO) -> tuple:
"""Reads a pdf and returns two lists: one of text chunks and another
of their respective page numbers.
Args:
pdf_bytes (io.BytesIO): PDF file in BytesIO format.
Returns:
tuple of lists: (List of chunked text, List of corresponding page numbers).
"""
logging.info("Reading PDF document")
pdf_reader = PyPDF2.PdfReader(io.BytesIO(pdf_bytes))
num_pages = len(pdf_reader.pages)
logging.info("Read PDF document with '%s' pages", num_pages)
text_pages = []
for i in range(num_pages):
page = pdf_reader.pages[i]
text = page.extract_text()
if text:
text_pages.append((text, i + 1))
logging.info("Processing PDF content")
pdf_processor = PDFProcessor()
processed_chunks = pdf_processor.chunk_text(text_pages)
chunks = [chunk for chunk, _ in processed_chunks]
pages = [page for _, page in processed_chunks]
logging.info("PDF processed. Number of chunks: %s", len(chunks))
return chunks, pages

View file

View file

@ -0,0 +1,237 @@
"""Create a connection with relevant operations to OpenSearch"""
import os
from dotenv import load_dotenv
from opensearchpy import OpenSearch, RequestsHttpConnection, exceptions
from sentence_transformers import SentenceTransformer
from fastapi import HTTPException
from connector.database_interface.utils import mappings
from common_packages import logging
# load env-vars
load_dotenv()
# instantiate logger
logger = logging.create_logger(
log_level=os.getenv("LOGGING_LEVEL", "INFO"),
logger_name=__name__,
)
class OpenSearchInterface():
"""Client to interact with OpenSearch Instance"""
def __init__(self, index_name, embedder_name, embedding_size, language):
"""Initialize an OpenSearch interface object.
Use index name needed to create an index space in OS.
Args:
index_name (str): index name
"""
# super().__init__()
self.logger_inst = logger
self.os_client = OpenSearch(
hosts=[
{
"host": os.getenv("VECTOR_STORE_ENDPOINT"),
"port": os.getenv("VECTOR_STORE_PORT"),
}
],
http_auth=("admin", "admin"),
use_ssl=os.getenv("OPENSEARCH_USE_SSL", "False").lower() in ["true", "1", "yes", "y"],
verify_certs=False,
connection_class=RequestsHttpConnection,
)
self.index_name = index_name
self.language = language
self.document_store = None
self.embedding_size = embedding_size
self.distance_type = "l2"
self.model = SentenceTransformer(embedder_name)
self.vector_type = "knn_vector"
self.embedding_space_name = "embedding_vector"
mappings.create_index_with_mapping_passagelevel(
index_name=self.index_name,
os_client=self.os_client,
vector_type=self.vector_type,
embedding_size=self.embedding_size,
embedding_space_name=self.embedding_space_name,
distance_type=self.distance_type,
logger=self.logger_inst,
)
self.logger_inst.info(
"Mappings created. Loaded Embedding model %s", embedder_name
)
def create_index_with_mapping_passagelevel(
index_name,
os_client,
vector_type,
embedding_size,
embedding_space_name,
distance_type,
logger,
):
settings = {
"knn": True,
"knn.algo_param.ef_search": 512,
"index": {"number_of_shards": 3},
}
mapping = {
"properties": {
"pdf_id": {"type": "keyword"},
"text": {
"properties": {
"page_content": {"type": "text"},
"metadata": {
"type": "object",
"properties": {
"language": {"type": "keyword"},
"date-of-upload": {"type": "date"},
"tag": {"type": "keyword"},
},
},
}
},
"page_number": {"type": "integer"},
"filename": {"type": "text"},
embedding_space_name: {
"type": vector_type,
"dimension": embedding_size,
"method": {
"name": "hnsw",
"space_type": "innerproduct",
"engine": "faiss",
"parameters": {"ef_construction": 512, "m": 48},
},
},
}
}
# Create the index with the specified settings and mappings
if not os_client.indices.exists(index=index_name):
os_client.indices.create(
index=index_name, body={"settings": settings, "mappings": mapping}
)
logger.info(f"Successfully created document indexing space: {index_name}")
logger.info(f"Successfully created indexing space: {index_name}")
def get_unique_values(self, field_name: str = "source"):
"""Retrieve all unique values for a specified field from the OpenSearch index.
Args:
field_name (str): The field for which to retrieve unique values.
Returns:
list: List of unique values for the specified field.
"""
try:
query = {
"size": 0,
"aggs": {
"unique_values": {
"terms": {
"field": f"metadata.{field_name}.keyword",
"size": 10000, # Number of unique values we expect.
}
}
},
}
# Execute the query
response = self.os_client.search(index=self.index_name, body=query)
# Extract the terms from the response
values = [
bucket["key"]
for bucket in response["aggregations"]["unique_values"]["buckets"]
]
return values
except Exception as e:
self.logger_inst.error(
"Error in retrieving unique values for %s: %s",
field_name,
e,
)
return []
def get_date_range(self):
"""Retrieve the maximum and minimum dates from the 'date-of-upload' field.
Returns:
tuple[str, str]: A tuple containing the 'min_date' and 'max_date' values as strings.
"""
try:
query = {
"size": 0,
"aggs": {
"min_date": {"min": {"field": "metadata.date-of-upload"}},
"max_date": {"max": {"field": "metadata.date-of-upload"}},
},
}
# Execute the query
response = self.os_client.search(index=self.index_name, body=query)
# Extract the min and max dates from the response
min_date = response["aggregations"]["min_date"]["value_as_string"]
max_date = response["aggregations"]["max_date"]["value_as_string"]
return min_date, max_date
except Exception as e:
self.logger_inst.error("Error in retrieving date range: %s", e)
return None, None
def delete_indices_by_document(self, document_id: str):
"""Delete all indices belonging to the same document in OpenSearch.
Args:
document_id (str): The unique identifier of the document.
"""
try:
query = {"query": {"term": {"metadata.source.keyword": document_id}}}
self.os_client.delete_by_query(index=self.index_name, body=query)
self.logger_inst.info("Deleted all indices for document")
except Exception as e:
self.logger_inst.error(
"Failed deleting indices for document with error: %s", e
)
raise e
def empty_entire_index(self):
"""Delete all entries in the used vector db index."""
try:
query = {"query": {"match_all": {}}}
response = self.os_client.delete_by_query(index=self.index_name, body=query)
self.logger_inst.info(
"Deleted all %s entries for index: %s",
response["deleted"],
self.index_name,
)
except exceptions.NotFoundError as error:
self.logger_inst.warning("Failed emptying index with error: %s", error)
raise HTTPException(status_code=404, detail="Not found") from error
except exceptions.OpenSearchException as error:
self.logger_inst.error("Failed emptying index with error: %s", error)
raise HTTPException(
status_code=500, detail="Error while deleting from OpenSearch"
) from error