// -----------------------------------------------------------------------
// Author: Jianxin wu (wujx2001@gmail.com)
// This file is provided to simplify the usage of the provided
// "Visual Place Categorization" dataset.
// For academic and research purpose only.
// -----------------------------------------------------------------------
// Please read carefully the comments to variable and functions for usage.
// -----------------------------------------------------------------------
// This code is provided for your convenience.
// Use at your own risk. Absolutely no warranty is given or implied.
// -----------------------------------------------------------------------

#include <iostream>
#include <fstream>
#include <string>
#include <sstream>
#include <cassert>
#include <algorithm>

#include "util.h"
#include "mdarray.h"
#include "VPC.h"

// top level directory where the VPC data resides.
// NOTE: make sure to change this variable to your own VPC dataset directory 
const char* VPC_BaseDir = "/home/VPC_Data/";
// use the below line instead if you are using Windows
//const char* VPC_BaseDir = "c:\\VPC_Data\\";

// help function that clear the contents of a VPC_Input
void VPC_Input::Clear()
{
    dirs.clear();
    for(int i=0;i<num_dirs;i++) labels[i].clear();
    labels.clear();
    num_dirs = 0;
}

// Mapping from a semantic category name string to an index number. 
// return -1 if an error happens (name not in category list)
int VPC_CategoryToIndex(const std::string category_name)
{
    int index = -1;
    for(unsigned int i=0;i<sizeof(VPC_Categories)/sizeof(std::string);i++)
    {
        if(category_name==VPC_Categories[i])
        {
            index = i;
            break;
        }
    }
    return index;
}

// Read the groundtruth labels into 'input. Labels are in 'label.txt' in each directory.
// e.g. if 'home'=1, then the directory "Home1/label.txt" contains labels
// Label format please refer to the README of the VPC dataset
// Return value: 0: successful; -1: one of the category name is incorrect.
int VPC_ReadLabelFile(const int home,VPC_Input& input)
{
    assert(home>=1 && home<=VPC_numHomes);
    std::ostringstream ss;
    ss.str(""); ss<<VPC_BaseDir<<"Home"<<home<<"/label.txt"; // home=0 ==> 1st home ==> sub-directory "Home1", etc.
    // clear the old contents of VPC_Input
    if(input.num_dirs!=0) input.Clear();
    // read new information
    std::ifstream in(ss.str().c_str());
    bool finished = false;
    input.num_dirs = 0;
    do // process a sub-directory (aka a floor in a home)
    {
        std::string dir_buf,buf;
        int size;
        int start=0, end=0;
        int old_start=0,old_end=0;
        in>>dir_buf; // sub-directory name, "-1" means "stop here"
        if(dir_buf[0]=='-' && dir_buf[1]=='1')
            finished = true; // current sub-directory is finished
        else
        {
            ss.str(""); ss<<VPC_BaseDir<<"Home"<<home<<'/'<<dir_buf;
            input.dirs.push_back(ss.str()); // sub-directory name
            in>>size; assert(size>0); // number of frames in this sub-directory (floor)
            size++; // difference between 0 based array and 1 based array
            input.labels.push_back(std::vector<int>()); // push a new floor
            input.labels[input.num_dirs].resize(size); // number of frames in this floor
            std::fill_n(input.labels[input.num_dirs].begin(),size,0); // default value '0' means "transition"
            do
            {
                in>>start>>end>>buf; // starting frame index, ending frame index, and category name of a video segment
                if(start>0) // "start=-1" means ending of a floor
                {
                    assert(end>start && start==old_end+1);
                    int l = VPC_CategoryToIndex(buf);
                    if(l<0) return -1; // error case -- aborted
                    // the 'start' and 'end' are 1-based, so we need to use 'start-1' in a C array
					std::fill(&input.labels[input.num_dirs][start-1],&input.labels[input.num_dirs][end],l);
                    old_start = start; old_end = end;
                }
            }while(start>0);
            input.num_dirs++;
        }
    } while(finished==false);
    in.close();
    // successful
    return 0;
}

// Get full pathname of a frame using: input (representing a home), dir (floor index), & image (frame index inside a floor)
// Return -1 when error happens
int VPC_GetFileName(VPC_Input& input, const int dir,const int image,std::string& filename)
{
    // image filename starts from '00000001.jpg' -- there is an offset
    assert(dir>=0 && dir<input.num_dirs);
    assert(image>=1 && image<=(int)input.labels[dir].size());
    // construct the file name
    std::ostringstream ss;
    ss<<input.dirs[dir]; ss.fill('0'); ss.width(8); ss<<image<<".jpg";
    // check whether image file exists
    if(FileExists(ss.str().c_str())==true)
    {
        filename = ss.str();
        return 0;
    }
    else
        return -1;
}

// Compute the number of frames that come from a home that is in 'homes'
// & whose semantic category is in "category"
// homes[i]=true means that it is being used, e.g. homes[0]==true ==> "Home1/" is used
// category[i]=true means that this categor is used, e.g. category[1]==true ==> "bedroom" is included
int VPC_ValidCount(std::vector<bool>& homes,std::vector<bool>& category)
{
    assert(VPC_numHomes==(int)homes.size());
    assert(VPC_numCategory==(int)category.size());

    int valid_count = 0;
    VPC_Input input;
    for(unsigned int i=0;i<homes.size();i++) // run through all homes
    {
        if(homes[i]==false) continue;
        // NOTE: home=0 ==> "Home1", so we use 'i+1' instead of 'i'
        int r = VPC_ReadLabelFile(i+1,input);
        assert(r==0);
        // run through all floors inside this home
        for(int j=0;j<input.num_dirs;j++)
        {
            // run through all frames inside this floor
            for(unsigned int k=0;k<input.labels[j].size();k++)
                if(category[input.labels[j][k]]==true)
                    valid_count++;
        }
    }
    // make sure that at least one component of 'homes' AND 'category' is true.
    assert(valid_count>0);
    return valid_count;
}

// Providing a way to sequentially traverse through specified homes and room types
// so that you can run your method on the VPC dataset
// e.g. we recommend a leave-one-out cross validation, then use the following code
// -------------------------------------
/*    std::vector<bool> homes,category;
    homes.resize(VPC_numHomes);
    category.resize(VPC_numCategory);
    // use a subset of the categories (room types)
    categories[1]=categories[2]=categories[3]=categories[5]=categories[6]=true;
    for(int round=0;round<VPC_numHomes;round++)
    {   // use homes except 'home' as training examples
        std::fill(homes.begin(),homes.end(),true); homes[i]=false;
        // put approriate code in VPC_Traverse to do training
        VPC_Traverse(homes,category);
        // use 'round' as testing
        std::fill(homes.begin(),homes.end(),false); homes[i]=true;
        // put approriate code in VPC_Traverse to do testing
        VPC_Traverse(homes,category);
    }*/
// The above code assume that each frame are processed independently. If you use a different strategy,
// then at least VPC_Traverse provides a ways to sequentially run through all frames.
// -------------------------------------
int VPC_Traverse(std::vector<bool>& homes,std::vector<bool>& category)
{
    assert(VPC_numHomes==(int)homes.size());
    assert(VPC_numCategory==(int)category.size());

    VPC_Input input;
    std::string filename;
    for(unsigned int i=0;i<homes.size();i++) // run through all homes
    {
        if(homes[i]==false) continue;
        // NOTE: home=0 ==> "Home1", so we use 'i+1' instead of 'i'
        int r = VPC_ReadLabelFile(i+1,input);
        assert(r==0);
        // run through all floors inside this home
        for(int j=0;j<input.num_dirs;j++)
        {
            // run through all frames inside this floor
            for(unsigned int k=0;k<input.labels[j].size();k++)
            {
                if(category[input.labels[j][k]]==true)
                {
                    VPC_GetFileName(input,j,k+1,filename);
                    // NOTE: the frame you want is stored in 'filename', so apply your method here.
                    // Do something -- either extract features for feature, or testing on this frame
                    // Information about this frame:
                    //   home[i], e.g. i==0, then from "Home1/"
                    //   floor j, the first sub-directory (usually "1/") has floor=0, etc.
                    //   frame k, e.g. k==3, then 4th frame 00000004.jpg
                    //   category index is input.labels[j][k], e.g. input.labels[j][k]==1, means bedroom
                    //   if you prefer not using C or C++, you could print out these information
                    //     to a text file, and then read from this text file using your favorite, e.g. Matlab as belows
                    // ---------------
                    // use std::ofstream out("database.txt") to open a output text file
                    // and write to the database:
                    //    out<<i<<" "<<j<<" "<<k<<" "<<input.labels[j][k]<<" "<<filename<<std::endl;
                    // ---------------
                }
            }
        }
    }
    return 0;
}

// This function provides a suggested way to evaluate your VPC method when frames are considered independently.
// The purpose of this "per-frame" accuracy is to provide a BASELINE for evaluation.
// --------------------------------------------------------------------------------------------------------------------
// We assume that a leave-one-out cross validation method is used, then each frame is tested exactly once.
// We assume all homes are used, but you can choose a subset of the semantic room categories.
// Test output are stored in the file 'outputfile', in the following format
// The first line prints whether a category is used
//       e.g. if only bedroom and bathroom are used, then 1st line should be  "0 1 1 0 0 .."
// from the 2nd line to the last line, each line print the prediction of a frame (a number between [0..VPC_numCategory-1])
// These predictions must be in the same order as those produced by the VPC_Traverse function.
// NOTE: ONLY frames whose groundtruth category is used will printed in this file.
// I do not check the validity of this file -- you need to make sure the format is correct.
// Hint: use VPC_Traverse to produce your output file to make sure a right format.
// Output: will be a confusion matrix, an overall accuracy, and per category accuracies.
// --------------------------------------------------------------------------------------------------------------------
// To generate such a file, codes from the VPC_Traverse function can be used,
// e.g. (category[i]=true means i-th category is selected
// --------------------------------------------------------------------------------------------------------------------
/*    std::ofstream out("out.txt");
    for(int i=0;i<VPC_numCategory;i++) out<<category[i]<<" "; out<<std::endl;
    for(int i=0;i<VPC_numHomes;i++)
    {
        std::fill(homes.begin(),homes.end(),true);
        homes[i]=false;
        // Train your system (may use code from VPC_Traverse)
        std::fill(homes.begin(),homes.end(),false);
        homes[i]=true;
        // Test your system (may use code from VPC_Traverse)
        // Write the prediction result for a single frame in a single line
    }
    out.close();*/
// --------------------------------------------------------------------------------------------------------------------
// NOTE: We calculate the numbers as follows. If you write your own code for evaluation, we recommend that you adopt the same strategy.
// We calculate the confusion matrix of when each home is used for testing
// Then we (element-wise) add the confusion matrices altogether, and (element-wise) divide by number of homes
// The overall accuracy is reported as average per-class accuracies, aka, average of diagonal elements of the final confusion matrix.
// --------------------------------------------------------------------------------------------------------------------
void VPC_Evaludate_Single_Frame(const char* outputfile)
{
    Array2dC<double> confusion_all(VPC_numCategory,VPC_numCategory);
    confusion_all.Zero();

    std::ifstream in(outputfile);
    // read which categories are used
    std::vector<int> category(VPC_numCategory);
    for(int i=0;i<VPC_numCategory;i++) in>>category[i];
    int num_category = 0; // how many categories are selected?
    for(int i=0;i<VPC_numCategory;i++) if(category[i]!=0) num_category++;
    // process all frames from all homes and selected categories using code from VPC_Traverse
    std::cout<<std::fixed;
    std::cout.precision(2);
    VPC_Input input;
    double acc; // add per category accuracies together
    for(int i=0;i<VPC_numHomes;i++) // run through all homes
    {
        Array2dC<double> confusion(VPC_numCategory,VPC_numCategory);
        confusion.Zero();
        int r = VPC_ReadLabelFile(i+1,input); assert(r==0);
        // run through all floors inside this home
        for(int j=0;j<input.num_dirs;j++)
        {
            // run through all frames inside this floor
            for(unsigned int k=0;k<input.labels[j].size();k++)
            {
                // only process frames from selected categories
                if(category[input.labels[j][k]]!=0)   
                {
                    int prediction;
                    in>>prediction;
                    // make sure the prediction is inside a valid range -- otherwise something is wrong!
                    assert(prediction>=0 && prediction<VPC_numCategory);
                    confusion.p[input.labels[j][k]][prediction]++;
                }
            }
        }
        // convert frequency count to percentage
        acc = 0.0;
        for(int j=0;j<confusion.nrow;j++)
        {
            double sum = std::accumulate(confusion.p[j],confusion.p[j]+confusion.ncol,0.0);
            if(sum>0)
            {
                sum = 100.0/sum;
                for(int k=0;k<confusion.ncol;k++) confusion.p[j][k] *= sum;
            }
            acc += confusion.p[j][j];
        }
        // print confusion matrix for the case: i-th home as testing, all rest as training
        std::cout<<"Result when Home"<<i+1<<"/ is used for testing."<<std::endl;
        std::cout<<"Confusion matrix in percentage: (row: groundtruth, column: predicted label)"<<std::endl;
        for(int i=0;i<confusion.ncol;i++)
        {
            for(int j=0;j<confusion.ncol;j++)
            {
                std::cout.width(6);
                std::cout<<confusion.p[i][j];
            }
            std::cout<<std::endl;
        }
        std::cout<<"Overall accuracy: ";
        std::cout<<acc/num_category<<'%'<<std::endl;
        std::cout<<"------------------------------------------------------------------------"<<std::endl;
        // add to the overall confusion matrix
        for(int j=0;j<confusion.nrow;j++)
            for(int k=0;k<confusion.ncol;k++)
                confusion_all.p[j][k] += confusion.p[j][k];
    }
    in.close();
    // Make the overall confusion matrix correct
    for(int i=0;i<confusion_all.nrow;i++)
        for(int j=0;j<confusion_all.ncol;j++)
            confusion_all.p[i][j] /= VPC_numHomes;
    // Print the overall confusion matrix
    std::cout<<"Overall results:"<<std::endl;
    std::cout<<"Confusion matrix in percentage: (row: groundtruth, column: predicted label)"<<std::endl;
    for(int i=0;i<confusion_all.nrow;i++)
    {
        for(int j=0;j<confusion_all.ncol;j++)
        {
            std::cout.width(6);
            std::cout<<confusion_all.p[i][j];
        }
        std::cout<<std::endl;
    }
    std::cout<<"------------------------------------------------------------------------"<<std::endl;
    // print the overall and per category accuracy
    acc = 0;
    for(int i=0;i<VPC_numCategory;i++)
    {
        std::cout.width(16);
        std::cout<<std::left<<VPC_Categories[i]; // print category name
        if(std::accumulate(confusion_all.p[i],confusion_all.p[i]+VPC_numCategory,0.0)>0)
        {   // print accuracy
            std::cout<<confusion_all.p[i][i]<<'%'<<std::endl;
            acc += confusion_all.p[i][i];
        }
        else
            std::cout<<"n/a"<<std::endl;
    }
    std::cout<<"------------------------------------------------------------------------"<<std::endl;
    std::cout<<"Overall accuracy: ";
    std::cout<<acc/num_category<<'%'<<std::endl;
}
