改写 caffe convert_mnist_data.cpp

匿名 (未验证) 提交于 2019-12-03 00:27:02

以下代码来自于caffe的examples/mnist目录,去掉一些google flags的内容,并且把命令行参数去掉了,适合入门阅读

改代码能将mnist数据的images和label转换成lmdb数据,目前只在ubuntu linux 下测试。

#include <gflags/gflags.h> #include <glog/logging.h> #include <google/protobuf/text_format.h>  //#if defined(USE_LEVELDB) && defined(USE_LMDB) #include <leveldb/db.h> #include <leveldb/write_batch.h> #include <lmdb.h>  //#endif  #include <stdint.h> #include <sys/stat.h>  #include <fstream>  // NOLINT(readability/streams) #include <string>  #include "boost/scoped_ptr.hpp" #include "caffe/proto/caffe.pb.h" #include "caffe/util/db.hpp" #include "caffe/util/format.hpp"  using namespace caffe;  // NOLINT(build/namespaces) using boost::scoped_ptr; using std::string;  //DEFINE_string(backend, "lmdb", "The backend for storing the result");  uint32_t swap_endian(uint32_t val) {     val = ((val << 8) & 0xFF00FF00) | ((val >> 8) & 0xFF00FF);     return (val << 16) | (val >> 16); }  void convert_dataset(const char* image_filename, const char* label_filename,         const char* db_path, const string& db_backend) {   // Open files   std::ifstream image_file(image_filename, std::ios::in | std::ios::binary);   std::ifstream label_file(label_filename, std::ios::in | std::ios::binary);    uint32_t magic;   uint32_t num_items;   uint32_t num_labels;   uint32_t rows;   uint32_t cols;    image_file.read(reinterpret_cast<char*>(&magic), 4);   magic = swap_endian(magic);      label_file.read(reinterpret_cast<char*>(&magic), 4);   magic = swap_endian(magic);      image_file.read(reinterpret_cast<char*>(&num_items), 4);   num_items = swap_endian(num_items);      label_file.read(reinterpret_cast<char*>(&num_labels), 4);   num_labels = swap_endian(num_labels);      image_file.read(reinterpret_cast<char*>(&rows), 4);   rows = swap_endian(rows);      image_file.read(reinterpret_cast<char*>(&cols), 4);   cols = swap_endian(cols);        scoped_ptr<db::DB> db(db::GetDB(db_backend));      db->Open(db_path, db::NEW);      scoped_ptr<db::Transaction> txn(db->NewTransaction());          // Storing to db   char label;   char* pixels = new char[rows * cols];   int count = 0;   string value;    Datum datum;   datum.set_channels(1);   datum.set_height(rows);   datum.set_width(cols);      for (int item_id = 0; item_id < num_items; ++item_id) {     image_file.read(pixels, rows * cols);     label_file.read(&label, 1);     datum.set_data(pixels, rows*cols);     datum.set_label(label);     string key_str = caffe::format_int(item_id, 8);     datum.SerializeToString(&value);      txn->Put(key_str, value);      if (++count % 1000 == 0) {       txn->Commit();     }   }   // write the last batch   if (count % 1000 != 0) {       txn->Commit();   }   //LOG(INFO) << "Processed " << count << " files.";   delete[] pixels;   db->Close(); }     int main() {    const string& db_backend = "lmdb";    const char* my_image_filename = "/mnt/e/ccc/ubuntu/lib/caffe_mnist/train-images.idx3-ubyte";   const char* my_label_filename = "/mnt/e/ccc/ubuntu/lib/caffe_mnist/train-labels.idx1-ubyte";   const char* my_db_path = "/mnt/e/ccc/ubuntu/lib/caffe_mnist/005";  //这个是创建lmdb数据库时候要保存的目录,                                                                      //不能够提前在目录下创建,因为caffe会调用mkdir创建,不然就重名了    convert_dataset(my_image_filename, my_label_filename, my_db_path, db_backend);       return 0; }

标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!