亚马逊AWS官方博客

使用 AWS EC2 上的 Apache MXNet 和 Multimedia Commons 数据集来估计图像位置

作者:Jaeyoung Choi 和 Kevin Li | 原文链接

这是由国际计算机科学研究院的 Jaeyoung Choi 和加州大学伯克利分校的 Kevin Li 所著的一篇访客文章。本项目演示学术研究人员如何利用我们的 AWS Cloud Credits for Research Program 实现科学突破。

当您拍摄照片时,现代移动设备可以自动向图像分配地理坐标。不过,网络上的大多数图像仍缺少该位置元数据。图像定位是估计图像位置并应用位置标签的过程。根据您的数据集大小以及提出问题的方式,分配的位置标签可以是建筑物或地标名称或实际地理坐标 (纬度、经度)。

在本文中,我们会展示如何使用通过 Apache MXNet 创建的预训练模型对图像进行地理分类。我们使用的数据集包含拍摄于全球各地的数百万张 Flickr 图像。我们还会展示如何将结果制成地图以直观地显示结果。

我们的方法

图像定位方法可以分为两类:图像检索搜索法和分类法。(该博文将对这两个类别中最先进的方法进行比较。)

Weyand 等人近期的作品提出图像定位是一个分类问题。在这种方法中,作者将地球表面细分为数千个地理单元格,并利用带地理标记的图像训练了深层神经网路。有关他们的试验更通俗的描述,请参阅该文章

由于作者没有公开他们的训练数据或训练模型 (即 PlaNet),因此我们决定训练我们自己的图像定位器。我们训练模型的场景灵感来自于 Weyand 等人描述的方法,但是我们对几个设置作了改动。

我们在单个 p2.16xlarge 实例上使用 MXNet 来训练我们的模型 LocationNet,该实例包含来自 AWS Multimedia Commons 数据集的带有地理标记的图像。

我们将训练、验证和测试图像分离,以便同一人上传的图像不会出现在多个集合中。我们使用 Google 的 S2 Geometry Library 通过训练数据创建类。该模型经过 12 个训练周期后收敛,完成 p2.16xlarge 实例训练大约花了 9 天时间。GitHub 上提供了采用 Jupyter Notebook 的完整教程

下表对用于训练和测试 LocationNet 和 PlaNet 的设置进行了比较。

LocationNet PlaNet
数据集来源 Multimedia Commons 从网络抓取的图像
训练集 3390 万 9100 万
验证 180 万 3400 万
S2 单元分区 t1=5000, t2=500
→ 15,527 个单元格
t1=10,000, t2=50
→ 26,263 个单元格
模型 ResNet-101 GoogleNet
优化 使用动量和 LR 计划的 SGD Adagrad
训练时间 采用 16 个 NVIDIA K80 GPU (p2.16xlarge EC2 实例) 时为 9 天
12 个训练周期
采用 200 个 CPU 内核时为两个半月
框架 MXNet DistBelief
测试集 Placing Task 2016 测试集 (150 万张 Flickr 图像) 230 万张有地理标记的 Flickr 图像

在推理时,LocationNet 会输出地理单元格间的概率分布。单元格中概率最高的图像的质心地理坐标会被分配为查询图像的地理坐标。

LocationNet 会在 MXNet Model Zoo 中公开分享。

下载 LocationNet

现在下载 LocationNet 预训练模型。LocationNet 已使用 AWS Multimedia Commons 数据集中带地理标记的图像子集进行了训练。Multimedia Commons 数据集包含 3900 多万张图像和 15000 个地理单元格 (类)。

LocationNet 包括两部分:一个包含模型定义的 JSON 文件和一个包含参数的二进制文件。我们从 S3 加载必要的软件包并下载文件。

import os

import urllib

import mxnet as mx

import logging

import numpy as np

from skimage import io, transform

from collections import namedtuple

from math import radians, sin, cos, sqrt, asin

path = 'https://s3.amazonaws.com/mmcommons-tutorial/models/'

model_path = 'models/'

if not os.path.exists(model_path):

os.mkdir(model_path)

urllib.urlretrieve(path+'RN101-5k500-symbol.json', model_path+'RN101-5k500-symbol.json')

urllib.urlretrieve(path+'RN101-5k500-0012.params', model_path+'RN101-5k500-0012.params')

然后,加载下载的模型。如果您没有可用 GPU,请将 mx.gpu() 替换为 mx.cpu():

# Load the pre-trained model

prefix = "models/RN101-5k500"

load_epoch = 12

sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, load_epoch)

mod = mx.mod.Module(symbol=sym, context=mx.gpu())

mod.bind([('data', (1,3,224,224))], for_training=False)
mod.set_params(arg_params, aux_params, allow_missing=True)

grids.txt 文件包含用于训练模型的地理单元格。

第 i 行是第 i 个类,列分别代表:S2 单元格标记、纬度和经度。我们将标签加载到名为 grids 的列表中。

# Download and load grids file

urllib.urlretrieve('https://raw.githubusercontent.com/multimedia-berkeley/tutorials/master/grids.txt','grids.txt')

# Load labels.

grids = []

with open('grids.txt', 'r') as f:

for line in f:

line = line.strip().split('\t')

lat = float(line[1])

lng = float(line[2])

grids.append((lat, lng))

该模型使用半径公式来测量点 p1 和 p2 之间的大圆弧距离,以千米为单位:

def distance(p1, p2):

R = 6371 # Earth radius in km

lat1, lng1, lat2, lng2 = map(radians, (p1[0], p1[1], p2[0], p2[1]))

dlat = lat2 - lat1

dlng = lng2 - lng1

a = sin(dlat * 0.5) ** 2 + cos(lat1) * cos(lat2) * (sin(dlng * 0.5) ** 2)
return 2 * R * asin(sqrt(a))

在将图像提供给深度学习网络之前,该模型会通过裁剪以及减去均值来预处理图像:

# mean image for preprocessing

mean_rgb = np.array([123.68, 116.779, 103.939])

mean_rgb = mean_rgb.reshape((3, 1, 1))


def PreprocessImage(path, show_img=False):

# load image.

img = io.imread(path)

# We crop image from center to get size 224x224.

short_side = min(img.shape[:2])

yy = int((img.shape[0] - short_side) / 2)

xx = int((img.shape[1] - short_side) / 2)

crop_img = img[yy : yy + short_side, xx : xx + short_side]

resized_img = transform.resize(crop_img, (224,224))

if show_img:

io.imshow(resized_img)

# convert to numpy.ndarray

sample = np.asarray(resized_img) * 256

# swap axes to make image from (224, 224, 3) to (3, 224, 224)

sample = np.swapaxes(sample, 0, 2)

sample = np.swapaxes(sample, 1, 2)

# sub mean

normed_img = sample - mean_rgb

normed_img = normed_img.reshape((1, 3, 224, 224))
return [mx.nd.array(normed_img)]

评估并比较模型

为了进行评估,我们使用两个数据集:IM2GPS 数据集和 Flickr 图像测试数据集,后者用于 MediaEval Placing 2016 基准测试

IM2GPS 测试集结果

以下值表示 IM2GPS 测试集中正确位于与实际位置的每个距离内的图像的百分比。


Flickr 图像结果

由于 PlaNet 中使用的测试集图像尚未公开发布,因此不能直接比较这些结果。这些值表示测试集中正确位于与实际位置的每个距离内的图像的百分比。


通过目测检查定位图像,我们可以看到该模型不仅在地标位置方面表现出色,而且也能准确定位非标志性场景。

使用 URL 估算图像的地理位置

现在我们试着用 URL 对网页上的图像进行定位。

Batch = namedtuple('Batch', ['data'])

def predict(imgurl, prefix='images/'):

download_url(imgurl, prefix)

imgname = imgurl.split('/')[-1]

batch = PreprocessImage(prefix + imgname, True)

#predict and show top 5 results

mod.forward(Batch(batch), is_train=False)

prob = mod.get_outputs()[0].asnumpy()[0]

pred = np.argsort(prob)[::-1]

result = list()

for i in range(5):

pred_loc = grids[int(pred[i])]

res = (i+1, prob[pred[i]], pred_loc)

print('rank=%d, prob=%f, lat=%s, lng=%s' \

% (i+1, prob[pred[i]], pred_loc[0], pred_loc[1]))

result.append(res[2])

return result


def download_url(imgurl, img_directory):

if not os.path.exists(img_directory):

os.mkdir(img_directory)

imgname = imgurl.split('/')[-1]

filepath = os.path.join(img_directory, imgname)

if not os.path.exists(filepath):

filepath, _ = urllib.urlretrieve(imgurl, filepath)

statinfo = os.stat(filepath)

print('Succesfully downloaded', imgname, statinfo.st_size, 'bytes.')
return filepath

来看看我们的模型如何处理东京塔图片。以下代码从 URL 下载图像,并输出模型的位置预测。

#download and predict geo-location of an image of Tokyo Tower

url = 'https://farm5.staticflickr.com/4275/34103081894_f7c9bfa86c_k_d.jpg'
result = predict(url)

结果列出了置信度分数 (概率) 排在前五位的输出以及地理坐标:

rank=1, prob=0.139923, lat=35.6599344486, lng=139.728919109

rank=2, prob=0.095210, lat=35.6546613641, lng=139.745685815

rank=3, prob=0.042224, lat=35.7098435803, lng=139.810458528

rank=4, prob=0.032602, lat=35.6641725688, lng=139.746648114

rank=5, prob=0.023119, lat=35.6901996892, lng=139.692857396

仅通过原始纬度和经度值,很难判断地理位置输出的质量。我们可以通过将输出制成地图来直观地显示结果。

在 Jupyter Notebook 上使用 Google Maps 直观显示结果

为了直观地显示预测结果,我们可以在 Jupyter Notebook 中使用 Google Maps。它让您能够看到预测是否有意义。我们使用一个名为 gmaps 的插件,它允许我们在 Jupyter Notebook 中使用 Google Maps。要安装 gmaps,请按照 gmaps GitHub 页面上的安装说明操作。

使用 gmaps 直观显示结果只需几行代码。请在您的 Notebook 输入以下内容:

import gmaps


gmaps.configure(api_key="") # Fill in with your API key


fig = gmaps.figure()


for i in range(len(result)):

marker = gmaps.marker_layer([result[i]], label=str(i+1))

fig.add_layer(marker)

fig

事实上,排在第一位的定位估算结果就是东京塔所在的位置。

现在,试着对您选择的图像进行定位吧!

鸣谢

在 AWS 上训练 LocationNet 的工作得到了 AWS 研究与教育计划的大力支持。我们还要感谢 AWS 公共数据集计划托管 Multimedia Commons 数据集以供公众使用。我们的工作也得到了劳伦斯·利弗莫尔国家实验室领导的合作 LDRD 的部分支持 (美国能源部合同 DE-AC52-07NA27344)。