#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Tue Jan 11 15:27:24 2022

@author: kurata
"""
import sys
import pickle
import argparse
import pandas as pd

def load_pickle(filename):
    with open(filename, "rb") as fp:
        data = pickle.load(fp)
    return data

def output_csv_pandas(filename, data):
    data.to_csv(filename)

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('-i', '--input', type=str, help='Path of file (pkl)', required=True)
    parser.add_argument('-o', '--out', type=str, help='Path of output directory (pkl)', required=True)
    parser.add_argument('-s', '--sample_number', type=int, help='Sample number in the attention', required=True)
    parser.add_argument('-l', '--layer_number', type=int, choices=[1,2,3], help='Layer number in the attention (Smaller values closer to the input layer)', required=True)
    parser.add_argument('-n', '--head_number', type=int, choices=[1,2,3,4], help='head number in the attention', required=True)
    
    sample_num = parser.parse_args().sample_number
    layer_num = parser.parse_args().layer_number
    head_num = parser.parse_args().head_number
    
    try:
        attention = load_pickle(parser.parse_args().input)
    except:
        print("Can not open the specified file. Please confirm the path and file format")
        sys.exit()
        
    if(len(attention) < sample_num):
        print("The specified sample number is more than the number of query.")
        sys.exit()
    else:
        attention = attention[sample_num - 1]
    
    if(attention == "n"):
        print("The specified sample was not predicted correctly.")
        sys.exit()
        
    attention = attention[layer_num - 1, head_num  - 1, :, :]
    res = pd.DataFrame(attention, index = ["query_" + str(i) for i in range(attention.shape[0])], columns = ["key_" + str(i) for i in range(attention.shape[1])])
    
    try:
        output_csv_pandas(parser.parse_args().out + "/attention_weights_sample_" + str(sample_num) + "_layer" + str(layer_num) + "_head" + str(head_num) + ".csv", res)
    except:
        print("Can not output. Please confirm the output directory")
        sys.exit()
    

























































