Windows7 64bit VS2013 Caffe train MNIST操作步骤

 1.        使用http://www.voidcn.com/article/p-klesixtj-ux.html中生成的Caffe静态库;

2.        使用http://www.voidcn.com/article/p-pjxaevtt-va.html中生成的LMDB数据库文件;

3.        新建一个train_mnist控制台工程;

4.        修改源文件中的caffe/examples/mnist/lenet_solver.prototxt文件:

(1)、net: "E:/GitCode/Caffe/src/caffe/caffe/examples/mnist/lenet_train_test.prototxt"

(2)、snapshot_prefix:"E:/GitCode/Caffe/src/caffe/caffe/examples/mnist/lenet"

(3)、solver_mode: CPU

5.        修改源文件中的caffe/examples/mnist/lenet_train_test.prototxt文件,指定LMDB数据库文件存放位置:

(1)、source:"E:/GitCode/Caffe/src/caffe/caffe/data/mnist/lmdb/train"

(2)、source:"E:/GitCode/Caffe/src/caffe/caffe/data/mnist/lmdb/test"

6.        train_mnist.cpp文件中内容为(是对caffe/tools/caffe.cpp的修改):

#include "stdafx.h"
#include <iostream>

#include <glog/logging.h>
#include <cstring>
#include <map>
#include <string>
#include <vector>

#include "caffe/common.hpp"
#include "boost/algorithm/string.hpp"
#include "caffe/caffe.hpp"
#include "caffe/util/io.hpp"
#include "caffe/blob.hpp"
#include "caffe/layer_factory.hpp"
#include "boost/smart_ptr/shared_ptr.hpp"

using caffe::Blob;
using caffe::Caffe;
using caffe::Net;
using caffe::Layer;
using caffe::Solver;
using caffe::shared_ptr;
using caffe::string;
using caffe::Timer;
using caffe::vector;
using std::ostringstream;

DEFINE_string(solver, "E:/GitCode/Caffe/src/caffe/caffe/examples/mnist/lenet_solver.prototxt",
	"The solver definition protocol buffer text file.");
DEFINE_string(snapshot, "E:/GitCode/Caffe/src/caffe/caffe/examples/mnist/lenet_iter_10000.solverstate",
	"Optional; the snapshot solver state to resume training.");
DEFINE_string(weights, "E:/GitCode/Caffe/src/caffe/caffe/examples/mnist/xxxx.caffemodel",
	"Optional; the pretrained weights to initialize finetuning, "
	"separated by ','. Cannot be set simultaneously with snapshot.");

// A simple registry for caffe commands.
typedef int(*BrewFunction)();
typedef std::map<caffe::string, BrewFunction> BrewMap;
BrewMap g_brew_map;

#define RegisterBrewFunction(func) \
namespace { \
class __Registerer_##func { \
 public: /* NOLINT */ \
  __Registerer_##func() { \
    g_brew_map[#func] = &func; \
  } \
}; \
__Registerer_##func g_registerer_##func; \
}

static BrewFunction GetBrewFunction(const caffe::string& name) {
	if (g_brew_map.count(name)) {
		return g_brew_map[name];
	}
	else {
		LOG(ERROR) << "Available caffe actions:";
		for (BrewMap::iterator it = g_brew_map.begin(); it != g_brew_map.end(); ++it) {
			LOG(ERROR) << "\t" << it->first;
		}
		LOG(FATAL) << "Unknown action: " << name;
		return NULL;  // not reachable, just to suppress old compiler warnings.
	}
}

// Load the weights from the specified caffemodel(s) into the train and test nets.
void CopyLayers(caffe::Solver<float>* solver, const std::string& model_list) {
	std::vector<std::string> model_names;
	boost::split(model_names, model_list, boost::is_any_of(","));
	for (int i = 0; i < model_names.size(); ++i) {
		LOG(INFO) << "Finetuning from " << model_names[i];
		solver->net()->CopyTrainedLayersFrom(model_names[i]);
		for (int j = 0; j < solver->test_nets().size(); ++j) {
			solver->test_nets()[j]->CopyTrainedLayersFrom(model_names[i]);
		}
	}
}

// Train / Finetune a model.
int train() {
	CHECK_GT(FLAGS_solver.size(), 0) << "Need a solver definition to train.";
	//CHECK(!FLAGS_snapshot.size() || !FLAGS_weights.size()) << "Give a snapshot to resume training or weights to finetune but not both.";

	caffe::SolverParameter solver_param;
	caffe::ReadProtoFromTextFileOrDie(FLAGS_solver, &solver_param);

	Caffe::set_mode(Caffe::CPU);

	shared_ptr<Solver<float> > solver(caffe::GetSolver<float>(solver_param));

	//if (FLAGS_snapshot.size()) { // resume training
	//	LOG(INFO) << "Resuming from " << FLAGS_snapshot;
	//	solver->Restore(FLAGS_snapshot.c_str());
	//}
	//else if (FLAGS_weights.size()) { // finetune
	//	CopyLayers(solver.get(), FLAGS_weights);
	//}

	LOG(INFO) << "Starting Optimization";
	solver->Solve();

	LOG(INFO) << "Optimization Done.";
	return 0;
}
RegisterBrewFunction(train);

int main(int argc, char* argv[])
{
	argc = 2;
#ifdef _DEBUG  
	argv[0] = "E:/GitCode/Caffe/lib/dbg/x86_vc12/train_mnist[dbg_x86_vc12].exe";
#else  
	argv[0] = "E:/GitCode/Caffe/lib/rel/x86_vc12/train_mnist[rel_x86_vc12].exe";
#endif 
	argv[1] = "train";

	// 每个进程中至少要执行1次InitGoogleLogging(),否则不产生日志文件
	google::InitGoogleLogging(argv[0]);
	// 设置日志文件保存目录,此目录必须是已经存在的
	FLAGS_log_dir = "E:\\GitCode\\Caffe";
	FLAGS_max_log_size = 1024;//MB

	// Print output to stderr (while still logging).
	FLAGS_alsologtostderr = 1;
	// Usage message.
	gflags::SetUsageMessage("command line brew\n"
		"usage: caffe <command> <args>\n\n"
		"commands:\n"
		"  train           train or finetune a model\n");
	// Run tool or show usage.
	//caffe::GlobalInit(&argc, &argv);
	// 解析命令行参数  
	gflags::ParseCommandLineFlags(&argc, &argv, true);

	if (argc == 2) {
		return GetBrewFunction(caffe::string(argv[1]))();
	}
	else {
		gflags::ShowUsageWithFlagsRestrict(argv[0], "tools/caffe");
	}

	std::cout << "OK!!!" << std::endl;
	return 0;
}

7.        执行完train_mnist后会生成四个文件:lenet_iter_5000.caffemodel、lenet_iter_5000.solverstate、lenet_iter_10000.caffemodel、lenet_iter_10000.solverstate

8.        运行结果如下图:

相关文章
相关标签/搜索