最近研究了一下kd-tree,並且用python實作了,程式碼見文末連結。
kd-tree是k-dimensional tree的縮寫,主要用來做空間索引,包括範圍搜尋和最近鄰搜尋,2分尋找樹可以看做是kd-tree在一維情況下的一個特例。
構造K-d tree
2分尋找樹是對1維數據做分類索引,但擴充套件到2維的情況下就不太適用。例如要尋找距離(2, 3)最近的點,如果只根據x的值進行2分尋找可能會得到錯誤的結果,因為距離還和y的值有關。這時候就需要用到kd-tree,對x和y 2個維度做劃分。
kd-tree的構造過程是交替的根據當前數據的不同維度進行切分,例如假設為2維度數據(x,y),root節點先根據x軸進行劃分,得到左右2個部份,然後再分別對左右2個部份根據y軸進行劃分,依次進行,直到劃分的區域內只有一個節點。以下是2維K-d tree的示意圖,同理可以擴充套件到N維度空間。
上述構造過程種最需要註意的是如何對數據進行切分,要平均劃分就需要找到當前維度的中間值,有3種獲取中間的方法。
- 根據堆排序或者歸並排序對陣列進行排序,時間復雜度 O(nlog^2n)
- 透過中位數演算法 [1] 找到中間值,時間復雜度 O(n*logn)
- 實作對k維的數據進行排序並且保存,然後利用排序好的資訊。時間復雜度 O(k*n*logn)
下面的範例程式碼是采用方法1構造K-d tree:
def
get_k_median
(
arr
,
axis
):
arr
.
sort
(
key
=
itemgetter
(
axis
))
return
len
(
arr
)
//
2
def
_construct
(
self
,
points
,
depth
)
->
Node
:
if
not
points
:
return
None
axis
=
depth
%
self
.
_k
median
=
get_k_median
(
points
,
axis
)
node
=
Node
(
points
[
median
],
axis
)
node
.
left
=
self
.
_construct
(
points
[:
median
],
depth
+
1
)
node
.
right
=
self
.
_construct
(
points
[
median
+
1
:],
depth
+
1
)
return
node
尋找
k-d tree有以下2種尋找方式:
尋找最近的節點
尋找最近的節點需要用到回溯演算法,先找到葉子節點,計算2者之間的距離,由於葉子節點不一定是最近的節點,因此還需要判斷是否相鄰區域是否比當前葉子節點的距離更短。
def
_query_nearest
(
point
,
node
,
nearest
,
distance
):
if
not
node
:
return
if
point
[
node
.
axis
]
<=
node
.
value
:
_query_nearest
(
point
,
node
.
left
,
nearest
,
distance
)
else
:
_query_nearest
(
point
,
node
.
right
,
nearest
,
distance
)
if
point
[
node
.
axis
]
<=
node
.
value
:
if
point
[
node
.
axis
]
+
distance
[
0
]
>
node
.
value
:
# check parent node
t_distance
=
euclidean_distance
(
node
.
element
,
point
)
if
t_distance
<
distance
[
0
]:
distance
[
0
]
=
t_distance
nearest
[
0
]
=
node
.
element
_query_nearest
(
point
,
node
.
right
,
nearest
,
distance
)
else
:
if
point
[
node
.
axis
]
-
distance
[
0
]
<
node
.
value
:
# check parent node
t_distance
=
euclidean_distance
(
node
.
element
,
point
)
if
t_distance
<
distance
[
0
]:
distance
[
0
]
=
t_distance
nearest
[
0
]
=
node
.
element
_query_nearest
(
point
,
node
.
left
,
nearest
,
distance
)
尋找半徑內的節點
和尋找最近的節點型別,只要把距離改成固定的值,如果尋找的節點和給定點的距離在半徑範圍內,則添加,反之亦然,最後得到在指定點半徑範圍內的所有點。
def
_query_radius
(
point
,
node
,
plist
,
radius
):
if
not
node
:
return
if
point
[
node
.
axis
]
<=
node
.
value
:
_query_radius
(
point
,
node
.
left
,
plist
,
radius
)
else
:
_query_radius
(
point
,
node
.
right
,
plist
,
radius
)
if
point
[
node
.
axis
]
<=
node
.
value
:
if
point
[
node
.
axis
]
+
radius
>=
node
.
value
:
t_distance
=
euclidean_distance
(
node
.
element
,
point
)
if
t_distance
<=
radius
:
plist
.
append
(
node
.
element
)
_query_radius
(
point
,
node
.
right
,
plist
,
radius
)
else
:
if
point
[
node
.
axis
]
-
radius
<=
node
.
value
:
t_distance
=
euclidean_distance
(
node
.
element
,
point
)
if
t_distance
<=
radius
:
plist
.
append
(
node
.
element
)
_query_radius
(
point
,
node
.
left
,
plist
,
radius
)
總結
上述只考慮了構造和尋找的過程,還沒有包括添加、刪除和再平衡的方法。當然有一些極端情況也需要補充,例如節點中某一維度有相等值的情況。
kd-tree的完整程式碼實作
參考
- ^https://en.wikipedia.org/wiki/Median_of_medians