Skip to content

Commit

Permalink
RALI - fix build error with tot (#339)
Browse files Browse the repository at this point in the history
  • Loading branch information
rrawther authored Jul 20, 2020
1 parent 38c8bb4 commit d0c672f
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 1 deletion.
3 changes: 3 additions & 0 deletions rali/rali/include/tf_record_reader.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ THE SOFTWARE.
#include <iterator>
#include <algorithm>
#include "reader.h"
#include "timing_debug.h"
#include <google/protobuf/message_lite.h>
#include "example.pb.h"
#include "feature.pb.h"
Expand Down Expand Up @@ -60,6 +61,7 @@ class TFRecordReader : public Reader
std::string id() override { return _last_id;};

unsigned count() override;
unsigned long long get_shuffle_time() {return _shuffle_time.get_timing();};

~TFRecordReader() override;

Expand Down Expand Up @@ -106,5 +108,6 @@ class TFRecordReader : public Reader
Reader::Status read_image(unsigned char* buff, std::string record_file_name, uint file_size);
Reader::Status read_image_names(std::ifstream &file_contents, uint file_size);
std::map <std::string, uint> _image_record_starting;
TimingDBG _shuffle_time;
};

5 changes: 4 additions & 1 deletion rali/rali/source/tf_record_reader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ THE SOFTWARE.

namespace filesys = boost::filesystem;

TFRecordReader::TFRecordReader()
TFRecordReader::TFRecordReader():
_shuffle_time("shuffle_time", DBG_TIMING)
{
_src_dir = nullptr;
_sub_dir = nullptr;
Expand Down Expand Up @@ -69,8 +70,10 @@ Reader::Status TFRecordReader::initialize(ReaderConfig desc)
_shuffle = desc.shuffle();
ret = folder_reading();
//shuffle dataset if set
_shuffle_time.start();
if (ret == Reader::Status::OK && _shuffle)
std::random_shuffle(_file_names.begin(), _file_names.end());
_shuffle_time.end();
return ret;
}

Expand Down

0 comments on commit d0c672f

Please sign in to comment.