basic backend
This commit is contained in:
parent
16e5004228
commit
89ec0476ca
29 changed files with 2125 additions and 13 deletions
130
rag-chat-backend/src/endpoints/search.py
Normal file
130
rag-chat-backend/src/endpoints/search.py
Normal file
|
|
@ -0,0 +1,130 @@
|
|||
"""This module incorporates all API endpoints for the search part of the application."""
|
||||
|
||||
import os
|
||||
import json
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
|
||||
from connector.database_interface.opensearch_client import OpenSearchInterface
|
||||
from neural_search.search_component import IndexSearchComponent
|
||||
|
||||
from common_packages import logging
|
||||
|
||||
logger = logging.create_logger(
|
||||
log_level=os.getenv("LOGGING_LEVEL", "INFO"),
|
||||
logger_name=__name__,
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
# parses the chosen neural search component based on flags.
|
||||
OS_INTERFACE = OpenSearchInterface(
|
||||
index_name="german",
|
||||
embedder_name="PM-AI/bi-encoder_msmarco_bert-base_german",
|
||||
embedding_size=768,
|
||||
language="german",
|
||||
)
|
||||
SEARCH_COMPONENT = IndexSearchComponent(os_client=OS_INTERFACE)
|
||||
|
||||
|
||||
@router.post(f"/search-engine")
|
||||
def search_engine(
|
||||
query: str,
|
||||
tags: list = None,
|
||||
languages: list = None,
|
||||
start_date: str = None,
|
||||
end_date: str = None,
|
||||
):
|
||||
"""Takes an query and returns all relevant documents for it.
|
||||
|
||||
Args:
|
||||
query (str): The search query string.
|
||||
tags (list[str], optional): List of tags to filter by.
|
||||
languages (list[str], optional): List of languages to filter by.
|
||||
start_date (str, optional): Start date for date-of-upload range.
|
||||
end_date (str, optional): End date for date-of-upload range.
|
||||
|
||||
Returns:
|
||||
dict: returns the results of the prompt.
|
||||
"""
|
||||
# Parse stringified arrays back into lists
|
||||
if tags is not None:
|
||||
try:
|
||||
tags = json.loads(tags)
|
||||
except json.JSONDecodeError as e:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Invalid format for parameter 'tags'"
|
||||
) from e
|
||||
|
||||
if languages is not None:
|
||||
try:
|
||||
languages = json.loads(languages)
|
||||
except json.JSONDecodeError as e:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Invalid format for parameter 'languages'"
|
||||
) from e
|
||||
logger.info(
|
||||
"Received parameters for search. tags: %s | languages: %s", tags, languages
|
||||
)
|
||||
|
||||
try:
|
||||
search_engine_results = SEARCH_COMPONENT.get_search_engine_results(
|
||||
search_query=query,
|
||||
tags=tags,
|
||||
languages=languages,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
)
|
||||
search_engine_results = search_engine_results["aggregations"][
|
||||
"group_by_source"
|
||||
]["buckets"]
|
||||
collected_entries = []
|
||||
|
||||
for entry in search_engine_results:
|
||||
hits = entry["top_entries"]["hits"]["hits"]
|
||||
collected_entries.extend(hits)
|
||||
|
||||
# Create a list to store the merged dictionaries
|
||||
merged_entries = []
|
||||
for entry in collected_entries:
|
||||
# Extract the "_source" dictionary
|
||||
source_dict = entry.pop("_source")
|
||||
|
||||
# Merge the dictionaries
|
||||
result_dict = entry.copy()
|
||||
result_dict.update(source_dict)
|
||||
|
||||
# Append the merged dictionary to the merged_entries list
|
||||
merged_entries.append(result_dict)
|
||||
logger.info("Number of entries found: %s", len(merged_entries))
|
||||
|
||||
return merged_entries
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Calling search engine failed with error: '%s'", e)
|
||||
raise HTTPException("Calling search engine failed") from e
|
||||
|
||||
|
||||
@router.get(f"/get-query-possibilities")
|
||||
def get_query_possibilities():
|
||||
"""Returns all possible query attributes
|
||||
|
||||
Returns:
|
||||
dict: dict of attributes and unique values.
|
||||
"""
|
||||
try:
|
||||
tag_result = OS_INTERFACE.get_unique_values(field_name="tag")
|
||||
language_result = OS_INTERFACE.get_unique_values(field_name="language")
|
||||
date_result = OS_INTERFACE.get_date_range()
|
||||
|
||||
results = {
|
||||
"tags": tag_result,
|
||||
"languages": language_result,
|
||||
"daterage": date_result, # TODO: is this supposes to mean 'date_range'?
|
||||
}
|
||||
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Calling OpenSearch failed with error: '%s'", e)
|
||||
raise HTTPException("Calling OpenSearch failed") from e
|
||||
Loading…
Add table
Add a link
Reference in a new issue