基于Transformers的自然语言处理(NLP)入门(五)

本文为参加Datawhale组队学习时所写,如若需了解细致内容,请去到Datawhale官方开源课程基于transformers的自然语言处理(NLP)入门 (datawhalechina.github.io)

抽取式问答任务

抽取式问答任务:给定一个问题和一段文本,从这段文本中找出能够回答该问题的文本片段(span), 通过使用Tranier API和dataset包,我们可以轻松加载数据,然后微调transformers。

1
2
3
4
5
# squad_v2等于True或者False分别代表使用SQUAD v1 或者 SQUAD v2。
# 如果您使用的是其他数据集,那么True代表的是:模型可以回答“不可回答”问题,也就是部分问题不给出答案,而False则代表所有问题必须回答。
squad_v2 = False
model_checkpoint = "distilbert-base-uncased"
batch_size = 16

加载数据集

1
2
3
4
5
from datasets import load_datasets, load_metric


# 下载数据
datasets = load_dataset("squad_v2" if squad_v2 else "squad")

datasets的属性结构

1
2
3
4
5
6
7
8
9
10
DatasetDict({
train: Dataset({
features: ['id', 'title', 'context', 'question', 'answers'],
num_rows: 87599
})
validation: Dataset({
features: ['id', 'title', 'context', 'question', 'answers'],
num_rows: 10570
})
})

无论是训练集、验证集还是测试集,对于每一个问答数据样本都会有“context", "question"和“answers”三个key。

1
2
3
4
datasets["train"][0]
# answers代表答案
# context代表文本片段
# question代表问题
1
2
3
4
5
{'answers': {'answer_start': [515], 'text': ['Saint Bernadette Soubirous']},
'context': 'Architecturally, the school has a Catholic character. Atop the Main Building\'s gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend "Venite Ad Me Omnes". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive (and in a direct line that connects through 3 statues and the Gold Dome), is a simple, modern stone statue of Mary.',
'id': '5733be284776f41900661182',
'question': 'To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France?',
'title': 'University_of_Notre_Dame'}

answers除了给出了文本片段里的答案文本之外,还给出了该answer所在位置(以character开始计算,上面的例子是第515位)。

随机抽取数据,进行展示。

数据预处理

1
2
3
4
5
6
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

import transformers
assert isinstance(tokenizer, transformers.PreTrainedTokenizerFast)

现在我们还需要思考预训练机器问答模型们是如何处理非常长的文本的。一般来说预训练模型输入有最大长度要求,所以我们通常将超长的输入进行截断。但是,如果我们将问答数据三元组<question, context, answer>中的超长context截断,那么我们可能丢掉答案(因为我们是从context中抽取出一个小片段作为答案)。

为了解决这个问题,下面的代码找到一个超过长度的例子,然后向您演示如何进行处理。我们把超长的输入切片为多个较短的输入,每个输入都要满足模型最大长度输入要求。由于答案可能存在与切片的地方,因此我们需要允许相邻切片之间有交集,代码中通过doc_stride参数控制。

机器问答预训练模型通常将question和context拼接之后作为输入,然后让模型从context里寻找答案。

1
2
max_length = 384  # 输入feature的最大长度,question和context拼接之后
doc_stride = 128 # 2个切片之间的重合token数量。

for循环遍历数据集,寻找一个超长样本,本notebook例子模型所要求的最大输入是384(经常使用的还有512)

1
2
3
4
for i, example in enumerate(datasets["train"]):
if len(tokenizer(example["question"], example["context"])["input_ids"]) > 384:
break
example = datasets["train"][i]

如果不截断的话,那么输入的长度是396

1
2
len(tokenizer(example["question"], example["context"])["input_ids"])
396

现在如果我们截断成最大长度384,将会丢失超长部分的信息

注意,一般来说,我们只对context进行切片,不会对问题进行切片,由于context是拼接在question后面的,对应着第2个文本,所以使用only_second控制.tokenizer使用doc_stride控制切片之间的重合长度。

1
2
3
4
5
6
7
8
tokenized_example = tokenizer(
example["question"],
example["context"],
max_length=max_length,
truncation="only_second",
return_overflowing_tokens=True,
stride=doc_stride
)

由于对超长输入进行了切片,我们得到了多个输入,这些输入input_ids对应的长度是

1
2
3
[len(x) for x in tokenized_example["input_ids"]]

[384, 157]

我们可以将预处理后的token IDs,input_ids还原为文本格式:

1
2
3
for i, x in enumerate(tokenized_example["input_ids"][:2]):
print("切片: {}".format(i))
print(tokenizer.decode(x))
1
2
3
4
5
切片: 0
[CLS] how many wins does the notre dame men's basketball team have? [SEP] the men's basketball team has over 1, 600 wins, one of only 12 schools who have reached that ....
championship. the 201011 team concluded its regular season ranked number seven in the country, with a record of 255, brey's fifth straight 20 - win season, and a second - place finish in the big east. during the 2014 - 15 season, the team went 32 - 6 and won the acc conference tournament, later advancing to the elite 8, where the fighting irish lost on a missed buzzer - beater against then undefeated kentucky. led by nba draft picks jerian grant and pat connaughton, the fighting irish beat the eventual national champion duke blue devils twice during the season. the 32 wins were [SEP]
切片: 1
[CLS] how many wins does the notre dame men's basketball team have? [SEP] championship. the 201011 team concluded its regular season ranked number seven in the country, with a record of 255, brey's fifth straight 20 - win season, and a second - place finish in the big east. during the 2014 - 15 season, the team went 32 - 6 and won the acc conference tournament, later advancing to the elite 8, where the fighting irish lost on a missed buzzer - beater against then undefeated kentucky. led by nba draft picks jerian grant and pat connaughton, the fighting irish beat the eventual national champion duke blue devils twice during the season. the 32 wins were the most by the fighting irish team since 1908 - 09. [SEP]

由于我们对超长文本进行了切片,我们需要重新寻找答案所在位置(相对于每一片context开头的相对位置)。机器问答模型将使用答案的位置(答案的起始位置和结束位置,start和end)作为训练标签(而不是答案的token IDS)。所以切片需要和原始输入有一个对应关系,每个token在切片后context的位置和原始超长context里位置的对应关系。在tokenizer里可以使用return_offsets_mapping参数得到这个对应关系的map:

1
2
3
4
5
6
7
8
9
10
11
tokenized_example = tokenizer(
example["question"],
example["context"],
max_length=max_length,
truncation="only_second",
return_overflowing_tokens=True,
return_offsets_mapping=True,
stride=doc_stride
)
# 打印切片前后位置下标的对应关系
print(tokenized_example["offset_mapping"][0][:100])
1
2
[(0, 0), (0, 3), (4, 8), (9, 13), (14, 18), (19, 22), (23, 28), (29, 33), (34, 37), (37, 38), (38, 39), (40, 50), (51, 55), (56, 60), (60, 61), (0, 0), (0, 3), (4, 7), (7, 8), (8, 9), (10, 20), (21, 25), (26, 29), (30, 34), (35, 36), (36, 37), (37, 40), (41, 45), (45, 46), (47, 50), (51, 53), (54, 58), (59, 61), (62, 69), (70, 73), (74, 78), (79, 86), (87, 91), (92, 96), (96, 97), (98, 101), (102, 106), (107, 115), (116, 118), (119, 121), (122, 126), (127, 138), (138, 139), (140, 146), (147, 153), (154, 160), (161, 165), (166, 171), (172, 175), (176, 182), (183, 186), (187, 191), (192, 198), (199, 205), (206, 208), (209, 210), (211, 217), (218, 222), (223, 225), (226, 229), (230, 240), (241, 245), (246, 248), (248, 249), (250, 258), (259, 262), (263, 267), (268, 271), (272, 277), (278, 281), (282, 285), (286, 290), (291, 301), (301, 302), (303, 307), (308, 312), (313, 318), (319, 321), (322, 325), (326, 330), (330, 331), (332, 340), (341, 351), (352, 354), (355, 363), (364, 373), (374, 379), (379, 380), (381, 384), (385, 389), (390, 393), (394, 406), (407, 408), (409, 415), (416, 418)]
[0, 0]

上面打印的是tokenized_example第0片的前100个tokens在原始context片里的位置。注意第一个token是[CLS]设定为(0, 0)是因为这个token不属于qeustion或者answer的一部分。第2个token对应的起始和结束位置是0和3。我们可以根据切片后的token id转化对应的token;然后使用offset_mapping参数映射回切片前的token位置,找到原始位置的tokens。由于question拼接在context前面,所以直接从question里根据下标找就行了。

1
2
3
first_token_id = tokenized_example["input_ids"][0][1]
offsets = tokenized_example["offset_mapping"][0][1]
print(tokenizer.convert_ids_to_tokens([first_token_id])[0], example["question"][offsets[0]:offsets[1]])
1
how How

因此,我们得到了切片前后的位置对应关系。我们还需要使用sequence_ids参数来区分question和context。

1
sequence_ids = tokenized_example.sequence_ids()
1
[None, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, None, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, None]

None对应了special tokens,然后0或者1分表代表第1个文本和第2个文本,由于我们qeustin第1个传入,context第2个传入,所以分别对应question和context。最终我们可以找到标注的答案在预处理之后的features里的位置:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
answers = example["answers"]
start_char = answers["answer_start"][0]
end_char = start_char + len(answers["text"][0])

# 找到当前文本的Start token index.
token_start_index = 0
while sequence_ids[token_start_index] != 1:
token_start_index += 1

# 找到当前文本的End token idnex.
token_end_index = len(tokenized_example["input_ids"][0]) - 1
while sequence_ids[token_end_index] != 1:
token_end_index -= 1

# 检测答案是否在文本区间的外部,这种情况下意味着该样本的数据标注在CLS token位置。
offsets = tokenized_example["offset_mapping"][0]
if (offsets[token_start_index][0] <= start_char and offsets[token_end_index][1] >= end_char):
# 将token_start_index和token_end_index移动到answer所在位置的两侧.
# 注意:答案在最末尾的边界条件.
while token_start_index < len(offsets) and offsets[token_start_index][0] <= start_char:
token_start_index += 1
start_position = token_start_index - 1
while offsets[token_end_index][1] >= end_char:
token_end_index -= 1
end_position = token_end_index + 1
print("start_position: {}, end_position: {}".format(start_position, end_position))
else:
print("The answer is not in this feature.")
1
start_position: 23, end_position: 26

我们需要对答案的位置进行验证,验证方式是:使用答案所在位置下标,取到对应的token ID,然后转化为文本,然后和原始答案进行但对比。

1
2
print(tokenizer.decode(tokenized_example["input_ids"][0][start_position: end_position+1]))
print(answers["text"][0])
1
2
over 1, 600
over 1,600

有时候question拼接context,而有时候是context拼接question,不同的模型有不同的要求,因此我们需要使用padding_side参数来指定。

1
pad_on_right = tokenizer.padding_side == "right"  # context在右边

现在,把所有步骤合并到一起。对于context中无答案的情况,我们直接将标注的答案起始位置和结束位置放置在CLS的下标处。如果allow_impossible_answers这个参数是False的化,那这些无答案的样本都会被扔掉。现在,把所有步骤合并到一起。对于context中无答案的情况,我们直接将标注的答案起始位置和结束位置放置在CLS的下标处。如果allow_impossible_answers这个参数是False的化,那这些无答案的样本都会被扔掉。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
def prepare_train_features(examples):
# 既要对examples进行truncation(截断)和padding(补全)还要还要保留所有信息,所以要用的切片的方法。
# 每一个一个超长文本example会被切片成多个输入,相邻两个输入之间会有交集。
tokenized_examples = tokenizer(
examples["question" if pad_on_right else "context"],
examples["context" if pad_on_right else "question"],
truncation="only_second" if pad_on_right else "only_first",
max_length=max_length,
stride=doc_stride,
return_overflowing_tokens=True,
return_offsets_mapping=True,
padding="max_length",
)

# 我们使用overflow_to_sample_mapping参数来映射切片片ID到原始ID。
# 比如有2个expamples被切成4片,那么对应是[0, 0, 1, 1],前两片对应原来的第一个example。
sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
# offset_mapping也对应4片
# offset_mapping参数帮助我们映射到原始输入,由于答案标注在原始输入上,所以有助于我们找到答案的起始和结束位置。
offset_mapping = tokenized_examples.pop("offset_mapping")

# 重新标注数据
tokenized_examples["start_positions"] = []
tokenized_examples["end_positions"] = []

for i, offsets in enumerate(offset_mapping):
# 对每一片进行处理
# 将无答案的样本标注到CLS上
input_ids = tokenized_examples["input_ids"][i]
cls_index = input_ids.index(tokenizer.cls_token_id)

# 区分question和context
sequence_ids = tokenized_examples.sequence_ids(i)

# 拿到原始的example 下标.
sample_index = sample_mapping[i]
answers = examples["answers"][sample_index]
# 如果没有答案,则使用CLS所在的位置为答案.
if len(answers["answer_start"]) == 0:
tokenized_examples["start_positions"].append(cls_index)
tokenized_examples["end_positions"].append(cls_index)
else:
# 答案的character级别Start/end位置.
start_char = answers["answer_start"][0]
end_char = start_char + len(answers["text"][0])

# 找到token级别的index start.
token_start_index = 0
while sequence_ids[token_start_index] != (1 if pad_on_right else 0):
token_start_index += 1

# 找到token级别的index end.
token_end_index = len(input_ids) - 1
while sequence_ids[token_end_index] != (1 if pad_on_right else 0):
token_end_index -= 1

# 检测答案是否超出文本长度,超出的话也适用CLS index作为标注.
if not (offsets[token_start_index][0] <= start_char and offsets[token_end_index][1] >= end_char):
tokenized_examples["start_positions"].append(cls_index)
tokenized_examples["end_positions"].append(cls_index)
else:
# 如果不超出则找到答案token的start和end位置。.
# Note: we could go after the last offset if the answer is the last word (edge case).
while token_start_index < len(offsets) and offsets[token_start_index][0] <= start_char:
token_start_index += 1
tokenized_examples["start_positions"].append(token_start_index - 1)
while offsets[token_end_index][1] >= end_char:
token_end_index -= 1
tokenized_examples["end_positions"].append(token_end_index + 1)

return tokenized_examples

Fine-tuning微调模型

1
2
3
from transformers import AutoModelForQuestionAnswering, TrainingArguments, Trainer

model = AutoModelForQuestionAnswering.from_pretrained(model_checkpoint)

训练参数

1
2
3
4
5
6
7
8
9
args = TrainingArguments(
f"test-squad",
evaluation_strategy = "epoch",
learning_rate=2e-5, #学习率
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
num_train_epochs=3, # 训练的论次
weight_decay=0.01,
)
1
2
3
from transformers import default_data_collator

data_collator = default_data_collator
1
2
3
4
5
6
7
8
9
10
11
12
13
14
trainer = Trainer(
model,
args,
train_dataset=tokenized_datasets["train"],
eval_dataset=tokenized_datasets["validation"],
data_collator=data_collator,
tokenizer=tokenizer,
)

# 开始训练
trainer.train()

# 保存模型
trainer.save_model("test-squad-trained")

Evaluation评估

我们需要将模型的输出后处理成我们需要的文本格式。模型本身预测的是answer所在start/end位置的logits。如果我们评估时喂入模型的是一个batch,那么输出如下:

1
2
3
4
5
6
7
8
import torch

for batch in trainer.get_eval_dataloader():
break
batch = {k: v.to(trainer.args.device) for k, v in batch.items()}
with torch.no_grad():
output = trainer.model(**batch)
output.keys()

模型的输出是一个像dict的数据结构,包含了loss(因为提供了label,所有有loss),answer start和end的logits。

1
2
3
output.start_logits.shape, output.end_logits.shape

(torch.Size([16, 384]), torch.Size([16, 384]))

每个feature里的每个token都会有一个logit。预测answer最简单的方法就是选择start的logits里最大的下标最为answer其实位置,end的logits里最大下标作为answer的结束位置。

以上策略大部分情况下都是不错的。但是,如果我们的输入告诉我们找不到答案:比如start的位置比end的位置下标大,或者start和end的位置指向了question。

这个时候,简单的方法是我们继续需要选择第2好的预测作为我们的答案了,实在不行看第3好的预测,以此类推。

由于上面的方法不太容易找到可行的答案,我们需要思考更合理的方法。我们将start和end的logits相加得到新的打分,然后去看最好的n_best_size个start和end对。从n_best_size个start和end对里推出相应的答案,然后检查答案是否有效,最后将他们按照打分进行怕苦,选择得分最高的作为答案。由于上面的方法不太容易找到可行的答案,我们需要思考更合理的方法。我们将start和end的logits相加得到新的打分,然后去看最好的n_best_size个start和end对。从n_best_size个start和end对里推出相应的答案,然后检查答案是否有效,最后将他们按照打分进行怕苦,选择得分最高的作为答案。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
n_best_size = 20

import numpy as np

start_logits = output.start_logits[0].cpu().numpy()
end_logits = output.end_logits[0].cpu().numpy()
# 收集最佳的start和end logits的位置:
start_indexes = np.argsort(start_logits)[-1 : -n_best_size - 1 : -1].tolist()
end_indexes = np.argsort(end_logits)[-1 : -n_best_size - 1 : -1].tolist()
valid_answers = []
for start_index in start_indexes:
for end_index in end_indexes:
if start_index <= end_index: # 如果start小雨end,那么合理的
valid_answers.append(
{
"score": start_logits[start_index] + end_logits[end_index],
"text": "" # 后续需要根据token的下标将答案找出来
}
)

随后我们对根据scorevalid_answers进行排序,找到最好的那一个。最后还剩一步是:检查start和end位置对应的文本是否在context里面而不是在question里面。

为了完成这件事情,我们需要添加以下两个信息到validation的features里面:

  • 产生feature的example的ID。由于每个example可能会产生多个feature,所以每个feature/切片的feature需要知道他们对应的example。
  • offset mapping: 将每个切片的tokens的位置映射会原始文本基于character的下标位置。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
def prepare_validation_features(examples):
# Tokenize our examples with truncation and maybe padding, but keep the overflows using a stride. This results
# in one example possible giving several features when a context is long, each of those features having a
# context that overlaps a bit the context of the previous feature.
tokenized_examples = tokenizer(
examples["question" if pad_on_right else "context"],
examples["context" if pad_on_right else "question"],
truncation="only_second" if pad_on_right else "only_first",
max_length=max_length,
stride=doc_stride,
return_overflowing_tokens=True,
return_offsets_mapping=True,
padding="max_length",
)

# Since one example might give us several features if it has a long context, we need a map from a feature to
# its corresponding example. This key gives us just that.
sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")

# We keep the example_id that gave us this feature and we will store the offset mappings.
tokenized_examples["example_id"] = []

for i in range(len(tokenized_examples["input_ids"])):
# Grab the sequence corresponding to that example (to know what is the context and what is the question).
sequence_ids = tokenized_examples.sequence_ids(i)
context_index = 1 if pad_on_right else 0

# One example can give several spans, this is the index of the example containing this span of text.
sample_index = sample_mapping[i]
tokenized_examples["example_id"].append(examples["id"][sample_index])

# Set to None the offset_mapping that are not part of the context so it's easy to determine if a token
# position is part of the context or not.
tokenized_examples["offset_mapping"][i] = [
(o if sequence_ids[k] == context_index else None)
for k, o in enumerate(tokenized_examples["offset_mapping"][i])
]

return tokenized_examples
1
2
3
4
5
6
7
validation_features = datasets["validation"].map(
prepare_validation_features,
batched=True,
remove_columns=datasets["validation"].column_names
)

HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))

使用Trainer.predict方法获得所有预测结果

1
raw_predictions = trainer.predict(validation_features)

这个 Trainer 隐藏了 一些模型训练时候没有使用的属性(这里是 example_idoffset_mapping,后处理的时候会用到),所以我们需要把这些设置回来:

1
validation_features.set_format(type=validation_features.format["type"], columns=list(validation_features.features.keys()))

当一个token位置对应着question部分时候,prepare_validation_features函数将offset mappings设定为None,所以我们根据offset mapping很容易可以鉴定token是否在context里面啦。我们同样也根绝扔掉了特别长的答案。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
max_answer_length = 30

start_logits = output.start_logits[0].cpu().numpy()
end_logits = output.end_logits[0].cpu().numpy()
offset_mapping = validation_features[0]["offset_mapping"]
# The first feature comes from the first example. For the more general case, we will need to be match the example_id to
# an example index
context = datasets["validation"][0]["context"]

# Gather the indices the best start/end logits:
start_indexes = np.argsort(start_logits)[-1 : -n_best_size - 1 : -1].tolist()
end_indexes = np.argsort(end_logits)[-1 : -n_best_size - 1 : -1].tolist()
valid_answers = []
for start_index in start_indexes:
for end_index in end_indexes:
# Don't consider out-of-scope answers, either because the indices are out of bounds or correspond
# to part of the input_ids that are not in the context.
if (
start_index >= len(offset_mapping)
or end_index >= len(offset_mapping)
or offset_mapping[start_index] is None
or offset_mapping[end_index] is None
):
continue
# Don't consider answers with a length that is either < 0 or > max_answer_length.
if end_index < start_index or end_index - start_index + 1 > max_answer_length:
continue
if start_index <= end_index: # We need to refine that test to check the answer is inside the context
start_char = offset_mapping[start_index][0]
end_char = offset_mapping[end_index][1]
valid_answers.append(
{
"score": start_logits[start_index] + end_logits[end_index],
"text": context[start_char: end_char]
}
)

valid_answers = sorted(valid_answers, key=lambda x: x["score"], reverse=True)[:n_best_size]
valid_answers

将预测答案和真实答案进行比较即可:

1
2
3
4
datasets["validation"][0]["answers"]

{'answer_start': [177, 177, 177],
'text': ['Denver Broncos', 'Denver Broncos', 'Denver Broncos']}

由于第1个feature一定是来自于第1个example,所以相对容易。对于其他的fearures来说,我们需要一个features和examples的一个映射map。同样,由于一个example可能被切片成多个features,所以我们也需要将所有features里的答案全部联系起来。以下的代码就将exmaple的下标和features的下标进行map映射。

1
2
3
4
5
6
7
8
9
import collections

examples = datasets["validation"]
features = validation_features

example_id_to_index = {k: i for i, k in enumerate(examples["id"])}
features_per_example = collections.defaultdict(list)
for i, feature in enumerate(features):
features_per_example[example_id_to_index[feature["example_id"]]].append(i)

最后一点事情是如何解决无答案的情况(squad_v2=True的时候)。以上的代码都只考虑了context里面的asnwers,所以我们同样需要将无答案的预测得分进行搜集(无答案的预测对应的CLSt oken的start和end)。如果一个example样本又多个features,那么我们还需要在多个features里预测是不是都无答案。所以无答案的最终得分是所有features的无答案得分最小的那个。

只要无答案的最终得分高于其他所有答案的得分,那么该问题就是无答案。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
from tqdm.auto import tqdm

def postprocess_qa_predictions(examples, features, raw_predictions, n_best_size = 20, max_answer_length = 30):
all_start_logits, all_end_logits = raw_predictions
# Build a map example to its corresponding features.
example_id_to_index = {k: i for i, k in enumerate(examples["id"])}
features_per_example = collections.defaultdict(list)
for i, feature in enumerate(features):
features_per_example[example_id_to_index[feature["example_id"]]].append(i)

# The dictionaries we have to fill.
predictions = collections.OrderedDict()

# Logging.
print(f"Post-processing {len(examples)} example predictions split into {len(features)} features.")

# Let's loop over all the examples!
for example_index, example in enumerate(tqdm(examples)):
# Those are the indices of the features associated to the current example.
feature_indices = features_per_example[example_index]

min_null_score = None # Only used if squad_v2 is True.
valid_answers = []

context = example["context"]
# Looping through all the features associated to the current example.
for feature_index in feature_indices:
# We grab the predictions of the model for this feature.
start_logits = all_start_logits[feature_index]
end_logits = all_end_logits[feature_index]
# This is what will allow us to map some the positions in our logits to span of texts in the original
# context.
offset_mapping = features[feature_index]["offset_mapping"]

# Update minimum null prediction.
cls_index = features[feature_index]["input_ids"].index(tokenizer.cls_token_id)
feature_null_score = start_logits[cls_index] + end_logits[cls_index]
if min_null_score is None or min_null_score < feature_null_score:
min_null_score = feature_null_score

# Go through all possibilities for the `n_best_size` greater start and end logits.
start_indexes = np.argsort(start_logits)[-1 : -n_best_size - 1 : -1].tolist()
end_indexes = np.argsort(end_logits)[-1 : -n_best_size - 1 : -1].tolist()
for start_index in start_indexes:
for end_index in end_indexes:
# Don't consider out-of-scope answers, either because the indices are out of bounds or correspond
# to part of the input_ids that are not in the context.
if (
start_index >= len(offset_mapping)
or end_index >= len(offset_mapping)
or offset_mapping[start_index] is None
or offset_mapping[end_index] is None
):
continue
# Don't consider answers with a length that is either < 0 or > max_answer_length.
if end_index < start_index or end_index - start_index + 1 > max_answer_length:
continue

start_char = offset_mapping[start_index][0]
end_char = offset_mapping[end_index][1]
valid_answers.append(
{
"score": start_logits[start_index] + end_logits[end_index],
"text": context[start_char: end_char]
}
)

if len(valid_answers) > 0:
best_answer = sorted(valid_answers, key=lambda x: x["score"], reverse=True)[0]
else:
# In the very rare edge case we have not a single non-null prediction, we create a fake prediction to avoid
# failure.
best_answer = {"text": "", "score": 0.0}

# Let's pick our final answer: the best one or the null answer (only for squad_v2)
if not squad_v2:
predictions[example["id"]] = best_answer["text"]
else:
answer = best_answer["text"] if best_answer["score"] > min_null_score else ""
predictions[example["id"]] = answer

return predictions

将后处理函数应用到原始预测上

1
final_predictions = postprocess_qa_predictions(datasets["validation"], validation_features, raw_predictions.predictions)

加载评价指标

1
metric = load_metric("squad_v2" if squad_v2 else "squad")
1
2
3
4
5
6
if squad_v2:
formatted_predictions = [{"id": k, "prediction_text": v, "no_answer_probability": 0.0} for k, v in predictions.items()]
else:
formatted_predictions = [{"id": k, "prediction_text": v} for k, v in final_predictions.items()]
references = [{"id": ex["id"], "answers": ex["answers"]} for ex in datasets["validation"]]
metric.compute(predictions=formatted_predictions, references=references)

基于Transformers的自然语言处理(NLP)入门(五)
https://www.spacezxy.top/2021/09/28/nlp-transformer/nlp-transformer-5/
作者
Xavier ZXY
发布于
2021年9月28日
许可协议