130 lines
4 KiB
Python
130 lines
4 KiB
Python
"""This module incorporates all API endpoints for the search part of the application."""
|
|
|
|
import os
|
|
import json
|
|
|
|
from fastapi import APIRouter, HTTPException
|
|
|
|
from connector.database_interface.opensearch_client import OpenSearchInterface
|
|
from neural_search.search_component import IndexSearchComponent
|
|
|
|
from common_packages import logging
|
|
|
|
logger = logging.create_logger(
|
|
log_level=os.getenv("LOGGING_LEVEL", "INFO"),
|
|
logger_name=__name__,
|
|
)
|
|
|
|
router = APIRouter()
|
|
|
|
# parses the chosen neural search component based on flags.
|
|
OS_INTERFACE = OpenSearchInterface(
|
|
index_name="german",
|
|
embedder_name="PM-AI/bi-encoder_msmarco_bert-base_german",
|
|
embedding_size=768,
|
|
language="german",
|
|
)
|
|
SEARCH_COMPONENT = IndexSearchComponent(os_client=OS_INTERFACE)
|
|
|
|
|
|
@router.post(f"/search-engine")
|
|
def search_engine(
|
|
query: str,
|
|
tags: list = None,
|
|
languages: list = None,
|
|
start_date: str = None,
|
|
end_date: str = None,
|
|
):
|
|
"""Takes an query and returns all relevant documents for it.
|
|
|
|
Args:
|
|
query (str): The search query string.
|
|
tags (list[str], optional): List of tags to filter by.
|
|
languages (list[str], optional): List of languages to filter by.
|
|
start_date (str, optional): Start date for date-of-upload range.
|
|
end_date (str, optional): End date for date-of-upload range.
|
|
|
|
Returns:
|
|
dict: returns the results of the prompt.
|
|
"""
|
|
# Parse stringified arrays back into lists
|
|
if tags is not None:
|
|
try:
|
|
tags = json.loads(tags)
|
|
except json.JSONDecodeError as e:
|
|
raise HTTPException(
|
|
status_code=400, detail="Invalid format for parameter 'tags'"
|
|
) from e
|
|
|
|
if languages is not None:
|
|
try:
|
|
languages = json.loads(languages)
|
|
except json.JSONDecodeError as e:
|
|
raise HTTPException(
|
|
status_code=400, detail="Invalid format for parameter 'languages'"
|
|
) from e
|
|
logger.info(
|
|
"Received parameters for search. tags: %s | languages: %s", tags, languages
|
|
)
|
|
|
|
try:
|
|
search_engine_results = SEARCH_COMPONENT.get_search_engine_results(
|
|
search_query=query,
|
|
tags=tags,
|
|
languages=languages,
|
|
start_date=start_date,
|
|
end_date=end_date,
|
|
)
|
|
search_engine_results = search_engine_results["aggregations"][
|
|
"group_by_source"
|
|
]["buckets"]
|
|
collected_entries = []
|
|
|
|
for entry in search_engine_results:
|
|
hits = entry["top_entries"]["hits"]["hits"]
|
|
collected_entries.extend(hits)
|
|
|
|
# Create a list to store the merged dictionaries
|
|
merged_entries = []
|
|
for entry in collected_entries:
|
|
# Extract the "_source" dictionary
|
|
source_dict = entry.pop("_source")
|
|
|
|
# Merge the dictionaries
|
|
result_dict = entry.copy()
|
|
result_dict.update(source_dict)
|
|
|
|
# Append the merged dictionary to the merged_entries list
|
|
merged_entries.append(result_dict)
|
|
logger.info("Number of entries found: %s", len(merged_entries))
|
|
|
|
return merged_entries
|
|
|
|
except Exception as e:
|
|
logger.error("Calling search engine failed with error: '%s'", e)
|
|
raise HTTPException("Calling search engine failed") from e
|
|
|
|
|
|
@router.get(f"/get-query-possibilities")
|
|
def get_query_possibilities():
|
|
"""Returns all possible query attributes
|
|
|
|
Returns:
|
|
dict: dict of attributes and unique values.
|
|
"""
|
|
try:
|
|
tag_result = OS_INTERFACE.get_unique_values(field_name="tag")
|
|
language_result = OS_INTERFACE.get_unique_values(field_name="language")
|
|
date_result = OS_INTERFACE.get_date_range()
|
|
|
|
results = {
|
|
"tags": tag_result,
|
|
"languages": language_result,
|
|
"daterage": date_result, # TODO: is this supposes to mean 'date_range'?
|
|
}
|
|
|
|
return results
|
|
|
|
except Exception as e:
|
|
logger.error("Calling OpenSearch failed with error: '%s'", e)
|
|
raise HTTPException("Calling OpenSearch failed") from e
|