四个基站的位置坐标为:
center1 = (80, 80)
center2 = (80, 280)
center3 = (280, 80)
center4 = (280, 280)
假设一个点 pos=[x,y]
为了进行区域高维表示(embedding),需要将其扩展为6个维度,pos’ = [x,y,x’,y’,h,w]
选择一个基站为例:四个象限,每个象限根据 (x-center)*(y-center)的正负判断,然后做出对应的向中心的靠拢的2像素宽度的矩形框。
def get_box(self,pos,center,w):
"""
position: [x,y]
center: [x0,y0]
w: constant stands for the box's weight and height
return: [x,y,x',y',w,w]
"""
position_b1 = [e for e in pos]
result = []
for position in position_b1:
X = position[0]-center[0]
Y = position[1]-center[1]
condition = X*Y
if condition > 0 :
if X > 0:
result.append([position[0]-2, position[1]-2, w, w])
else:
result.append([position[0]+2, position[1]+2, w, w])
else:
if X > 0:
result.append([position[0]-2, position[1]+2, w, w])
else:
result.append([position[0]+2, position[1]-2, w, w])
result = torch.tensor(result).to(pos.device)
return torch.cat([pos, result],-1)
然后再输入到embedding的代码中:参考:https://zhuanlan.zhihu.com/p/351299548
def coordinate_embeddings(self,boxes, dim):
"""
Coordinate embeddings of bounding boxes
:param boxes: [K, 6] ([x1, y1, x2, y2, w_image, h_image])
:param dim: sin/cos embedding dimension
:return: [K, 4, 2 * dim]
"""
num_boxes = boxes.shape[0]
w = boxes[:, 4]
h = boxes[:, 5]
# transform to (x_c, y_c, w, h) format
boxes_ = boxes.new_zeros((num_boxes, 4))
boxes_[:, 0] = (boxes[:, 0] + boxes[:, 2]) / 2
boxes_[:, 1] = (boxes[:, 1] + boxes[:, 3]) / 2
boxes_[:, 2] = boxes[:, 2] - boxes[:, 0]
boxes_[:, 3] = boxes[:, 3] - boxes[:, 1]
boxes = boxes_
# position
pos = boxes.new_zeros((num_boxes, 4))
pos[:, 0] = boxes[:, 0] / w * 100
pos[:, 1] = boxes[:, 1] / h * 100
pos[:, 2] = boxes[:, 2] / w * 100
pos[:, 3] = boxes[:, 3] / h * 100
# sin/cos embedding
dim_mat = 1000 ** (torch.arange(dim, dtype=boxes.dtype, device=boxes.device) / dim)
sin_embedding = (pos.view((num_boxes, 4, 1)) / dim_mat.view((1, 1, -1))).sin()
cos_embedding = (pos.view((num_boxes, 4, 1)) / dim_mat.view((1, 1, -1))).cos()
return torch.cat((sin_embedding, cos_embedding), dim=-1)