使用 OCR 识别手写文本

本文实现了基于微调TrOCR模型进行手写文本识别。 1.GNHK手写笔记数据集GNHK(GoodNotes Handwriting Kollection)手写笔记数据集由GoodNotes提供,包含来自世界各地学生的数百份英文手写笔记。 下载数据集访问GNHK数据集官方网站:(),滚动到底部,同意使用条款和条件;点击第二个链接下载数据集。

本文实现了基于微调TrOCR模型进行手写文本识别。

使用 OCR 识别手写文本

1.GNHK手写笔记数据集

GNHK(GoodNotes Handwriting Kollection)手写笔记数据集由GoodNotes提供,包含来自世界各地学生的数百份英文手写笔记。

下载数据集

访问GNHK数据集官方网站:

(https://www.goodnotes.com/gnhk),滚动到底部,同意使用条款和条件;点击第二个链接下载数据集。

使用 OCR 识别手写文本

下载后会得到两个文件:train_data.zip 和 test_data.zip。解压这两个文件后,数据集的目录结构如下:

复制
├── test_data
│   └── test
│       ├── eng_AF_004.jpg
│       ├── eng_AF_004.json
│       ├── eng_AF_007.jpg
│       ├── eng_AF_007.json
│       ...
│       ├── eng_NA_142.jpg
│       └── eng_NA_142.json
├── train_data
    └── train
        ├── eng_AF_001.jpg
        ├── eng_AF_001.json
        ├── eng_AF_002.jpg
        ├── eng_AF_002.json
        ...
        ├── eng_NA_146.jpg
        └── eng_NA_146.json
4 directories, 1375 files
  • 训练集:包含515个样本
  • 测试集:包含172个样本
  • 图像文件:从1080p到4K的高分辨率图像
  • 标注文件:每个图像文件对应一个JSON文件,包含图像中每个单词的标注信息

以下是数据集中的一些手写笔记图像样本。

使用 OCR 识别手写文本

每个图像文件对应一个JSON文件,文件内容格式如下:

复制
[
    {
        "text": "%math%",
        "polygon": {
            "x0": 112, "y0": 556,
            "x1": 285, "y1": 563,
            "x2": 245, "y2": 776,
            "x3": 112, "y3": 783
        },
        "line_idx": 1,
        "type": "H"
    },
    {
        "text": "%math%",
        "polygon": {
            "x0": 2365, "y0": 202,
            "x1": 2350, "y1": 509,
            "x2": 2588, "y2": 527,
            "x3": 2632, "y3": 195
        },
        "line_idx": 0,
        "type": "H"
    },
    ...
    {
        "text": "ownership",
        "polygon": {
            "x0": 1347, "y0": 1606,
            "x1": 2238, "y1": 1574,
            "x2": 2170, "y2": 1884,
            "x3": 1300, "y3": 1747
        },
        "line_idx": 4,
        "type": "H"
    }
]

其中:

  • text:表示单词的内容。如果单词是数学符号、特殊字符或不可理解的内容(例如划线),则用%%符号包裹的特殊词表示。否则,text键包含实际的单词。
  • polygon:表示单词的多边形坐标,用于精确标注单词的位置。
  • line_idx:表示单词所在的行索引。
  • type:表示单词的类型,通常为"H"(手写)。

2.项目目录结构

复制
├── input
│   └── gnhk_dataset
│       ├── test_data
│       ├── test_processed
│       ├── train_data
│       ├── train_processed
│       ├── test_processed.csv
│       └── train_processed.csv
├── pretrained_model_inference  [10066 entries exceeds filelimit, not opening dir]
├── trocr_handwritten
│   ├── checkpoint-6093
│   │   ├── config.json
│   │   ├── generation_config.json
│   │   ├── model.safetensors
│   │   ├── optimizer.pt
│   │   ├── preprocessor_config.json
│   │   ├── rng_state.pth
│   │   ├── scheduler.pt
│   │   ├── trainer_state.json
│   │   └── training_args.bin
│   ├── checkpoint-6770
│   │   ├── config.json
│   │   ├── generation_config.json
│   │   ├── model.safetensors
│   │   ├── optimizer.pt
│   │   ├── preprocessor_config.json
│   │   ├── rng_state.pth
│   │   ├── scheduler.pt
│   │   ├── trainer_state.json
│   │   └── training_args.bin
│   └── runs
│       └── Aug27_11-30-05_f57a2dab37c7
├── Fine_Tune_TrOCR_Handwritten.ipynb
├── preprocess_gnhk_dataset.py
└── Pretrained_Model_Inference.ipynb

目录说明:

  • input/gnhk_dataset:包含下载并解压的数据集
  • pretrained_model_inference:包含使用预训练的TrOCR手写模型对验证数据集进行推理的结果。
  • trocr_handwritten:包含微调TrOCR模型后的结果。
  • Fine_Tune_TrOCR_Handwritten.ipynb:用于微调TrOCR模型的Jupyter Notebook
  • preprocess_gnhk_dataset.py:包含预处理GNHK数据集的Python脚本
  • Pretrained_Model_Inference.ipynb:用于使用预训练模型进行推理的Jupyter Notebook

3.安装依赖项

在继续进行数据预处理、推理和训练之前,我们需要安装以下依赖项。

复制
pip install transformers
pip install sentencepiece
pip install jiwer
pip install datasets
pip install evaluate
pip install -U accelerate

pip install matplotlib
pip install protobuf==3.20.1
pip install tensorboard

4.GNHK数据集预处理

预训练的TrOCR模型只能识别单个单词或单行句子,而GNHK数据集中的图像是整个文档的图像。因此需要对数据集进行预处理,以便模型能够更好地处理这些图像。

数据集预处理的关键步骤如下:

  • 转换多边形坐标为四点边界框坐标。
  • 裁剪每个单词并存储在单独的目录中。
  • 创建两个 CSV 文件,一个用于训练集,一个用于测试集。这些文件将包含裁剪后的图像名称和标签文本。

代码实现:

复制
import os
import json
import csv
import cv2
import numpy as np
from tqdm import tqdm
 
def create_directories():
   """
   创建必要的目录
   """
   dirs = [
       'input/gnhk_dataset/train_processed/images',
       'input/gnhk_dataset/test_processed/images',
   ]
   for dir_path in dirs:
       os.makedirs(dir_path, exist_ok=True)
 
def polygon_to_bbox(polygon):
    """
    将多边形坐标转换为四点边界框坐标
    """
   points = np.array([(polygon[f'x{i}'], polygon[f'y{i}']) for i in range(4)], dtype=np.int32)
   x, y, w, h = cv2.boundingRect(points)
   return x, y, w, h
 
def process_dataset(input_folder, output_folder, csv_path):
    """
    处理数据集,裁剪图像并生成 CSV 文件
    """
   with open(csv_path, 'w', newline='') as csvfile:
       csv_writer = csv.writer(csvfile)
       csv_writer.writerow(['image_filename', 'text'])
       
       for filename in tqdm(os.listdir(input_folder), desc=f"Processing {os.path.basename(input_folder)}"):
           if filename.endswith('.json'):
               json_path = os.path.join(input_folder, filename)
               img_path = os.path.join(input_folder, filename.replace('.json', '.jpg'))
               
               with open(json_path, 'r') as f:
                   data = json.load(f)
               
               img = cv2.imread(img_path)
               
               for idx, item in enumerate(data):
                   text = item['text']
                   if text.startswith('%') and text.endswith('%'):
                       text = 'SPECIAL_CHARACTER'
                   
                   x, y, w, h = polygon_to_bbox(item['polygon'])
                   
                   cropped_img = img[y:y+h, x:x+w]
                   
                   output_filename = f"{filename.replace('.json', '')}_{idx}.jpg"
                   output_path = os.path.join(output_folder, output_filename)
                   cv2.imwrite(output_path, cropped_img)
                   
                   csv_writer.writerow([output_filename, text])
                   
def main():
    """
    主函数,创建目录并处理数据集
    """
   create_directories()

   process_dataset(
       'input/gnhk_dataset/train_data/train',
       'input/gnhk_dataset/train_processed/images',
       'input/gnhk_dataset/train_processed.csv'
   )
   process_dataset(
       'input/gnhk_dataset/test_data/test',
       'input/gnhk_dataset/test_processed/images',
       'input/gnhk_dataset/test_processed.csv'
   )

if __name__ == '__main__':
   main()

将上述代码保存为preprocess_gnhk_dataset.py文件。在终端中运行脚本。

复制
python preprocess_gnhk_dataset.py

运行脚本后,将会在 input/gnhk_dataset 目录下创建以下子目录和文件:

子目录:

  • train_processed/images:存储训练集的裁剪图像。
  • test_processed/images:存储测试集的裁剪图像。

CSV 文件:

  • train_processed.csv:包含训练集的图像文件名和对应的标签文本。
  • test_processed.csv:包含测试集的图像文件名和对应的标签文本。

以下是一些经过处理后的裁剪图像示例:

使用 OCR 识别手写文本

csv文件示例如下图所示:

使用 OCR 识别手写文本

每个csv文件包括裁剪后的图像文件名和对应图像的标签文本。每一行表示一个裁剪后的图像及其对应的标签文本。

处理后的数据集包括:

  • 训练集:32495张裁剪图像
  • 测试集:10066张裁剪图像

5.微调TrOCR模型

首先,导入必要的库,并定义一些全局设置。

复制
import os
import torch
import evaluate
import numpy as np
import pandas as pd
import glob as glob
import torch.optim as optim
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
 
from PIL import Image
from tqdm.notebook import tqdm
from dataclasses import dataclass
from torch.utils.data import Dataset
from transformers import (
   VisionEncoderDecoderModel,
   TrOCRProcessor,
   Seq2SeqTrainer,
   Seq2SeqTrainingArguments,
   default_data_collator
)
 
block_plot = False
plt.rcParams['figure.figsize'] = (12, 9)
 
os.environ["TOKENIZERS_PARALLELISM"] = 'false'

接着,为确保实验的可重复性,设置随机种子,并初始化计算设备。

复制
def seed_everything(seed_value):
   np.random.seed(seed_value)
   torch.manual_seed(seed_value)
   torch.cuda.manual_seed_all(seed_value)
   torch.backends.cudnn.deterministic = True
   torch.backends.cudnn.benchmark = False

seed_everything(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

定义一些重要的配置项,包括训练和数据集的路径。这里设置批次大小batch size为48,训练轮数10,基础学习率0.00005。

复制
@dataclass(frozen=True)
class TrainingConfig:
   BATCH_SIZE:    int = 48
   EPOCHS:        int = 10
   LEARNING_RATE: float = 0.00005
 
@dataclass(frozen=True)
class DatasetConfig:
   DATA_ROOT:     str = 'input/gnhk_dataset'
 
@dataclass(frozen=True)
class ModelConfig:
   MODEL_NAME: str = 'microsoft/trocr-small-handwritten'

可视化训练样本,以帮助我们验证路径、CSV文件准备和标签是否正确。

复制
def visualize(dataset_path, df):
   all_images = df.image_filename
   all_labels = df.text
   
   plt.figure(figsize=(15, 3))
   for i in range(15):
       plt.subplot(3, 5, i+1)
       image = plt.imread(f"{dataset_path}/test_processed/images/{all_images[i]}")
       label = all_labels[i]
       plt.imshow(image)
       plt.axis('off')
       plt.title(label)
   plt.show()
sample_df = pd.read_csv(
   os.path.join(DatasetConfig.DATA_ROOT, 'test_processed.csv'),
   header=None,
   skiprows=1,
   names=['image_filename', 'text'],
   nrows=50
)
 
visualize(DatasetConfig.DATA_ROOT, sample_df)

使用 OCR 识别手写文本

GNHK手写文本识别数据集具有自定义的目录结构和CSV文件,我们需要编写自定义的数据集准备代码。

  • 读取csv文件
复制
train_df = pd.read_csv(
   os.path.join(DatasetConfig.DATA_ROOT, 'train_processed.csv'),
   header=None,
   skiprows=1,
   names=['image_filename', 'text']
)
 
 
test_df = pd.read_csv(
   os.path.join(DatasetConfig.DATA_ROOT, 'test_processed.csv'),
   header=None,
   skiprows=1,
   names=['image_filename', 'text']
)
  • 为了减少过拟合,应用一些轻微的数据增强,主要包括颜色抖动和高斯模糊。
复制
# 定义数据增强
train_transforms = transforms.Compose([
    transforms.ColorJitter(brightness=0.5, hue=0.3),
    transforms.GaussianBlur(kernel_size=(3, 3), sigma=(0.1, 5)),
])
  • 需要创建一个自定义的PyTorch数据集类。
复制
class CustomOCRDataset(Dataset):
    def __init__(self, root_dir, df, processor, max_target_length=128):
        self.root_dir = root_dir
        self.df = df
        self.processor = processor
        self.max_target_length = max_target_length

        # 填充空值
        self.df['text'] = self.df['text'].fillna('')

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        # 图像文件名
        file_name = self.df['image_filename'][idx]
        # 文本(标签)
        text = self.df['text'][idx]

        # 读取图像,应用数据增强,并获取转换后的像素值
        image = Image.open(os.path.join(self.root_dir, file_name)).convert('RGB')
        image = train_transforms(image)
        pixel_values = self.processor(image, return_tensors='pt').pixel_values

        # 通过分词器对文本进行分词,并获取标签
        labels = self.processor.tokenizer(
            text,
            padding='max_length',
            max_length=self.max_target_length
        ).input_ids

        # 使用 -100 作为填充标记
        labels = [label if label != self.processor.tokenizer.pad_token_id else -100 for label in labels]

        encoding = {
            "pixel_values": pixel_values.squeeze(),
            "labels": torch.tensor(labels)
        }
        return encoding
  • 初始化TrOCR处理器,并准备训练和验证数据集。
复制
# 初始化处理器
processor = TrOCRProcessor.from_pretrained(ModelConfig['MODEL_NAME'])

# 准备训练数据集
train_dataset = CustomOCRDataset(
    root_dir=os.path.join(DatasetConfig['DATA_ROOT'], 'train_processed/images/'),
    df=train_df,
    processor=processor
)

# 准备验证数据集
valid_dataset = CustomOCRDataset(
    root_dir=os.path.join(DatasetConfig['DATA_ROOT'], 'test_processed/images/'),
    df=test_df,
    processor=processor
)

初始化和配置模型,并统计模型的参数数量。

  • 加载模型
复制
# 初始化模型
model = VisionEncoderDecoderModel.from_pretrained(ModelConfig['MODEL_NAME'])
model.to(device)
print(model)

# 统计总参数和可训练参数
total_params = sum(p.numel() for p in model.parameters())
print(f"{total_params:,} total parameters.")

total_trainable_params = sum(
    p.numel() for p in model.parameters() if p.requires_grad
)
print(f"{total_trainable_params:,} training parameters.")
  • 手动设置一些配置。
复制
# 设置特殊 token 用于从标签创建 decoder_input_ids
model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
model.config.pad_token_id = processor.tokenizer.pad_token_id
# 设置正确的词汇表大小
model.config.vocab_size = model.config.decoder.vocab_size
model.config.eos_token_id = processor.tokenizer.sep_token_id

# 设置最大输出长度
model.config.max_length = 64
# 启用提前停止
model.config.early_stopping = True
# 设置不重复 n-gram 的大小
model.config.no_repeat_ngram_size = 3
# 设置长度惩罚
model.config.length_penalty = 2.0
# 设置 beam search 的束宽
model.config.num_beams = 4

# 打印模型配置
print(model.config)
  • 定义AdamW优化器,并配置学习率和权重衰减。
复制
# 定义 AdamW 优化器
optimizer = optim.AdamW(
    model.parameters(), lr=TrainingConfig['LEARNING_RATE'], weight_decay=0.0005
)
  • 使用字符错误率CER对模型进行评估。
复制
cer_metric = evaluate.load('cer')


def compute_cer(pred):
   # 提取标签的 ID
   labels_ids = pred.label_ids
   # 提取预测的 ID
   pred_ids = pred.predictions

   # 将预测的 ID 解码为字符串,跳过特殊 token
   pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True)
   # 将标签中的 -100 转换为 pad_token_id,以避免影响评估结果
   labels_ids[labels_ids == -100] = processor.tokenizer.
   # 将标签的 ID 解码为字符串,跳过特殊 token
   label_str = processor.batch_decode(labels_ids, skip_special_tokens=True)

   # 使用 cer_metric 计算 CER
   cer = cer_metric.compute(predictions=pred_str, references=label_str)

   return {"cer": cer}

训练和验证模型。在开始训练之前,需要初始化训练参数和 Trainer API。

  • 定义 Seq2SeqTrainingArguments 对象,设置训练和验证的相关参数。
复制
# 初始化训练参数
training_args = Seq2SeqTrainingArguments(
    predict_with_generate=True,
    evaluation_strategy='epoch',
    per_device_train_batch_size=TrainingConfig['BATCH_SIZE'],
    per_device_eval_batch_size=TrainingConfig['BATCH_SIZE'],
    fp16=True,
    output_dir='trocr_handwritten/',
    logging_strategy='epoch',
    save_strategy='epoch',
    save_total_limit=2,
    report_to='tensorboard',
    num_train_epochs=TrainingConfig['EPOCHS'],
    dataloader_num_workers=8
)
  • 使用 Seq2SeqTrainer API 初始化训练器。Seq2SeqTrainer 接受模型、处理器、训练参数、数据集和数据收集器作为参数。
复制
# 初始化训练器
trainer = Seq2SeqTrainer(
    model=model,
    tokenizer=processor.feature_extractor,
    args=training_args,
    compute_metrics=compute_cer,
    train_dataset=train_dataset,
    eval_dataset=valid_dataset,
    data_collator=default_data_collator
)
  • 开始微调模型。
复制
# 开始训练
trainer.train()

以下是训练10个epoch后的日志示例:

使用 OCR 识别手写文本

在训练完成后,我们得到了最佳的验证 CER 值。接下来,我们将使用最后一个epoch的检查点对验证集进行推理。

使用 OCR 识别手写文本

如图所示,验证CER图表在整个训练过程中持续下降,直到最后一个 epoch。这表明模型仍在学习,并且可能通过适当的学习率调度进一步训练几个 epoch 以获得更好的性能。

6.使用训练好的TrOCR模型推理

接下来,将使用训练好的trOCR模型对一组图像进行推理。

  • 加载处理器和训练好的模型检查点。
复制
# 定义模型和处理器
processor = TrOCRProcessor.from_pretrained(ModelConfig.MODEL_NAME)
trained_model = VisionEncoderDecoderModel.from_pretrained('trocr_handwritten/checkpoint-'+str(res.global_step)).to(device)
  • 定义一些辅助函数,用于读取图像、通过模型进行前向传播以及绘制结果。
复制
def read_and_show(image_path):
    """
    :param image_path: String, path to the input image.
    
    Returns:
        image: PIL Image.
    """
    image = Image.open(image_path).convert('RGB')
    return image
def ocr(image, processor, model):
    """
    :param image: PIL Image.
    :param processor: Huggingface OCR processor.
    :param model: Huggingface OCR model.
    
    Returns:
        generated_text: the OCR'd text string.
    """
    pixel_values = processor(image, return_tensors='pt').pixel_values.to(device)
    generated_ids = model.generate(pixel_values)
    generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
    return generated_text
def eval_new_data(data_path=None, num_samples=50, df=None):
    all_images = df.image_filename
    all_labels = df.text
    
    plt.figure(figsize=(15, 3))
    for i in range(num_samples):
        plt.subplot(3, 5, i+1)
        image = read_and_show(os.path.join(data_path, all_images[i]))
        text = ocr(image, processor, trained_model)
        plt.imshow(image)
        plt.title(text)
        plt.axis('off')
    plt.show()
  • 运行推理并可视化结果
复制
# 运行推理并可视化结果
eval_new_data(
    data_path=data_path,
    num_samples=num_samples,
    df=sample_df
)

推理结果如下图所示。

使用 OCR 识别手写文本

由此可以看出,模型成功地正确预测了所有单词。这表明经过微调后,模型在验证集上的表现非常出色。

附录(完整代码)

链接:https://pan.baidu.com/s/1R5-JB7zKTeb1pJ0kS2Tmnw

提取码:d388

相关资讯

腾讯OCR团队斩获ICDAR大赛四项冠军

在全球文字识别(OCR)领域顶级盛会ICDAR 2023上,腾讯OCR团队基于自研算法,斩获四项冠军,这是继2017年、2019年、2021年以来,连续四届参会同时创造佳绩,共获得18项官方认证冠军,展示了腾讯OCR技术在全球的一流水平。ICDAR大会是全球文档图像分析识别领域公认的权威学术会议,每两年举办一次,赛事举办至今已经吸引了超过100多个国家的近8000支队伍参与其中。ICDAR竞赛因其极高的技术难度和强大的实用性享誉国内外,与赛后非正式刷榜不同,ICDAR官方认证的正式竞赛采用全新的数据集,并且在比赛期

亮相CCIG2024,合合信息文档解析技术破解大模型语料“饥荒”难题

近日,2024中国图象图形大会在古都西安盛大开幕。本届大会由中国图象图形学学会主办,空军军医大学、西安交通大学、西北工业大学承办,通过二十多场论坛、百余项成果,集中展示了生成式人工智能、大模型、机器学习、类脑计算等多个图像图形领域的进展。大模型技术正随着科技革新实现广泛应用,满足多行业图像处理需求。大会期间,由CSIG文档图像分析与识别专委会与上海合合信息科技股份有限公司(简称“合合信息”)联合主办了《大模型技术及其前沿应用》论坛,来自华南理工大学、上海交通大学、清华大学、复旦大学、上海人工智能实验室、合合信息等高

【多模态&文档智能】OCR-free感知多模态大模型技术链路及训练数据细节

目前的一些多模态大模型的工作倾向于使用MLLM进行推理任务,然而,纯OCR任务偏向于模型的感知能力,对于文档场景,由于文字密度较高,现有方法往往通过增加图像token的数量来提升性能。 这种策略在增加新的语言时,需要重新进行训练,计算开销较大,成本较高。 因此,本文再来看看vary和got这两个衔接工作,看看其完整的技术链路。