图像相似度估计 | 结合三元组损失的暹罗网络

在机器学习领域,确定图像之间的相似度在各种应用中至关重要,从检测重复项到面部识别。 解决这个问题的一个强大方法是使用暹罗网络结合三元组损失函数。 在本文中,我们将探索如何构建和训练暹罗网络以估计图像相似度,并通过一个来自GitHub仓库的实际示例进行说明。

在机器学习领域,确定图像之间的相似度在各种应用中至关重要,从检测重复项到面部识别。解决这个问题的一个强大方法是使用暹罗网络结合三元组损失函数。在本文中,我们将探索如何构建和训练暹罗网络以估计图像相似度,并通过一个来自GitHub仓库的实际示例进行说明。

图像相似度估计 | 结合三元组损失的暹罗网络

什么是暹罗网络?

暹罗网络是一种包含两个或更多相同子网络的神经网络架构。这些子网络旨在为每个输入生成特征向量,然后可以比较这些向量以估计相似度。关键思想是使用相同的网络处理每个输入,确保输出一致且可比较。

这种架构特别适合于检测重复项、寻找异常和面部识别等任务。在我们将要探索的实现中,网络设置有三个相同的子网络。每个网络处理三张图像中的一张:锚点图像、正样本(与锚点相似)和负样本(与锚点无关)。

什么是三元组损失?

为了有效地训练暹罗网络,我们使用三元组损失函数。这种损失函数鼓励网络在特征空间中拉近锚点和正样本的距离,同时将锚点和负样本推得更远。损失函数定义如下:

L(A, P, N) = max(‖f(A) — f(P)‖² — ‖f(A) — f(N)‖² + margin, 0)

这里,A是锚点图像,P是正图像,N是负图像。函数f(x)代表网络生成的embedding,而margin是一个小的正值,有助于确保网络不会将所有嵌入压缩到同一点。

设置暹罗网络

在这次实现中,我们首先加载Totally Looks Like数据集,其中包含我们用来创建训练网络的三元组图像。

1. 数据准备

使用TensorFlow的tf.data API处理数据集以创建图像三元组。这涉及到设置一个数据管道,其中每个三元组由锚点、正样本和负样本图像组成。通过调整图像大小到目标形状并归一化像素值来预处理图像。

复制
def preprocess_image(filename):
    image_string = tf.io.read_file(filename)
    image = tf.image.decode_jpeg(image_string, channels=3)
    image = tf.image.convert_image_dtype(image, tf.float32)
    image = tf.image.resize(image, target_shape)
    return image

def preprocess_triplets(anchor, positive, negative):
    return (
        preprocess_image(anchor),
        preprocess_image(positive),
        preprocess_image(negative),
    )

以下是从数据集中生成的三元组示例,每行的前两张图像相似(锚点和正样本),第三张不同(负样本):

图像相似度估计 | 结合三元组损失的暹罗网络

图1:在数据准备期间生成的三元组。每行的前两张图像相似(锚点和正样本),第三张不同(负样本)

2.构建 embedding 生成器

我们暹罗网络的核心是嵌入生成器,它使用在ImageNet上预训练的ResNet50模型构建。通过冻结ResNet50中的大部分层的权重,并且仅微调最后几层,我们可以利用迁移学习来减少训练时间并提高性能。

复制
base_cnn = resnet.ResNet50(
    weights="imagenet", input_shape=target_shape + (3,), include_top=False
)

flatten = layers.Flatten()(base_cnn.output)
dense1 = layers.Dense(512, activation="relu")(flatten)
dense1 = layers.BatchNormalization()(dense1)
dense2 = layers.Dense(256, activation="relu")(dense1)
dense2 = layers.BatchNormalization()(dense2)
output = layers.Dense(256)(dense2)

embedding = Model(base_cnn.input, output, name="Embedding")

# Freeze all layers until the layer conv5_block1_out
trainable = False
for layer in base_cnn.layers:
    if layer.name == "conv5_block1_out":
        trainable = True
    layer.trainable = trainable

3.构建暹罗网络

暹罗网络设置为一次输入三张图像(锚点、正样本和负样本)。自定义的DistanceLayer计算锚点-正样本对和锚点-负样本对之间的距离。然后训练模型以最小化相似图像之间的距离,并最大化不相似图像之间的距离。

复制
class DistanceLayer(layers.Layer):
    def call(self, anchor, positive, negative):
        ap_distance = tf.reduce_sum(tf.square(anchor - positive), -1)
        an_distance = tf.reduce_sum(tf.square(anchor - negative), -1)
        return (ap_distance, an_distance)

anchor_input = layers.Input(name="anchor", shape=target_shape + (3,))
positive_input = layers.Input(name="positive", shape=target_shape + (3,))
negative_input = layers.Input(name="negative", shape=target_shape + (3,))

distances = DistanceLayer()(
    embedding(resnet.preprocess_input(anchor_input)),
    embedding(resnet.preprocess_input(positive_input)),
    embedding(resnet.preprocess_input(negative_input)),
)

siamese_network = Model(
    inputs=[anchor_input, positive_input, negative_input], outputs=distances
)

4.训练和评估

模型使用自定义训练循环进行训练,其中计算三元组损失并用于更新网络的权重。仔细监控训练过程,并通过对学习到的嵌入进行检查来评估模型的性能。

复制
class SiameseModel(Model):
    def __init__(self, siamese_network, margin=0.5):
        super(SiameseModel, self).__init__()
        self.siamese_network = siamese_network
        self.margin = margin
        self.loss_tracker = metrics.Mean(name="loss")

    def train_step(self, data):
        with tf.GradientTape() as tape:
            loss = self._compute_loss(data)
        gradients = tape.gradient(loss, self.siamese_network.trainable_weights)
        self.optimizer.apply_gradients(
            zip(gradients, self.siamese_network.trainable_weights)
        )
        self.loss_tracker.update_state(loss)
        return {"loss": self.loss_tracker.result()}

    def _compute_loss(self, data):
        ap_distance, an_distance = self.siamese_network(data)
        loss = ap_distance - an_distance
        loss = tf.maximum(loss + self.margin, 0.0)
        return loss

5.检查结果

训练完成后,我们可以通过比较锚点-正样本对和锚点-负样本对的嵌入之间的余弦相似度来评估网络学习分离相似和不相似图像的能力。

复制
cosine_similarity = metrics.CosineSimilarity()

positive_similarity = cosine_similarity(anchor_embedding, positive_embedding)
print("Positive similarity:", positive_similarity.numpy())

negative_similarity = cosine_similarity(anchor_embedding, negative_embedding)
print("Negative similarity:", negative_similarity.numpy())

以下是经过训练的模型评估的三元组示例。网络成功识别出图像之间的相似性和差异:

图像相似度估计 | 结合三元组损失的暹罗网络

图2:经过训练的暹罗网络的输出,其中每行的前两张图像被模型识别为相似,第三张为不同

结论

本文展示了使用三元组损失的暹罗网络如何有效地估计图像相似度。通过使用预训练的ResNet50模型并微调其层,我们可以创建一个可以应用于需要相似度估计的各种任务。

完整代码和解释,参考:https://github.com/elcaiseri/Siamese-Network

相关资讯

描述液体和软物质的AI方法,开启密度泛函理论新篇章

编辑 | 白菜叶拜罗伊特大学(Universität Bayreuth)的科学家开发了一种利用人工智能研究液体和软物质的新方法,开启了密度泛函理论的新篇章。我们生活在一个高度技术化的世界,在这个密集而复杂的相互关联的网络中,基础研究是创新发展的引擎。这里的新方法,可以对广泛的模拟技术产生巨大影响,从而可以在计算机上更快、更精确、更深入地研究复杂物质。将来,这可能会对产品和工艺设计产生影响。新制定的神经数学关系可以很好地表示液体的结构,这一事实是一项重大突破,为获得深入的物理见解开辟了一系列可能性。「在这项研究中,我

一种实现符号钢琴音乐声音和谱表分离的GNN新方法

译者 | 朱先忠审校 | 重楼本文涵盖了我最近在ISMIR 2024上发表的论文《聚类和分离:一种用于乐谱雕刻的声音和谱表预测的GNN方法》的主要内容。 简介以MIDI等格式编码的音乐,即使包含量化音符、拍号或小节信息,通常也缺少可视化的重要元素,例如语音和五线谱信息。 这种限制也适用于音乐生成、转录或编曲系统的输出。

时间序列模型的演变:人工智能引领新的预测时代

译者 | 布加迪审校 | 重楼我们正处于这样一个时代:大型基础模型(大规模通用神经网络以无监督的方式使用大量不同的数据进行预训练)彻底改变计算机视觉、自然语言处理以及最近的时间序列预测等领域。 这种模型通过实现零样本预测来重塑时间序列预测领域,允许使用新的、未见过的数据进行预测,无需针对每个数据集进行重新训练。 这一突破显著缩减了开发时间和成本,简化了为不同任务创建和微调模型的过程。