import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification

#一定の閾値score_thresholdを超えた文書のみをrerankして再表示
def rerank(query,documents,score_threshold=1.0):
    model_name = "hotchpotch/japanese-reranker-cross-encoder-large-v1"
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForSequenceClassification.from_pretrained(model_name)

    inputs = tokenizer(
        [f"Query: {query} Document: {doc}" for doc in documents],
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=512
    )

    # モデルでスコアを計算
    with torch.no_grad():
        outputs = model(**inputs)
        scores = outputs.logits.squeeze().tolist()

    # スコアが閾値を超えるドキュメントをフィルタリング
    filtered_docs = [(doc, score) for doc, score in zip(documents, scores) if score > score_threshold]
    
    ranked_docs = sorted(filtered_docs, key=lambda x: x[1], reverse=False)
    
    # リランクされた結果を表示
    print("Ranked Documents:")
    for doc, score in ranked_docs:
        print(f"Score: {score:.4f}, Document: {doc}")
    
    return ranked_docs
     

# question = "年収の高いものはどれ？"
# documents = [
#     "月給2000万円を稼ぐ仕事です。待遇が良い職場です。",
#     "スープの味が美味しいと評判のレストランの話です。",
#     "年収100万円を得られる高収入のポジションです。",
#     "時給5000円で働けるアルバイトの求人です。"
# ]

# rag_rerank(question,documents,score_threshold=-10.0)