I've been following some tutorials on training GPT-2, and I've scraped together some code that works in Google Colab, but when I move it over to Vertex AI workbench, it just seems to sit there and do nothing when I run the training code. I have the GPU quotas all set up and I have a billing account, and I've enabled all relevant APIs. Here is the code I'm using for the tokenizer:
def tokenize_function(examples):
        return base_tokenizer(examples['Prompt_Final'], padding=True)
    
# Split in train and test
df_train, df_val = train_test_split(df, train_size = 0.8)
# load the dataset from the dataframes
train_dataset = Dataset.from_pandas(df_train[['Prompt_Final']])
val_dataset = Dataset.from_pandas(df_val[['Prompt_Final']])
# tokenize the training and validation
tokenized_train_dataset = train_dataset.map(
    tokenize_function,
    batched=True,
    num_proc=1
)
tokenized_val_dataset = val_dataset.map(
    tokenize_function,
    batched=True,
    num_proc=1
)
And here is the code I'm using for the model:
bos = '<|startoftext|>'
eos = '<|endoftext|>'
body = '<|body|>'
special_tokens_dict = {'eos_token': eos, 'bos_token': bos, 'pad_token': '<pad>',
                       'sep_token': body} 
# the new tokens are added to the tokenizer
num_added_toks = base_tokenizer.add_special_tokens(special_tokens_dict)
# model configuration
config = AutoConfig.from_pretrained('gpt2', 
                                    bos_token_id=base_tokenizer.bos_token_id,
                                    eos_token_id=base_tokenizer.eos_token_id,
                                    pad_token_id=base_tokenizer.pad_token_id,
                                    sep_token_id=base_tokenizer.sep_token_id,
                                    output_hidden_states=False)
# we load the pre-trained model with custom settings
base_model = GPT2LMHeadModel.from_pretrained('gpt2', config=config)
# model embeding resizing
base_model.resize_token_embeddings(len(base_tokenizer))
# make sure its using the gpu
base_model = base_model.to(device)
And here is the code I'm using for the model path, the training args, the data collator, and the Trainer.
model_articles_path = r'Model/Model_Path'
training_args = TrainingArguments(
    output_dir=model_articles_path,  # output directory
    num_train_epochs=1,              # total num of training epochs
    per_device_train_batch_size=10,   # batch size per device during training
    per_device_eval_batch_size=10,    # batch size for evaluation
    warmup_steps=200,                # number of warmup steps for learning rate scheduler
    weight_decay=0.01,               # strength of weight decay
    logging_dir=model_articles_path, # directory for storing logs
    prediction_loss_only=True,
    evaluation_strategy= "steps",
    save_steps=10,
    gradient_accumulation_steps=1,
    # gradient_checkpointing=True,
    eval_accumulation_steps=1,
    fp16=True
)
data_collator = DataCollatorForLanguageModeling(
        tokenizer=base_tokenizer,
        mlm=False
    )
trainer = Trainer(
    model=base_model,                      # the instantiated Transformers model to be trained
    args=training_args,                    # training arguments, defined above
    data_collator=data_collator,
    train_dataset=tokenized_train_dataset, # training dataset
    eval_dataset=tokenized_val_dataset     # validation dataset
)
When I run trainer.train() though, it will start with this warning, which I'm very used to
/opt/conda/lib/python3.7/site-packages/transformers/optimization.py:395: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning
  FutureWarning,
And then the code will just sit there. I can tell its running, and when I check nvidia-smi in the terminal I can tell its using the GPU, but it just sits there. I am using a Tesla P100-PCIE-16GB GPU, and I am using the GPT-2 small model, so it should be making quick work of it with only 1000 rows of data. I'm hopeful that I've just made a dumb mistake somewhere, but if someone has some experience in this department it'd be greatly appreciated.