rag-chat/rag-chat-backend/src/endpoints/search.py
2024-06-23 12:57:53 +02:00

130 lines
4 KiB
Python

"""This module incorporates all API endpoints for the search part of the application."""
import os
import json
from fastapi import APIRouter, HTTPException
from connector.database_interface.opensearch_client import OpenSearchInterface
from neural_search.search_component import IndexSearchComponent
from common_packages import logging
logger = logging.create_logger(
log_level=os.getenv("LOGGING_LEVEL", "INFO"),
logger_name=__name__,
)
router = APIRouter()
# parses the chosen neural search component based on flags.
OS_INTERFACE = OpenSearchInterface(
index_name="german",
embedder_name="PM-AI/bi-encoder_msmarco_bert-base_german",
embedding_size=768,
language="german",
)
SEARCH_COMPONENT = IndexSearchComponent(os_client=OS_INTERFACE)
@router.post(f"/search-engine")
def search_engine(
query: str,
tags: list = None,
languages: list = None,
start_date: str = None,
end_date: str = None,
):
"""Takes an query and returns all relevant documents for it.
Args:
query (str): The search query string.
tags (list[str], optional): List of tags to filter by.
languages (list[str], optional): List of languages to filter by.
start_date (str, optional): Start date for date-of-upload range.
end_date (str, optional): End date for date-of-upload range.
Returns:
dict: returns the results of the prompt.
"""
# Parse stringified arrays back into lists
if tags is not None:
try:
tags = json.loads(tags)
except json.JSONDecodeError as e:
raise HTTPException(
status_code=400, detail="Invalid format for parameter 'tags'"
) from e
if languages is not None:
try:
languages = json.loads(languages)
except json.JSONDecodeError as e:
raise HTTPException(
status_code=400, detail="Invalid format for parameter 'languages'"
) from e
logger.info(
"Received parameters for search. tags: %s | languages: %s", tags, languages
)
try:
search_engine_results = SEARCH_COMPONENT.get_search_engine_results(
search_query=query,
tags=tags,
languages=languages,
start_date=start_date,
end_date=end_date,
)
search_engine_results = search_engine_results["aggregations"][
"group_by_source"
]["buckets"]
collected_entries = []
for entry in search_engine_results:
hits = entry["top_entries"]["hits"]["hits"]
collected_entries.extend(hits)
# Create a list to store the merged dictionaries
merged_entries = []
for entry in collected_entries:
# Extract the "_source" dictionary
source_dict = entry.pop("_source")
# Merge the dictionaries
result_dict = entry.copy()
result_dict.update(source_dict)
# Append the merged dictionary to the merged_entries list
merged_entries.append(result_dict)
logger.info("Number of entries found: %s", len(merged_entries))
return merged_entries
except Exception as e:
logger.error("Calling search engine failed with error: '%s'", e)
raise HTTPException("Calling search engine failed") from e
@router.get(f"/get-query-possibilities")
def get_query_possibilities():
"""Returns all possible query attributes
Returns:
dict: dict of attributes and unique values.
"""
try:
tag_result = OS_INTERFACE.get_unique_values(field_name="tag")
language_result = OS_INTERFACE.get_unique_values(field_name="language")
date_result = OS_INTERFACE.get_date_range()
results = {
"tags": tag_result,
"languages": language_result,
"daterage": date_result, # TODO: is this supposes to mean 'date_range'?
}
return results
except Exception as e:
logger.error("Calling OpenSearch failed with error: '%s'", e)
raise HTTPException("Calling OpenSearch failed") from e