从零开始,用英伟达T4、A10训练小型文生视频模型,几小时搞定

很翔实的一篇教程。OpenAI 的 Sora、Stability AI 的 Stable Video Diffusion 以及许多其他已经发布或未来将出现的文本生成视频模型,是继大语言模型 (LLM) 之后 2024 年最流行的 AI 趋势之一。在这篇博客中,作者将展示如何将从头开始构建一个小规模的文本生成视频模型,涵盖了从理解理论概念、到编辑整个架构再到生成最终结果的所有内容。由于作者没有大算力的 GPU,所以仅编辑了小规模架构。以下是在不同处理器上训练模型所需时间的比较。作者表示,在 CPU 上运转显然须要更长

很翔实的一篇教程。

OpenAI 的 Sora、Stability AI 的 Stable Video Diffusion 以及许多其他已经发布或未来将出现的文本生成视频模型,是继大语言模型 (LLM) 之后 2024 年最流行的 AI 趋势之一。

在这篇博客中,作者将展示如何将从头开始构建一个小规模的文本生成视频模型,涵盖了从理解理论概念、到编辑整个架构再到生成最终结果的所有内容。

由于作者没有大算力的 GPU,所以仅编辑了小规模架构。以下是在不同处理器上训练模型所需时间的比较。

从零开始,用英伟达T4、A10训练小型文生视频模型,几小时搞定

作者表示,在 CPU 上运转显然须要更长的时间来训练模型。如果你须要快速尝试代码中的更改并查看结果,CPU 不是最佳选择。因此建议应用 Colab 或 Kaggle 的 T4 GPU 举行更高效、更快速的训练。

构建目标

我们采用了与传统机器学习或深度学习模型类似的方法,即在数据集上举行训练,然后在未见过数据上举行尝试。在文本转视频的背景下,假设有一个蕴涵 10 万个狗捡球和猫追老鼠视频的训练数据集,然后训练模型来生成猫捡球或狗追老鼠的视频。

从零开始,用英伟达T4、A10训练小型文生视频模型,几小时搞定

                              图源:iStock, GettyImages

虽然此类训练数据集在互联网上很容易获得,但所需的算力极高。因此,我们将应用由 Python 代码生成的静止对象视频数据集。同时应用 GAN(生成匹敌收集)架构来创设模型,而不是 OpenAI Sora 应用的扩散模型。

我们也尝试应用扩散模型,但内存要求超出了自己的能力。另一方面,GAN 可以更容易、更快地举行训练和尝试。

准备条件

我们将应用 OOP(面向对象编程),因此必须对它以及神经收集有基本的了解。此外 GAN(生成匹敌收集)的知识不是必需的,因为这里简单介绍它们的架构。

OOP:https://www.youtube.com/watch?v=q2SGW2VgwAM

神经收集理论:https://www.youtube.com/watch?v=Jy4wM2X21u0

GAN 架构:https://www.youtube.com/watch?v=TpMIssRdhco

Python 基础:https://www.youtube.com/watch?v=eWRfhZUzrAc

了解 GAN 架构

什么是 GAN?

生成匹敌收集是一种深度学习模型,其中两个神经收集相互竞争:一个从给定的数据集创设新数据(如图象或音乐),另一个则判断数据是真正的还是虚假的。这个过程一直持续到生成的数据与原始数据无法区分。

真正世界应用

生成图象:GAN 根据文本 prompt 创设真切的图象或修改现有图象,例如增强分辨率或为黑白照片添加颜色。

数据增强:GAN 生成合成数据来训练其他机器学习模型,例如为欺诈检测系统创设欺诈交易数据。

补充缺失信息:GAN 可以填充缺失数据,例如根据地形图生成地下图象以用于能源应用。

生成 3D 模型:GAN 将 2D 图象变换为 3D 模型,在医疗保健等领域非常有用,可用于为手术规划创设真切的器官图象。

GAN 工作原理

GAN 由两个深度神经收集组成:生成器和辨别器。这两个收集在匹敌设置中一起训练,其中一个收集生成新数据,另一个收集评估数据是真是假。

从零开始,用英伟达T4、A10训练小型文生视频模型,几小时搞定

GAN 训练示例

让我们以图象到图象的变换为例,解释一下 GAN 模型,重点是修改人脸。

1. 输出图象:输出图象是一张真正的人脸图象。

2. 属性修改:生成器会修改人脸的属性,比如给眼睛加上墨镜。

3. 生成图象:生成器会创设一组添加了太阳镜的图象。

4. 辨别器的任务:辨别器接收到混合的真正图象(带有太阳镜的人)和生成的图象(添加了太阳镜的人脸)。 

5. 评估:辨别器尝试区分真正图象和生成图象。 

6. 反馈回路:如果辨别器正确识别出假图象,生成器会调整其参数以生成更真切的图象。如果生成器成功欺骗了辨别器,辨别器会更新其参数以提高检测能力。 

通过这一匹敌过程,两个收集都在不断改进。生成器越来越善于生成真切的图象,而辨别器则越来越善于识别假图象,直到达到平衡,辨别器再也无法区分真正图象和生成的图象。此时,GAN 已成功学会生成真切的修改图象。

设置背景

我们将应用一系列 Python 库,让我们导入它们。

# Operating System module for interacting with the operating system

import os

# Module for generating random numbers

import random

# Module for numerical operations

import numpy as np

# OpenCV library for image processing

import cv2

# Python Imaging Library for image processing

from PIL import Image, ImageDraw, ImageFont

# PyTorch library for deep learning

import torch

# Dataset class for creating custom datasets in PyTorch

from torch.utils.data import Dataset

# Module for image transformations

import torchvision.transforms as transforms

# Neural network module in PyTorch

import torch.nn as nn

# Optimization algorithms in PyTorch

import torch.optim as optim

# Function for padding sequences in PyTorch

from torch.nn.utils.rnn import pad_sequence

# Function for saving images in PyTorch

from torchvision.utils import save_image

# Module for plotting graphs and images

import matplotlib.pyplot as plt

# Module for displaying rich content in IPython environments

from IPython.display import clear_output, display, HTML

# Module for encoding and decoding binary data to text

import base64

现在我们已经导入了所有的库,下一步就是定义我们的训练数据,用于训练 GAN 架构。

对训练数据举行编码

我们须要至少 10000 个视频作为训练数据。为什么呢?因为我尝试了较小数量的视频,结果非常糟糕,几乎没有任何效果。下一个重要问题是:这些视频内容是什么?  我们的训练视频数据集包括一个圆圈以不同方向和不同运动方式静止的视频。让我们来编辑代码并生成 10,000 个视频,看看它的效果如何。

# Create a directory named 'training_dataset'

os.makedirs('training_dataset', exist_ok=True)

# Define the number of videos to generate for the dataset

num_videos = 10000

# Define the number of frames per video (1 Second Video)

frames_per_video = 10

# Define the size of each image in the dataset

img_size = (64, 64)

# Define the size of the shapes (Circle)

shape_size = 10

设置一些基本参数后,接下来我们须要定义训练数据集的文本 prompt,并据此生成训练视频。

# Define text prompts and corresponding movements for circles

prompts_and_movements = [

("circle moving down", "circle", "down"), # Move circle downward

("circle moving left", "circle", "left"), # Move circle leftward

("circle moving right", "circle", "right"), # Move circle rightward

("circle moving diagonally up-right", "circle", "diagonal_up_right"), # Move circle diagonally up-right

("circle moving diagonally down-left", "circle", "diagonal_down_left"), # Move circle diagonally down-left

("circle moving diagonally up-left", "circle", "diagonal_up_left"), # Move circle diagonally up-left

("circle moving diagonally down-right", "circle", "diagonal_down_right"), # Move circle diagonally down-right

("circle rotating clockwise", "circle", "rotate_clockwise"), # Rotate circle clockwise

("circle rotating counter-clockwise", "circle", "rotate_counter_clockwise"), # Rotate circle counter-clockwise

("circle shrinking", "circle", "shrink"), # Shrink circle

("circle expanding", "circle", "expand"), # Expand circle

("circle bouncing vertically", "circle", "bounce_vertical"), # Bounce circle vertically

("circle bouncing horizontally", "circle", "bounce_horizontal"), # Bounce circle horizontally

("circle zigzagging vertically", "circle", "zigzag_vertical"), # Zigzag circle vertically

("circle zigzagging horizontally", "circle", "zigzag_horizontal"), # Zigzag circle horizontally

("circle moving up-left", "circle", "up_left"), # Move circle up-left

("circle moving down-right", "circle", "down_right"), # Move circle down-right

("circle moving down-left", "circle", "down_left"), # Move circle down-left

]

我们已经利用这些 prompt 定义了圆的几个运动轨迹。现在,我们须要编辑一些数学公式,以便根据 prompt 静止圆。

# Define function with parameters

def create_image_with_moving_shape(size, frame_num, shape, direction):  

# Create a new RGB image with specified size and white background

img = Image.new('RGB', size, color=(255, 255, 255)) 

# Create a drawing context for the image

draw = ImageDraw.Draw(img)

# Calculate the center coordinates of the image

center_x, center_y = size[0] // 2, size[1] // 2

# Initialize position with center for all movements

position = (center_x, center_y)

# Define a dictionary mapping directions to their respective position adjustments or image transformations

direction_map = {

# Adjust position downwards based on frame number

"down": (0, frame_num * 5 % size[1]), 

# Adjust position to the left based on frame number

"left": (-frame_num * 5 % size[0], 0), 

# Adjust position to the right based on frame number

"right": (frame_num * 5 % size[0], 0), 

# Adjust position diagonally up and to the right

"diagonal_up_right": (frame_num * 5 % size[0], -frame_num * 5 % size[1]), 

# Adjust position diagonally down and to the left

"diagonal_down_left": (-frame_num * 5 % size[0], frame_num * 5 % size[1]), 

# Adjust position diagonally up and to the left

"diagonal_up_left": (-frame_num * 5 % size[0], -frame_num * 5 % size[1]), 

# Adjust position diagonally down and to the right

"diagonal_down_right": (frame_num * 5 % size[0], frame_num * 5 % size[1]), 

# Rotate the image clockwise based on frame number

"rotate_clockwise": img.rotate(frame_num * 10 % 360, center=(center_x, center_y), fillcolor=(255, 255, 255)),  # Rotate the image counter-clockwise based on frame number

"rotate_counter_clockwise": img.rotate(-frame_num * 10 % 360, center=(center_x, center_y), fillcolor=(255, 255, 255)), 

# Adjust position for a bouncing effect vertically

"bounce_vertical": (0, center_y - abs(frame_num * 5 % size[1] - center_y)), 

# Adjust position for a bouncing effect horizontally

"bounce_horizontal": (center_x - abs(frame_num * 5 % size[0] - center_x), 0), 

# Adjust position for a zigzag effect vertically

"zigzag_vertical": (0, center_y - frame_num * 5 % size[1]) if frame_num % 2 == 0 else (0, center_y + frame_num * 5 % size[1]), 

# Adjust position for a zigzag effect horizontally

"zigzag_horizontal": (center_x - frame_num * 5 % size[0], center_y) if frame_num % 2 == 0 else (center_x + frame_num * 5 % size[0], center_y), 

# Adjust position upwards and to the right based on frame number

"up_right": (frame_num * 5 % size[0], -frame_num * 5 % size[1]), 

# Adjust position upwards and to the left based on frame number

"up_left": (-frame_num * 5 % size[0], -frame_num * 5 % size[1]), 

# Adjust position downwards and to the right based on frame number

"down_right": (frame_num * 5 % size[0], frame_num * 5 % size[1]), 

# Adjust position downwards and to the left based on frame number

"down_left": (-frame_num * 5 % size[0], frame_num * 5 % size[1]) 

}

# Check if direction is in the direction map

if direction in direction_map: 

# Check if the direction maps to a position adjustment

if isinstance(direction_map[direction], tuple): 

# Update position based on the adjustment

position = tuple(np.add(position, direction_map[direction])) 

else: # If the direction maps to an image transformation

# Update the image based on the transformation

img = direction_map[direction] 

# Return the image as a numpy array

return np.array(img)

上述函数用于根据所选方向在每一帧中静止我们的圆。我们只需在其上运转一个循环,直至生成所有视频的次数。

# Iterate over the number of videos to generate

for i in range(num_videos):

# Randomly choose a prompt and movement from the predefined list

prompt, shape, direction = random.choice(prompts_and_movements)

   # Create a directory for the current video

   video_dir = f'training_dataset/video_{i}'

   os.makedirs(video_dir, exist_ok=True)

   # Write the chosen prompt to a text file in the video directory

   with open(f'{video_dir}/prompt.txt', 'w') as f:

f.write(prompt)

   

  # Generate frames for the current video

   for frame_num in range(frames_per_video):

# Create an image with a moving shape based on the current frame number, shape, and direction

img = create_image_with_moving_shape(img_size, frame_num, shape, direction)

# Save the generated image as a PNG file in the video directory

cv2.imwrite(f'{video_dir}/frame_{frame_num}.png', img)

运转上述代码后,就会生成整个训练数据集。以下是训练数据集文件的结构。

从零开始,用英伟达T4、A10训练小型文生视频模型,几小时搞定

每个训练视频文件夹蕴涵其帧以及对应的文本 prompt。让我们看一下我们的训练数据集样本。

在我们的训练数据集中,我们没有蕴涵圆圈先向上静止然后向右静止的运动。我们将应用这个作为尝试 prompt,来评估我们训练的模型在未见过的数据上的表现。

从零开始,用英伟达T4、A10训练小型文生视频模型,几小时搞定

还有一个重要的要点须要注意,我们的训练数据蕴涵许多物体从场景中移出或部分出现在摄像机前方的样本,类似于我们在 OpenAI Sora 演示视频中观察到的情况。 

从零开始,用英伟达T4、A10训练小型文生视频模型,几小时搞定

在我们的训练数据中蕴涵此类样本的原因是为了尝试当圆圈从角落进入场景时,模型是否能够保持一致性而不会破坏其形状。

现在我们的训练数据已经生成,须要将训练视频变换为张量,这是 PyTorch 等深度学习框架中应用的主要数据类型。此外,通过将数据缩放到较小的范围,执行归一化等变换有助于提高训练架构的收敛性和稳定性。

预处理训练数据

我们必须为文本转视频任务编辑一个数据集类,它可以从训练数据集目录中读取视频帧及其相应的文本 prompt,使其可以在 PyTorch 中应用。

# Define a dataset class inheriting from torch.utils.data.Dataset

class TextToVideoDataset(Dataset):

def __init__(self, root_dir, transform=None):

# Initialize the dataset with root directory and optional transform

self.root_dir = root_dir

self.transform = transform

# List all subdirectories in the root directory

self.video_dirs = [os.path.join(root_dir, d) for d in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, d))]

# Initialize lists to store frame paths and corresponding prompts

self.frame_paths = []

self.prompts = []

# Loop through each video directory

for video_dir in self.video_dirs:

# List all PNG files in the video directory and store their paths

frames = [os.path.join(video_dir, f) for f in os.listdir(video_dir) if f.endswith('.png')]

self.frame_paths.extend(frames)

# Read the prompt text file in the video directory and store its content

with open(os.path.join(video_dir, 'prompt.txt'), 'r') as f:

prompt = f.read().strip()

# Repeat the prompt for each frame in the video and store in prompts list

self.prompts.extend([prompt] * len(frames))

# Return the total number of samples in the dataset

def __len__(self):

return len(self.frame_paths)

# Retrieve a sample from the dataset given an index

def __getitem__(self, idx):

# Get the path of the frame corresponding to the given index

frame_path = self.frame_paths[idx]

# Open the image using PIL (Python Imaging Library)

image = Image.open(frame_path)

# Get the prompt corresponding to the given index

prompt = self.prompts[idx]

# Apply transformation if specified

if self.transform:

image = self.transform(image)

# Return the transformed image and the prompt

return image, prompt

在继续编辑架构代码之前,我们须要对训练数据举行归一化处理。我们应用 16 的 batch 大小并对数据举行混洗以引入更多随机性。

实现文本嵌入层

你可能已经看到,在 Transformer 架构中,起点是将文本输出变换为嵌入,从而在多头注意力中举行进一步处理。类似地,我们在这里必须编辑一个文本嵌入层。基于该层,GAN 架构训练在我们的嵌入数据和图象张量上举行。

# Define a class for text embedding

class TextEmbedding(nn.Module):

# Constructor method with vocab_size and embed_size parameters

def __init__(self, vocab_size, embed_size):

# Call the superclass constructor

super(TextEmbedding, self).__init__()

# Initialize embedding layer

self.embedding = nn.Embedding(vocab_size, embed_size)

# Define the forward pass method

def forward(self, x):

# Return embedded representation of input

return self.embedding(x)

词汇量将基于我们的训练数据,在稍后举行计算。嵌入大小将为 10。如果应用更大的数据集,你还可以应用 Hugging Face 上已有的嵌入模型。

实现生成器层

现在我们已经知道生成器在 GAN 中的作用,接下来让我们对这一层举行编码,然后了解其内容。

class Generator(nn.Module):

def __init__(self, text_embed_size):

 super(Generator, self).__init__()

# Fully connected layer that takes noise and text embedding as input

self.fc1 = nn.Linear(100 + text_embed_size, 256 * 8 * 8)

# Transposed convolutional layers to upsample the input

self.deconv1 = nn.ConvTranspose2d(256, 128, 4, 2, 1)

self.deconv2 = nn.ConvTranspose2d(128, 64, 4, 2, 1)

self.deconv3 = nn.ConvTranspose2d(64, 3, 4, 2, 1) # Output has 3 channels for RGB images

# Activation functions

self.relu = nn.ReLU(True) # ReLU activation function

self.tanh = nn.Tanh() # Tanh activation function for final output

def forward(self, noise, text_embed):

# Concatenate noise and text embedding along the channel dimension

x = torch.cat((noise, text_embed), dim=1)

# Fully connected layer followed by reshaping to 4D tensor

x = self.fc1(x).view(-1, 256, 8, 8)

# Upsampling through transposed convolution layers with ReLU activation

x = self.relu(self.deconv1(x))

x = self.relu(self.deconv2(x))

# Final layer with Tanh activation to ensure output values are between -1 and 1 (for images)

x = self.tanh(self.deconv3(x))

return x

该 Generator 类负责根据随机噪声和文本嵌入的组合创设视频帧,旨在根据给定的文本描述生成真切的视频帧。该收集从完全连接层 (nn.Linear) 开始,将噪声向量和文本嵌入组合成单个特征向量。然后,该向量被重新整形并经过一系列的转置卷积层 (nn.ConvTranspose2d),这些层将特征图逐步上采样到所需的视频帧大小。

这些层应用 ReLU 激活 (nn.ReLU) 实现非线性,最后一层应用 Tanh 激活 (nn.Tanh) 将输出缩放到 [-1, 1] 的范围。因此,生成器将抽象的高维输出变换为以视觉方式表示输出文本的连贯视频帧。

实现辨别器层

在编辑完生成器层之后,我们须要实现另一半,即辨别器部分。

class Discriminator(nn.Module):

def __init__(self):

super(Discriminator, self).__init__()

    

       # Convolutional layers to process input images

        self.conv1 = nn.Conv2d(3, 64, 4, 2, 1)   # 3 input channels (RGB), 64 output channels, kernel size 4x4, stride 2, padding 1

self.conv2 = nn.Conv2d(64, 128, 4, 2, 1) # 64 input channels, 128 output channels, kernel size 4x4, stride 2, padding 1

self.conv3 = nn.Conv2d(128, 256, 4, 2, 1) # 128 input channels, 256 output channels, kernel size 4x4, stride 2, padding 1

       

         # Fully connected layer for classification

self.fc1 = nn.Linear(256 * 8 * 8, 1) # Input size 256x8x8 (output size of last convolution), output size 1 (binary classification)

  # Activation functions

self.leaky_relu = nn.LeakyReLU(0.2, inplace=True) # Leaky ReLU activation with negative slope 0.2

self.sigmoid = nn.Sigmoid() # Sigmoid activation for final output (probability)

def forward(self, input):

# Pass input through convolutional layers with LeakyReLU activation

x = self.leaky_relu(self.conv1(input))

x = self.leaky_relu(self.conv2(x))

x = self.leaky_relu(self.conv3(x))

# Flatten the output of convolutional layers

x = x.view(-1, 256 * 8 * 8)

# Pass through fully connected layer with Sigmoid activation for binary classification

x = self.sigmoid(self.fc1(x))

        return x

辨别器类用作二元分类器,区分真正视频帧和生成的视频帧。目的是评估视频帧的真正性,从而指导生成器产生更真正的输出。该收集由卷积层 (nn.Conv2d) 组成,这些卷积层从输出视频帧中提取分层特征, Leaky ReLU 激活 (nn.LeakyReLU) 增加非线性,同时允许负值的小梯度。

然后,特征图被展平并通过完全连接层 (nn.Linear),最终以 S 形激活 (nn.Sigmoid) 输出指示帧是真正还是假的概率分数。

通过训练辨别器准确地对帧举行分类,生成器同时接受训练以创设更令人信服的视频帧,从而骗过辨别器。

编辑训练参数

我们必须设置用于训练 GAN 的基础组件,例如损失函数、优化器等。

# Check for GPU

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Create a simple vocabulary for text prompts

all_prompts = [prompt for prompt, _, _ in prompts_and_movements] # Extract all prompts from prompts_and_movements list

vocab = {word: idx for idx, word in enumerate(set(" ".join(all_prompts).split()))} # Create a vocabulary dictionary where each unique word is assigned an index

vocab_size = len(vocab) # Size of the vocabulary

embed_size = 10 # Size of the text embedding vector

def encode_text(prompt):

# Encode a given prompt into a tensor of indices using the vocabulary

return torch.tensor([vocab[word] for word in prompt.split()])

# Initialize models, loss function, and optimizers

text_embedding = TextEmbedding(vocab_size, embed_size).to(device) # Initialize TextEmbedding model with vocab_size and embed_size

netG = Generator(embed_size).to(device) # Initialize Generator model with embed_size

netD = Discriminator().to(device) # Initialize Discriminator model

criterion = nn.BCELoss().to(device) # Binary Cross Entropy loss function

optimizerD = optim.Adam(netD.parameters(), lr=0.0002, betas=(0.5, 0.999)) # Adam optimizer for Discriminator

optimizerG = optim.Adam(netG.parameters(), lr=0.0002, betas=(0.5, 0.999)) # Adam optimizer for Generator

这是我们必须变换代码以在 GPU 上运转的部分(如果可用)。我们已经编辑了代码来查找 vocab_size,并且我们正在为生成器和辨别器应用 ADAM 优化器。你可以选择自己的优化器。在这里,我们将学习率设置为较小的值 0.0002,嵌入大小为 10,这比其他可供公众应用的 Hugging Face 模型要小得多。

编辑训练 loop

就像其他神经收集一样,我们将以类似的方式对 GAN 架构训练举行编码。

# Number of epochs

num_epochs = 13

# Iterate over each epoch

for epoch in range(num_epochs):

# Iterate over each batch of data

for i, (data, prompts) in enumerate(dataloader):

# Move real data to device

real_data = data.to(device)

# Convert prompts to list

prompts = [prompt for prompt in prompts]

# Update Discriminator

netD.zero_grad() # Zero the gradients of the Discriminator

batch_size = real_data.size(0) # Get the batch size

labels = torch.ones(batch_size, 1).to(device) # Create labels for real data (ones)

output = netD(real_data) # Forward pass real data through Discriminator

lossD_real = criterion(output, labels) # Calculate loss on real data

lossD_real.backward() # Backward pass to calculate gradients

        # Generate fake data

noise = torch.randn(batch_size, 100).to(device) # Generate random noise

text_embeds = torch.stack([text_embedding(encode_text(prompt).to(device)).mean(dim=0) for prompt in prompts]) # Encode prompts into text embeddings

fake_data = netG(noise, text_embeds) # Generate fake data from noise and text embeddings

labels = torch.zeros(batch_size, 1).to(device) # Create labels for fake data (zeros)

output = netD(fake_data.detach()) # Forward pass fake data through Discriminator (detach to avoid gradients flowing back to Generator)

lossD_fake = criterion(output, labels) # Calculate loss on fake data

lossD_fake.backward() # Backward pass to calculate gradients

optimizerD.step() # Update Discriminator parameters

# Update Generator

netG.zero_grad() # Zero the gradients of the Generator

labels = torch.ones(batch_size, 1).to(device) # Create labels for fake data (ones) to fool Discriminator

output = netD(fake_data) # Forward pass fake data (now updated) through Discriminator

lossG = criterion(output, labels) # Calculate loss for Generator based on Discriminator's response

lossG.backward() # Backward pass to calculate gradients

optimizerG.step() # Update Generator parameters

# Print epoch information

print(f"Epoch [{epoch + 1}/{num_epochs}] Loss D: {lossD_real + lossD_fake}, Loss G: {lossG}")

通过反向传播,我们的损失将针对生成器和辨别器举行调整。我们在训练 loop 中应用了 13 个 epoch。我们尝试了不同的值,但如果 epoch 高于这个值,结果并没有太大差异。此外,过度拟合的风险很高。如果我们的数据集更加多样化,蕴涵更多动作和形状,则可以考虑应用更高的 epoch,但在这里没有这样做。

当我们运转此代码时,它会开始训练,并在每个 epoch 之后 print 生成器和辨别器的损失。

## OUTPUT ##

Epoch [1/13] Loss D: 0.8798642754554749, Loss G: 1.300612449645996

Epoch [2/13] Loss D: 0.8235711455345154, Loss G: 1.3729925155639648

Epoch [3/13] Loss D: 0.6098687052726746, Loss G: 1.3266581296920776

...

保存训练的模型

训练完成后,我们须要保存训练好的 GAN 架构的辨别器和生成器,这只需两行代码即可实现。

# Save the Generator model's state dictionary to a file named 'generator.pth'

torch.save(netG.state_dict(), 'generator.pth')

# Save the Discriminator model's state dictionary to a file named 'discriminator.pth'

torch.save(netD.state_dict(), 'discriminator.pth')

生成 AI 视频

正如我们所讨论的,我们在未见过的数据上尝试模型的方法与我们训练数据中涉及狗取球和猫追老鼠的示例类似。因此,我们的尝试 prompt 可能涉及猫取球或狗追老鼠等场景。

在我们的特定情况下,圆圈向上静止然后向右静止的运动在训练数据中不存在,因此模型不熟悉这种特定运动。但是,模型已经在其他动作上举行了训练。我们可以应用此动作作为 prompt 来尝试我们训练过的模型并观察其性能。

# Inference function to generate a video based on a given text promptdef generate_video(text_prompt, num_frames=10):    # Create a directory for the generated video frames based on the text prompt    os.makedirs(f'generated_video_{text_prompt.replace(" ", "_")}', exist_ok=True)        # Encode the text prompt into a text embedding tensor    text_embed = text_embedding(encode_text(text_prompt).to(device)).mean(dim=0).unsqueeze(0)        # Generate frames for the video    for frame_num in range(num_frames):        # Generate random noise        noise = torch.randn(1, 100).to(device)                # Generate a fake frame using the Generator network        with torch.no_grad():            fake_frame = netG(noise, text_embed)                # Save the generated fake frame as an image file        save_image(fake_frame, f'generated_video_{text_prompt.replace(" ", "_")}/frame_{frame_num}.png')# usage of the generate_video function with a specific text promptgenerate_video('circle moving up-right')

当我们运转上述代码时,它将生成一个目录,其中蕴涵我们生成视频的所有帧。我们须要应用一些代码将所有这些帧合并为一个短视频。

# Define the path to your folder containing the PNG frames

folder_path = 'generated_video_circle_moving_up-right' # Get the list of all PNG files in the folder

image_files = [f for f in os.listdir(folder_path) if f.endswith('.png')]

# Sort the images by name (assuming they are numbered sequentially)

image_files.sort()

# Create a list to store the frames

frames = []

# Read each image and append it to the frames list

for image_file in image_files:

image_path = os.path.join(folder_path, image_file)

frame = cv2.imread(image_path)

frames.append(frame)

# Convert the frames list to a numpy array for easier processing

frames = np.array(frames)

# Define the frame rate (frames per second)

fps = 10

# Create a video writer object

fourcc = cv2.VideoWriter_fourcc(*'XVID')

out = cv2.VideoWriter('generated_video.avi', fourcc, fps, (frames[0].shape[1], frames[0].shape[0]))

# Write each frame to the video

for frame in frames:

out.write(frame)

# Release the video writer

out.release()

确保文件夹路径指向你新生成的视频所在的位置。运转此代码后,你将成功创设 AI 视频。让我们看看它是什么样子。

从零开始,用英伟达T4、A10训练小型文生视频模型,几小时搞定

我们举行了多次训练,训练次数相同。在两种情况下,圆圈都是从底部开始,出现一半。好消息是,我们的模型在两种情况下都尝试执行直立运动。

例如,在尝试 1 中,圆圈沿对角线向上静止,然后执行向上运动,而在尝试 2 中,圆圈沿对角线静止,同时尺寸缩小。在两种情况下,圆圈都没有向左静止或完全消失,这是一个好兆头。

最后,作者表示已经尝试了该架构的各个方面,发现训练数据是关键。通过在数据集中蕴涵更多动作和形状,你可以增加可变性并提高模型的性能。由于数据是通过代码生成的,因此生成更多样的数据不会花费太多时间;相反,你可以专注于完善逻辑。

此外,文章中讨论的 GAN 架构相对简单。你可以通过集成高级技术或应用语言模型嵌入 (LLM) 而不是基本神经收集嵌入来使其更复杂。此外,调整嵌入大小等参数会显著影响模型的有效性。

原文链接:https://levelup.gitconnected.com/building-an-ai-text-to-video-model-from-scratch-using-python-35b4eb4002de

给TA打赏
共{{data.count}}人
人已打赏
AI

AI“恐怖”体操视频腿脚乱飞、大变活人,LeCun:视频生成模型根本不懂物理

2024-7-1 15:40:00

AI

腾讯搜狗输入法上线 AI 对话、AI 宠物、快捷问答等功能

2024-7-1 18:10:56

0 条回复 A文章作者 M管理员
    暂无讨论,说说你的看法吧
个人中心
今日签到
搜索