forked from nagadomi/kaggle-lshtc
-
Notifications
You must be signed in to change notification settings - Fork 0
/
ncc_cache.hpp
115 lines (105 loc) · 2.45 KB
/
ncc_cache.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
#ifndef NCC_CACHE_HPP
#define NCC_CACHE_HPP
#include <vector>
#include <map>
#include <algorithm>
#include <cstdio>
// Cache for Nearest Centroid Classifier Results
class NCCCache
{
private:
std::map<int, std::vector<int> > m_cache;
public:
NCCCache(){}
void
set(unsigned int query_doc_id,
const std::vector<int> &results)
{
#ifdef _OPENMP
#pragma omp critical (ncc_cache)
#endif
{
m_cache.insert(std::make_pair(query_doc_id, results));
}
}
bool
get(unsigned int query_doc_id, std::vector<int> &results) const
{
bool ret = false;
#ifdef _OPENMP
#pragma omp critical (ncc_cache)
#endif
{
auto cache = m_cache.find(query_doc_id);
if (cache != m_cache.end()) {
results = cache->second;
ret = true;
}
}
return ret;
}
bool
save(const char *file) const
{
FILE *fp = std::fopen(file, "wb");
if (fp == 0) {
return false;
}
size_t size = m_cache.size();
std::fwrite(&size, sizeof(size), 1, fp);
for (auto cache = m_cache.begin(); cache != m_cache.end(); ++cache) {
int query_doc_id = cache->first;
size = cache->second.size();
std::fwrite(&query_doc_id, sizeof(query_doc_id), 1, fp);
std::fwrite(&size, sizeof(size), 1, fp);
std::fwrite(cache->second.data(), sizeof(int), size, fp);
}
fclose(fp);
return true;
}
bool
load(const char *file)
{
FILE *fp = std::fopen(file, "rb");
if (fp == 0) {
return false;
}
m_cache.clear();
size_t cache_num = 0;
size_t ret = std::fread(&cache_num, sizeof(cache_num), 1, fp);
if (ret != 1) {
std::fprintf(stderr, "NCCCache: %s: invalid format 1\n", file);
fclose(fp);
return false;
}
for (size_t i = 0; i < cache_num; ++i) {
int query_doc_id;
size_t vec_size = 0;
std::vector<int> results;
ret = fread(&query_doc_id, sizeof(query_doc_id), 1, fp);
if (ret != 1) {
std::fprintf(stderr, "NCCCache: %s: invalid format 2\n", file);
fclose(fp);
return false;
}
ret = fread(&vec_size, sizeof(vec_size), 1, fp);
if (ret != 1) {
std::fprintf(stderr, "NCCCache: %s: invalid format 3\n", file);
fclose(fp);
return false;
}
int buffer[vec_size];
ret = fread(&buffer[0], sizeof(int), vec_size, fp);
if (ret != vec_size) {
std::fprintf(stderr, "NCCCache: %s: invalid format 4\n", file);
fclose(fp);
return false;
}
std::copy(buffer, buffer + vec_size, std::back_inserter(results));
m_cache.insert(std::make_pair(query_doc_id, results));
}
fclose(fp);
return true;
}
};
#endif