Skip to content Skip to footer

从头开始创建帧插值模型 — 使用卷积融合上采样

研究鸟类的翅膀运动以研究材料的变形、断裂和失效可能需要昂贵的高速摄像机。

令人惊讶的是(或者可能不是),有一种技术可以使用 AI 免费自动生成更多帧。

帧插值,也称为“中间”,允许您在 2 个中间帧对之间生成帧。

虽然乍一看很简单,但应用程序很有前途。

运动分析:

  • 运动科学:帧插值可用于分析运动员的动作,以确定需要改进的领域。例如,它可用于跟踪跑步者脚的运动,看看他们是否正确着陆,或者跟踪高尔夫球手挥杆的旋转,看看他们是否使用了正确的技术。
  • 安全和监控:帧插值可用于分析安全摄像头的视频片段,以识别潜在威胁或可疑活动。例如,它可用于跟踪人群中一个人的运动,看看他们的行为是否可疑,或者跟踪车辆的运动,看看它是否在跟随某人。

流体力学:

  • 空气动力学:帧插值可用于研究飞机和其他物体周围的空气流动。这可以帮助工程师设计更高效、更符合空气动力学的车辆。
  • 流体力学:坐标系插值可用于研究水和其他流体的流动。这可以帮助工程师设计更高效和有效的泵、涡轮机和其他流体机械。

化学反应:

  • 燃烧: 帧插值可用于研究燃料的燃烧。这可以帮助科学家开发更高效、更环保的发动机。
  • 聚合:框架插值可用于研究塑料和其他材料的聚合。这可以帮助科学家开发具有改进性能的新材料。

材料科学:

  • 断裂力学: 框架插补可用于研究材料的断裂。这可以帮助工程师设计出更耐开裂和失效的材料。
  • 疲劳:框架插值可用于研究材料的疲劳。这可以帮助工程师设计更耐磨的材料。

生物学研究:

  • 细胞生物学:框架插值可用于研究细胞和其他生物结构的运动。这可以帮助科学家了解细胞分裂、迁移和分化的过程。
  • 神经: 帧插值可用于研究大脑中神经元的活动。这可以帮助科学家了解学习、记忆和感知的过程。

在娱乐和制作中还有各种其他应用。本文将介绍如何从头开始创建整个帧插值模型,以及我是如何做到的。

image-20231224224635509

收集数据集

T对帧插值模型进行下雨需要来自各种视频的各种帧数据集。当从这样的数据集中学习时,模型会尝试学习连续帧之间的复杂关系,以准确合成缺失的中间帧。

具体来说,该模型学习...

  • 运动模式,
  • 外观变化,
  • 不同光照条件,
  • 以及各种其他功能。

帧插值模型可以从各种流行的数据集中学习。

  • Vimeo-90K: 这是一个大型数据集,其中包含具有不同帧速率和运动模式的高质量视频。它通常用于评估帧插值模型。
  • 米德尔伯里: 该数据集包含一组视频序列,这些序列在连续帧之间具有地面实况光流。它可用于训练和评估基于光流的帧插值方法。
  • Adobe 2K 培训套装: 该数据集由具有 GoPro 风格相机运动的 2K 分辨率视频组成。它适用于训练处理动态场景的帧插值模型。
  • 戴维斯: 该数据集包括具有复杂运动和遮挡的高质量视频。这是评估帧插值方法的一个具有挑战性的基准。
  • 洛桑联邦理工学院: 该数据集提供了具有极限运动的视频集合,例如运动和动作场景。它对于训练可以处理帧之间大位移的模型非常有用。

image-20231224224648655

Vimeo-90k 数据集

在下面的代码中,所有视频都保存到“数据集/视频”中。视频必须特别为 720p 才能具有适当的比例,尽管这可以很容易地更改。

import os
import cv2

# Function to process a video by resizing its frames and saving them as images
def process_video(input_path, output_path, max_frames=200, target_resolution=(160, 90)):
    # Open the video file
    cap = cv2.VideoCapture(input_path)

    # Get the width and height of the video frames
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))

    # Check if the video has the expected resolution (1280x720)
    if width == 1280 and height == 720:
        # Create the output directory if it doesn't exist
        os.makedirs(output_path, exist_ok=True)
        frame_count = 0
        success, frame = cap.read()

        # Process each frame, resize, and save to the output directory
        while success and frame_count < max_frames:
            frame_count += 1
            resized_frame = cv2.resize(frame, target_resolution)
            frame_filename = os.path.join(output_path, f"frame{frame_count:04d}.jpg")
            cv2.imwrite(frame_filename, resized_frame)
            success, frame = cap.read()

        # Release the video capture object
        cap.release()

if __name__ == "__main__":
    # Set input and output directories
    input_directory = 'dataset/videos'
    output_directory = 'dataset/frames'

    # Process each video file in the input directory
    for filename in os.listdir(input_directory):
        if filename.endswith(".mp4"):
            input_path = os.path.join(input_directory, filename)
            video_name = os.path.splitext(filename)[0]
            output_path = os.path.join(output_directory, video_name)

            # Call the process_video function for each video
            process_video(input_path, output_path)

    # Print a message indicating that processing is complete
    print("Processing complete.")

第 1 步:数据采集
代码首先定义输入和输出目录。假设输入目录包含下载的视频库存视频,而输出目录将存储处理后的帧。

第 2 步:视频处理
对于输入目录中的每个视频文件,代码提取帧并将其保存为 JPEG 图像。该函数将输入视频路径、输出路径、最大帧数和目标分辨率作为参数。process_video

  • 它使用 OpenCV 的功能打开视频。cv2.VideoCapture
  • 它使用 和 提取视频属性,包括宽度和高度。cv2.CAP_PROP_FRAME_WIDTHcv2.CAP_PROP_FRAME_HEIGHT
  • 它检查视频的分辨率是否为 1280x720。如果是这样,它会创建相应的输出目录(如果该目录不存在)。
  • 它使用 .cv2.VideoCapture.read()
  • 它使用 将每个帧的大小调整为目标分辨率 (160x90)。cv2.resize()
  • 它使用 将调整大小的帧保存为 JPEG 图像,使用格式“frameXXXX.jpg”命名文件,其中 XXXX 是零填充帧数。cv2.imwrite()

第 3 步:数据集生成
代码遍历输入目录,标识以 .对于每个视频,它提取其文件名,构造输出路径,并调用函数来处理视频并保存帧。最后,它打印完成消息。.mp4process_video

此过程将每个视频转换为一系列 JPEG 图像,从而创建视频帧数据集来训练帧插值模型。

训练模型

T他的帧插值模型方法利用卷积神经网络 (CNN) 架构来学习输入帧与相应的缺失中间帧之间的复杂关系。CNN从输入帧中提取特征,并以分层方式进行组合,以有效地捕捉视频序列的时间动态。

模型架构:
该类代表核心神经网络架构。它包括三个主要组成部分:FrameInterpolationModel

  1. 特征提取器: 由三个特征提取器模块组成的系列,每个模块由两个卷积层组成,后跟一个 ReLU 激活函数,负责从输入帧中提取高级特征。
  2. 上采样: 双线性上采样层,用于缩放提取的特征以匹配目标中间帧的分辨率。
  3. 融合和上采样卷积层: 融合卷积层,将上采样的特征图和生成预测的中间帧的最终上采样卷积层组合在一起。

数据准备
该类处理从一组视频序列中准备训练数据。FrameInterpolationDataset

  1. 构造函数使用包含视频帧的根目录初始化数据集,并定义一个转换函数以在必要时应用预处理步骤。
  2. 该方法返回数据样本的总数,而该方法检索由两个输入帧(帧 1 和帧 2)和相应的目标中间帧组成的特定数据样本。__len____getitem__

模型训练:主要训练循环包括几个步骤:

  1. 模型和损失定义: 初始化模型实例、损失函数(均方误差)和优化器(Adam)。
  2. 数据集加载: 训练数据集是使用数据加载器加载的,该加载器可以有效地批处理数据。
  3. 培训周期: 循环遍历指定数量的 epoch(训练周期)。
  4. 小批量处理: 在每个纪元中,循环遍历小批量数据。
  5. 输入准备: 输入帧被连接并发送到模型。
  6. 输出生成: 该模型生成预测的中间帧。
  7. 目标上采样: 对目标中间帧进行上采样,以匹配预测帧的分辨率。
  8. 损失计算: 计算预测帧和目标帧之间的均方误差。
  9. 反向传播和优化: 计算梯度,优化器更新模型的参数以最小化损失。
  10. 挂失: 定期报告损失值以跟踪训练进度。
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torch.nn.functional import interpolate as F_interpolate
from PIL import Image

# Definition of the Frame Interpolation Model
class FrameInterpolationModel(nn.Module):
    def __init__(self):
        super().__init__()
        # Feature extractor module with three sets of convolutional layers
        self.feature_extractor = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(6, 64, 3, 1, 1),
                nn.ReLU(inplace=True),
                nn.Conv2d(64, 128, 3, 1, 1),
                nn.ReLU(inplace=True)
            ) for _ in range(3)
        ])
        # Upsampling layer
        self.resize = nn.Upsample(size=(90, 160), mode='bilinear', align_corners=False)
        # Fusion convolution layer
        self.fusion_conv = nn.Conv2d(384, 128, 3, 1, 1)
        # Upsample convolution layer
        self.upsample_conv = nn.ConvTranspose2d(128, 3, 3, 2, 1, 1, 1)

    def forward(self, x):
        # Extract features from each input frame
        feature_maps = [extractor(x) for extractor in self.feature_extractor]
        # Resize the feature maps
        feature_maps_resized = [self.resize(fm) for fm in feature_maps]
        # Concatenate the resized feature maps
        x = torch.cat(feature_maps_resized, 1)
        # Apply fusion convolution and upsample convolution
        x = F.relu(self.fusion_conv(x))
        x = self.upsample_conv(x)
        return x

# Definition of the Frame Interpolation Dataset
class FrameInterpolationDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.video_list = [f"video{i}" for i in range(1, 26)]
        self.num_frames_per_video = 190

    def __len__(self):
        return len(self.video_list) * self.num_frames_per_video

    def __getitem__(self, idx):
        # Calculate video index and frame index
        video_idx, frame_idx = divmod(idx, self.num_frames_per_video)
        frame_idx += 1
        # Create paths for the frames
        video_folder = os.path.join(self.root_dir, f"video{video_idx + 1}")
        frame_idx = min(frame_idx, self.num_frames_per_video)
        frame1_path = os.path.join(video_folder, f"frame{frame_idx:04d}.jpg")
        frame2_path = os.path.join(video_folder, f"frame{frame_idx + 1:04d}.jpg")
        target_path = os.path.join(video_folder, f"frame{frame_idx + 2:04d}.jpg")
        # Open and transform the frames
        frame1, frame2, target = map(Image.open, (frame1_path, frame2_path, target_path))
        if self.transform:
            frame1, frame2, target = map(self.transform, (frame1, frame2, target))
        return frame1, frame2, target

# Instantiate the Frame Interpolation Model
model = FrameInterpolationModel()
# Define the Mean Squared Error loss criterion
criterion = nn.MSELoss()
# Set up the Adam optimizer for model parameters
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Load the Frame Interpolation Dataset
print("Loading dataset...")
dataset = FrameInterpolationDataset(root_dir='dataset/frames', transform=transforms.ToTensor())
# Create a DataLoader for batching and shuffling the dataset
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
print("Dataset loaded.")

# Train the Frame Interpolation Model
print("Training model...")
num_epochs = 15
for epoch in range(num_epochs):
    print(f"\nEpoch {epoch+1}/{num_epochs}")
    for frame1, frame2, target in dataloader:
        # Concatenate input frames
        inputs = torch.cat((frame1, frame2), 1)
        # Zero the gradients, forward pass, calculate loss, backward pass, and optimizer step
        optimizer.zero_grad()
        outputs = model(inputs)
        target_resized = F_interpolate(target, size=outputs.shape[2:], mode='bilinear', align_corners=False)
        loss = criterion(outputs, target_resized)
        loss.backward()
        optimizer.step()
    print(f'\nEpoch {epoch+1}/{num_epochs}, Loss: {loss.item()}', flush=True)

# Save the trained model's state dictionary
torch.save(model.state_dict(), 'models/frame_interpolation_model.pth')

纪元数取决于数据集和 GPU 的计算能力。最后,将训练好的模型保存到命名的文件中以供将来使用。frame_interpolation_model.pth

模型性能

image-20231224224717917

一个经过各种努力在更多时期训练模型,我的计算机在合理时间(~5 小时)内能够管理的最大值为 10 个时期。

尽管模型和结果还为时过早,但如果在至少 20 个 epoch 的专用 GPU 上进行训练,该模型将能够创建无缝的中间帧。

“从长远来看,我相信人工智能面临的最大挑战将是保持谦虚,并记住它只是我们可以用来改善生活的工具。”