function [cnn_feat, savefile, valid] = get_CNN_features(opts, dataset, split, filename)
% precompute image features
% input:
%   opts      options structure
%   dataset   dataset name
%   split     'train' or 'test'
%   filename  image names (N)
% output:
%   cnn_feat  Nxfeat_dim matrix
%
assert(ismember(split, {'Train', 'Test'}));
savefile = fullfile(opts.cachedir, [dataset '_' split '_' opts.cnn_arch]);
if ~isempty(opts.cnn_idstr)
    savefile = [savefile '_' opts.cnn_idstr];
end
if opts.cnn_use_fc7
    % note: need to make *_fc7.prototxt on our own
    savefile = [savefile '_fc7.mat'];
    net_str  = [opts.cnn_arch '_fc7'];
    feat_dim = 4096;
else
    savefile = [savefile '_deploy.mat'];
    net_str  = 'deploy';
    if strcmp(opts.cnn_arch, 'VGG16_SalObjSub'), feat_dim = 5;
    elseif strcmp(opts.cnn_arch, 'STATIC_ft') || strcmp(opts.cnn_arch, 'VizWiz_ft'), 
        feat_dim = 2;
    else, feat_dim = 1000; end
end

if exist(savefile, 'file')
    d = load(savefile);
    filename = d.filename;
    cnn_feat = d.cnn_feat;
    if isfield(d, 'valid')
        valid = d.valid;
    else
        valid = [];
    end
else
    dataset
    addpath('extern/caffe/matlab');

    % prepare network
    net_model = sprintf('%s/%s/%s.prototxt', opts.modeldir, opts.cnn_arch, net_str);
    net_weights = sprintf('%s/%s/%s', opts.modeldir, opts.cnn_arch, opts.cnn_arch);
    if ~isempty(opts.cnn_idstr)
        net_weights = [net_weights '_' opts.cnn_idstr];
    end
    net = matcaffe_init(1, net_model, [net_weights '.caffemodel']);

    % prepare list of images
    imdir = fullfile(opts.datadir, dataset, split);
    l = [dir([imdir '/*.png']); dir([imdir '/*.jpg'])];
    filename = {l.name};
    image_files = cellfun(@(f) fullfile(imdir, strtrim(f)), filename, 'uniform', false);
    valid = true(1, length(filename));

    % do forward
    batch_size = 10;
    mean_data  = [];
    if ~isempty(strfind(opts.cnn_arch, 'VGG')) || strcmp(opts.cnn_arch, 'STATIC_ft') || ...
            strcmp(opts.cnn_arch, 'VizWiz_ft')
        % VGG: convert mean pixel (BGR!) to mean image (simply replicate)
        crop_dim = 224;
        mean_data = zeros(256, 256, 3);
        mean_data(:, :, 1) = 103.939;
        mean_data(:, :, 2) = 116.779;
        mean_data(:, :, 3) = 123.68;
    elseif strcmp(opts.cnn_arch, 'bvlc_reference_caffenet')
        % load mean image for alexnet
        d = load('extern/caffe/matlab/+caffe/imagenet/ilsvrc_2012_mean.mat');
        mean_data = d.mean_data;
        crop_dim = 227;
    else
        error(['unsupported cnn_arch: ' opts.cnn_arch]);
    end
    [cnn_feat, list_im] = matcaffe_batch(net, image_files, batch_size, ...
        feat_dim, crop_dim, mean_data);
    cnn_feat = cnn_feat';

    % cache
    save(savefile, 'filename', 'cnn_feat', 'valid');
end
end


%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% init Caffe Matlab interface
function net = matcaffe_init(use_gpu, net_model, net_weights)
if use_gpu
    caffe.set_mode_gpu();
    gpu_id = 0;  % use first GPU
    caffe.set_device(gpu_id);
else
    caffe.set_mode_cpu();
end
net = caffe.Net(net_model, net_weights, 'test');
end


%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
function [scores, list_im] = matcaffe_batch(net, list_im, batch_size, ...
    FEAT_DIM, CROPPED_DIM, IMAGE_MEAN)
% forward batch of images through caffe
%
% input
%   net      	   Caffe net (already initialized by matcaffe_init)
%   list_im      list of images files
%   use_gpu      1 to use the GPU, 0 to use the CPU
%   FEAT_DIM     default 4096, dim of resulting feature
%   CROPPED_DIM  target rehape size
%   IMAGE_MEAN   mean image to subtract
%
% output
%   scores       dim x num_images output matrix
%   list_im      return as-is
%
% You may need to do the following before you start matlab:
%  $ export LD_LIBRARY_PATH=/opt/intel/mkl/lib/intel64:/usr/local/cuda/lib64
%  $ export LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libstdc++.so.6
% Or the equivalent based on where things are installed on your system
%
if ischar(list_im)
    %Assume it is a file contaning the list of images
    filename = list_im;
    list_im = read_cell(filename);
end
% Adjust the batch size and dim to match with deploy.prototxt
%batch_size = 10;
if mod(length(list_im),batch_size)
    warning(['Assuming batches of ' num2str(batch_size) ', rest will be filled with zeros'])
end

% prepare input
num_images = length(list_im);
scores = zeros(FEAT_DIM, num_images, 'single');
num_batches = ceil(length(list_im)/batch_size)

initic = tic;
for bb = 1 : num_batches
    batchtic = tic;
    range = 1+batch_size*(bb-1) : min(num_images, batch_size*bb);
    input_data = prepare_batch(list_im(range), CROPPED_DIM, IMAGE_MEAN, batch_size);
    fprintf('Batch %d/%d  %.2f%% Complete  ETA %.2f sec\n',...
        bb,num_batches,bb/num_batches*100,toc(initic)/bb*(num_batches-bb));

    output_data = net.forward({input_data});
    output_data = squeeze(output_data{1});
    scores(:, range) = output_data(:, mod(range-1,batch_size)+1);
end
toc(initic);

% call caffe.reset_all() to reset caffe
caffe.reset_all();
end


%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
function images = prepare_batch(image_files, CROPPED_DIM, IMAGE_MEAN, batch_size)
num_images = length(image_files);
if nargin < 3
    batch_size = num_images;
end

IMAGE_DIM = 256;
indices = [0, IMAGE_DIM-CROPPED_DIM] + 1;
center = floor(indices(2) / 2)+1;

num_images = length(image_files);
images = zeros(CROPPED_DIM, CROPPED_DIM, 3, batch_size, 'single');

parfor i=1:num_images
    % read file
    fprintf('%c %s\n',13,image_files{i});
    try
        im = imread(image_files{i});
        % resize to fixed input size
        im = single(im);
        im = imresize(im, [IMAGE_DIM IMAGE_DIM], 'bilinear');
        % Transform GRAY to RGB
        if size(im, 3) == 1
            im = cat(3,im,im,im);
        end
        % permute from RGB to BGR (IMAGE_MEAN is already BGR)
        im = im(:, :, [3 2 1]) - IMAGE_MEAN;
        % Crop the center of the image
        images(:,:,:,i) = permute(im(center:center+CROPPED_DIM-1,...
            center:center+CROPPED_DIM-1,:),[2 1 3]);
    catch
        warning('Problems with file',image_files{i});
    end
end
end
