Transformers 保存並加載模型

本節說明如何保存和重新加載微調模型(BERT,GPT,GPT-2和Transformer-XL)。你需要保存三種文件類型才能重新加載經過微調的模型:

  • 模型本身應該是PyTorch序列化保存的模型(https://pytorch.org/docs/stable/notes/serialization.html#best-practices)
  • 模型的配置文件是保存為JSON文件
  • 詞彙表(以及基於GPT和GPT-2合併的BPE的模型)。

這些文件的默認文件名如下:

  • 模型權重文件:pytorch_model.bin
  • 配置文件:config.json
  • 詞彙文件:vocab.txt代表BERT和Transformer-XL,vocab.json代表GPT/GPT-2(BPE詞彙),
  • 代表GPT/GPT-2(BPE詞彙)額外的合併文件:merges.txt。

如果使用這些默認文件名保存模型,則可以使用from_pretrained()方法重新加載模型和tokenizer。

這是保存模型,配置和配置文件的推薦方法。詞彙到output_dir目錄,然後重新加載模型和tokenizer:

<code>from transformers import WEIGHTS_NAME, CONFIG_NAME

output_dir = "./models/"

# 步驟1:保存一個經過微調的模型、配置和詞彙表


#如果我們有一個分佈式模型,只保存封裝的模型
#它包裝在PyTorch DistributedDataParallel或DataParallel中
model_to_save = model.module if hasattr(model, 'module') else model
#如果使用預定義的名稱保存,則可以使用`from_pretrained`加載
output_model_file = os.path.join(output_dir, WEIGHTS_NAME)
output_config_file = os.path.join(output_dir, CONFIG_NAME)

torch.save(model_to_save.state_dict(), output_model_file)
model_to_save.config.to_json_file(output_config_file)
tokenizer.save_vocabulary(output_dir)

# 步驟2: 重新加載保存的模型

#Bert模型示例
model = BertForQuestionAnswering.from_pretrained(output_dir)
tokenizer = BertTokenizer.from_pretrained(output_dir, do_lower_case=args.do_lower_case) # Add specific options if needed
#GPT模型示例
model = OpenAIGPTDoubleHeadsModel.from_pretrained(output_dir)
tokenizer = OpenAIGPTTokenizer.from_pretrained(output_dir)/<code>

如果要為每種類型的文件使用特定路徑,則可以使用另一種方法保存和重新加載模型:

<code>output_model_file = "./models/my_own_model_file.bin"
output_config_file = "./models/my_own_config_file.bin"
output_vocab_file = "./models/my_own_vocab_file.bin"

# 步驟1:保存一個經過微調的模型、配置和詞彙表

#如果我們有一個分佈式模型,只保存封裝的模型
#它包裝在PyTorch DistributedDataParallel或DataParallel中
model_to_save = model.module if hasattr(model, 'module') else model

torch.save(model_to_save.state_dict(), output_model_file)
model_to_save.config.to_json_file(output_config_file)

tokenizer.save_vocabulary(output_vocab_file)

# 步驟2: 重新加載保存的模型

# 我們沒有使用預定義權重名稱、配置名稱進行保存,無法使用`from_pretrained`進行加載。
# 下面是在這種情況下的操作方法:

#Bert模型示例
config = BertConfig.from_json_file(output_config_file)
model = BertForQuestionAnswering(config)
state_dict = torch.load(output_model_file)
model.load_state_dict(state_dict)
tokenizer = BertTokenizer(output_vocab_file, do_lower_case=args.do_lower_case)

#GPT模型示例
config = OpenAIGPTConfig.from_json_file(output_config_file)
model = OpenAIGPTDoubleHeadsModel(config)
state_dict = torch.load(output_model_file)
model.load_state_dict(state_dict)
tokenizer = OpenAIGPTTokenizer(output_vocab_file)/<code>


分享到:


相關文章: