"""This module incorporates all API endpoints for the search part of the application.""" import os import json from fastapi import APIRouter, HTTPException from connector.database_interface.opensearch_client import OpenSearchInterface from neural_search.search_component import IndexSearchComponent from common_packages import logging logger = logging.create_logger( log_level=os.getenv("LOGGING_LEVEL", "INFO"), logger_name=__name__, ) router = APIRouter() # parses the chosen neural search component based on flags. OS_INTERFACE = OpenSearchInterface( index_name="german", embedder_name="PM-AI/bi-encoder_msmarco_bert-base_german", embedding_size=768, language="german", ) SEARCH_COMPONENT = IndexSearchComponent(os_client=OS_INTERFACE) @router.post(f"/search-engine") def search_engine( query: str, tags: list = None, languages: list = None, start_date: str = None, end_date: str = None, ): """Takes an query and returns all relevant documents for it. Args: query (str): The search query string. tags (list[str], optional): List of tags to filter by. languages (list[str], optional): List of languages to filter by. start_date (str, optional): Start date for date-of-upload range. end_date (str, optional): End date for date-of-upload range. Returns: dict: returns the results of the prompt. """ # Parse stringified arrays back into lists if tags is not None: try: tags = json.loads(tags) except json.JSONDecodeError as e: raise HTTPException( status_code=400, detail="Invalid format for parameter 'tags'" ) from e if languages is not None: try: languages = json.loads(languages) except json.JSONDecodeError as e: raise HTTPException( status_code=400, detail="Invalid format for parameter 'languages'" ) from e logger.info( "Received parameters for search. tags: %s | languages: %s", tags, languages ) try: search_engine_results = SEARCH_COMPONENT.get_search_engine_results( search_query=query, tags=tags, languages=languages, start_date=start_date, end_date=end_date, ) search_engine_results = search_engine_results["aggregations"][ "group_by_source" ]["buckets"] collected_entries = [] for entry in search_engine_results: hits = entry["top_entries"]["hits"]["hits"] collected_entries.extend(hits) # Create a list to store the merged dictionaries merged_entries = [] for entry in collected_entries: # Extract the "_source" dictionary source_dict = entry.pop("_source") # Merge the dictionaries result_dict = entry.copy() result_dict.update(source_dict) # Append the merged dictionary to the merged_entries list merged_entries.append(result_dict) logger.info("Number of entries found: %s", len(merged_entries)) return merged_entries except Exception as e: logger.error("Calling search engine failed with error: '%s'", e) raise HTTPException("Calling search engine failed") from e @router.get(f"/get-query-possibilities") def get_query_possibilities(): """Returns all possible query attributes Returns: dict: dict of attributes and unique values. """ try: tag_result = OS_INTERFACE.get_unique_values(field_name="tag") language_result = OS_INTERFACE.get_unique_values(field_name="language") date_result = OS_INTERFACE.get_date_range() results = { "tags": tag_result, "languages": language_result, "daterage": date_result, # TODO: is this supposes to mean 'date_range'? } return results except Exception as e: logger.error("Calling OpenSearch failed with error: '%s'", e) raise HTTPException("Calling OpenSearch failed") from e