arxiv_audio_summary/vibe/rerank.py

45 lines
1.6 KiB
Python
Raw Normal View History

2025-03-02 03:22:35 +00:00
import json
import re
import logging
2025-03-02 12:47:03 +00:00
from .llm import chat_llm
2025-03-02 03:22:35 +00:00
logger = logging.getLogger(__name__)
2025-03-02 12:47:03 +00:00
def rerank_articles(articles, user_info, llm_level="medium"):
2025-03-02 03:22:35 +00:00
"""
Calls the LLM to reorder the articles by importance. Returns the reordered list.
Expects a JSON response with a 'ranking' key pointing to a list of article IDs.
"""
if not articles:
return []
logger.info("Starting rerank for %d articles.", len(articles))
prompt_lines = [
f"User info: {user_info}\n",
2025-03-02 12:47:03 +00:00
('Please rank the following articles from most relevant to least relevant. '
'Return your answer as valid JSON in the format: { "ranking": [ "id1", "id2", ... ] }.')
2025-03-02 03:22:35 +00:00
]
for article in articles:
prompt_lines.append(
f"Article ID: {article['id']}\nTitle: {article['title']}\nAbstract: {article['abstract']}\n"
)
prompt = "\n".join(prompt_lines)
try:
2025-03-02 12:47:03 +00:00
response_text = chat_llm(prompt, level=llm_level)
match = re.search(r"\{.*\}", response_text, re.DOTALL)
2025-03-02 03:22:35 +00:00
if not match:
logger.error("No valid JSON found in rerank response.")
return articles
json_str = match.group(0)
rerank_result = json.loads(json_str)
ranking_list = rerank_result.get("ranking", [])
article_map = {a["id"]: a for a in articles}
reordered = [article_map[art_id] for art_id in ranking_list if art_id in article_map]
remaining = [a for a in articles if a["id"] not in ranking_list]
reordered.extend(remaining)
return reordered
except Exception as e:
logger.exception("Error during rerank: %s", e)
return articles