We have to go more low-level, as the pipeline function is not appropriate for what you are trying to do.
After you pass your sequence to AutoModelForCausalLM, the last tensor in the output will contain the probabilities of every token in the vocabulary being the next token. In the code below, I call it next_token_candidates_tensor. After that, you simply need to select the indices of the topk candidates and decode them back to words.
import torch
from transformers import AutoModelForCausalLM , AutoTokenizer
class LMHeadModel:
def __init__(self, model_name):
# Initialize the model and the tokenizer.
self.model = AutoModelForCausalLM.from_pretrained(model_name)
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
def get_predictions(self, sentence):
# Encode the sentence using the tokenizer and return the model predictions.
inputs = self.tokenizer.encode(sentence, return_tensors="pt")
with torch.no_grad():
outputs = self.model(inputs)
predictions = outputs[0]
return predictions
def get_next_word_probabilities(self, sentence, top_k=500):
# Get the model predictions for the sentence.
predictions = self.get_predictions(sentence)
# Get the next token candidates.
next_token_candidates_tensor = predictions[0, -1, :]
# Get the top k next token candidates.
topk_candidates_indexes = torch.topk(
next_token_candidates_tensor, top_k).indices.tolist()
# Get the token probabilities for all candidates.
all_candidates_probabilities = torch.nn.functional.softmax(
next_token_candidates_tensor, dim=-1)
# Filter the token probabilities for the top k candidates.
topk_candidates_probabilities = \
all_candidates_probabilities[topk_candidates_indexes].tolist()
# Decode the top k candidates back to words.
topk_candidates_tokens = \
[self.tokenizer.decode([idx]).strip() for idx in topk_candidates_indexes]
# Return the top k candidates and their probabilities.
return list(zip(topk_candidates_tokens, topk_candidates_probabilities))
sentence = "I enjoy walking in the"
model = LMHeadModel("gpt2")
model.get_next_word_probabilities(sentence, top_k=500)
# [('park', 0.15904344618320465),
# ('woods', 0.10028065741062164),
# ('streets', 0.0418376550078392),
# ('dark', 0.03117542900145054),
# ('door', 0.029618268832564354),
# ('street', 0.02388935722410679),
# ('rain', 0.021733922883868217),
# ...