16BJWC结业测试T1 - 数星星 star

2017/01/20

16北京冬令营结业测试第一题。

做法好迷。

题目描述

给定二维平面上的 N 个点,求距离第 K 小的点对间距离。 其中 1 <= K, N <= 1E5,点座标非负且为 int。

做法

听说看 std 可以涨智商。

  1. 读入点。
  2. 预处理:对点集进行遍历,找到一个小于 30 的 L 值 ,将各点座标 (x, y) 进行 trim 变换 (x >> L, y >> L) ,使得各点在 trim 座标系下邻接九宫格内元素个数和大于等于 K+1(换句话说,用玄学把非常离散的数据集中一下,再建个哈希索引)
  3. 玄学搜索:建二维 vector 存点(第一维是哈希化的座标),对点集 (x, y) 进行遍历, 取其 trim 后的索引 (x >> L, y >> L) 在邻接 49 宫格所属的 vector 里进行搜索。
  4. 维护一个堆,堆中存前 K+1 小的最近距离。那么分为这样两种情况
  5. 当前搜索宫格到中心宫格的最小距离大于堆顶(目前第 K+1 小元素),直接剪枝。(类似KD树的剪枝)
  6. 遍历当前搜索宫格中的元素,计算到搜索点之间的距离,更新堆中答案。
  7. 把当前遍历点扔到它对应的二维 vector 中。

icecathy 大佬说当时讲评的时候说的是切格子,然后我看代码理解出来是这样的 ↑

代码

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
#include <iostream>
#include <algorithm>
#include <vector>
#include <queue>
#include <cstring>
using namespace std;
const int MAX_N = 2E5;
typedef unsigned long long ULL;
typedef long long LL;
int N, K;

struct point {
    int x, y;
    point(int a, int b) : x(a), y(b) { }
    point() { }
} Raw[MAX_N];

ULL pack_int(int a, int b) {
    return (ULL(a)<<32) | ULL(b);
}

const int BSIZ = 100003;
struct hashmap {         // ULL -> int
    struct node {
        ULL k; int v; node* n;
    };
    node* BKT[BSIZ];
    node *POOL, *cur;

    hashmap(int N) {
        POOL = new node[N];
        cur = POOL;
        memset(BKT, 0, sizeof(BKT));
    }

    void set(ULL key) {
        ULL x = key % BSIZ;
        node* &p = BKT[x];
        while(p && p->k != key) p = p->n;
        if(p) {
            p->v += 1;
        } else {
            p = cur++;
            p->k = key; p->v = 1;
        }
    }

    int get(ULL key) {
        ULL x = key % BSIZ;
        node* p = BKT[x];
        while(p && p->k != key) p = p->n;
        return p ? p->v : 0;
    }

    ~hashmap() {
        delete []POOL;
    }
};

int L = 0;
void getL() {
    while(L < 30) {
        hashmap cnt(N);
        ULL s = 0;
        for(int i=0;i<N;i++) {
            for(int dx=-1;dx<=1;dx++) {
                for(int dy=-1;dy<=1;dy++) {
                    long long trkey = pack_int((Raw[i].x>>L) + dx, (Raw[i].y>>L) + dy);
                    s += cnt.get(trkey);
                }
            }
            long long key = pack_int(Raw[i].x>>L, Raw[i].y>>L);
            cnt.set(key);
        }

        if(s > ULL(K)) break;
        L += 1;
    }
}

void work() {
    priority_queue<ULL> Q;
    vector<vector<point> > B(BSIZ);

    for(int i=0;i<N;i++) {
        int kx=Raw[i].x>>L, ky = Raw[i].y>>L;
        for(int dx=-3;dx<=3;dx++) {
            for(int dy=-3;dy<=3;dy++) {
                if(int(Q.size()) == K+1 && (ULL(max(0,abs(dx)-1)*max(0,abs(dx)-1)+max(0,abs(dy)-1)*max(0,abs(dy)-1))<<(L<<1)) >= Q.top())         // 类 KD 树剪枝
                    continue;
                ULL key = pack_int(kx + dx, ky + dy);
                vector<point> &C = B[key % BSIZ];
                for(int k=0;k<int(C.size());k++) {
                    LL ofx = LL(Raw[i].x - C[k].x);
                    LL ofy = LL(Raw[i].y - C[k].y);
                    ULL dis = ULL(ofx * ofx + ofy * ofy);
                    if(int(Q.size()) == K+1 && dis < Q.top()) {
                        Q.pop(); Q.push(dis);
                    } else if(int(Q.size()) <= K) {
                        Q.push(dis);
                    }
                }
            }
        }

        vector<point> &C = B[pack_int(kx, ky) % BSIZ];
        C.push_back(Raw[i]);
    }

    cout << Q.top() << endl;
}

int main() {
    scanf("%d %d", &N, &K);

    for(int i=0;i<N;i++) {
        scanf("%d %d", &Raw[i].x, &Raw[i].y);
    }

    getL();

    work();
}

数据生成器

没有这题的数据(但有 std),所以写个了数据生成器。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
#!/usr/bin/env python3
# Usage: ./gen.py {DATA_SCALE} > star.in
import sys
import random

if len(sys.argv) < 2:
    sys.exit(-1)

DATA_SIZE = int(sys.argv[1])

K_TH = random.randrange(1, DATA_SIZE - 1)

print(DATA_SIZE, K_TH)

for i in range(DATA_SIZE):
    print(random.randrange(1, 1E9), random.randrange(1, 1E9))