"""This script implements a permutation test (Fisher 1936, Koplenig 2019) for the 'sequential' measures of
lexical diversity (MSTTR, MTLD).
The input format is a tab separated CSV file that contains a column "Token" containing the tokens of a sub vocabulary
line by line. The result of the permutation test is output on the console.
Important: The tokens in the input file must be ordered by text from which they were extracted!"""

import pandas as pd
import random
import calculate_lexical_diversity as cld
import os

# Change the values of the following variables if required
# a random seed is set so the results are reproducible
random.seed(42)

# the file path to the input files (tab separated CSV files) for which you want to run a permutation test
# info: a permutation test can only be conducted for two corpora
os.chdir("..\\input")
input_file_0 = "..\\input\\Anhang_F.csv"
input_file_1 = "..\\input\\Anhang_G.csv"
# the file path to the folders where the files generated with create_random_sorted_files.py are stored
input_dir_0 = "..\\random_0"
input_dir_1 = "..\\random_1"

# the number of repetitions
repetition = 10000
# MTLD segment size
mtld_seg_size = 6
# MSTTR segment size
msttr_seg_size = 25

token_list_0 = pd.read_csv(input_file_0, sep='\t', encoding='utf-16')["Token"].values.tolist()
token_list_1 = pd.read_csv(input_file_1, sep='\t', encoding='utf-16')["Token"].values.tolist()

# get the total number of tokens in both files
token_amount_0 = len(token_list_0)
token_amount_1 = len(token_list_1)


# A function that calculates the difference between the MSTTR values of the two input corpora.
# return msttr_diff: the difference between the MSTTR values
# return larger_flag: a boolean variable to determine for a permutation test which of the two corpora is the minuend
# and which is the subtrahend
def calculate_msttr_diff():
    os.chdir(input_dir_0)
    msttr_0 = cld.calc_msttr(input_dir_0, msttr_seg_size, msttr_seg_size, 1)[msttr_seg_size]
    os.chdir(input_dir_1)
    msttr_1 = cld.calc_msttr(input_dir_1, msttr_seg_size, msttr_seg_size, 1)[msttr_seg_size]
    if msttr_0 > msttr_1:
        msttr_diff = msttr_0 - msttr_1
        larger_flag = True
    else:
        msttr_diff = msttr_1 - msttr_0
        larger_flag = False
    return msttr_diff, larger_flag


# A function that calculates the difference between the MTLD values of the two input corpora.
# return mtld_diff: the difference between the MTLD values
# return larger_flag: a boolean variable to determine for a permutation test which of the two corpora is the minuend
# and which is the subtrahend
def calculate_mtld_diff():
    os.chdir(input_dir_0)
    mtld_0 = cld.calc_mtld(input_dir_0)
    os.chdir(input_dir_1)
    mtld_1 = cld.calc_mtld(input_dir_1)
    if mtld_0 > mtld_1:
        mtld_diff = mtld_0 - mtld_1
        bigger_flag = True
    else:
        mtld_diff = mtld_1 - mtld_0
        bigger_flag = False
    return mtld_diff, bigger_flag


# A function that conducts a permutation test for the measure MSTTR. The p-value of each measure is output to
# the console.
# param token_list_0: a list containing the tokens of the first corpus
# param token_list_1: a list containing the tokens of the second corpus
def permutation_msttr_test(token_list_0, token_list_1):
    segment_amount_0 = int(len(token_list_0) / msttr_seg_size)
    segment_amount_1 = int(len(token_list_1) / msttr_seg_size)
    seg_sum_amount = int(segment_amount_0 + segment_amount_1)
    token_dict = {}
    diff_msttr_list = []

    index = 0
    start_index = 0
    end_index = msttr_seg_size
    while index < int(len(token_list_0) / msttr_seg_size):
        token_dict[index] = token_list_0[start_index:end_index]
        start_index += msttr_seg_size
        end_index += msttr_seg_size
        index += 1

    last_ind = index
    start_index = 0
    end_index = msttr_seg_size
    while index < int((len(token_list_1) / msttr_seg_size) + last_ind):
        token_dict[index] = token_list_1[start_index:end_index]
        start_index += msttr_seg_size
        end_index += msttr_seg_size
        index += 1

    w,h = msttr_seg_size, seg_sum_amount;
    token_list = list(token_dict.values())

    i = 0

    while i < repetition:
        shuffle_list = token_list.copy()
        random.shuffle(shuffle_list)
        list_0 = shuffle_list[0:segment_amount_0]
        list_1 = shuffle_list[segment_amount_0:seg_sum_amount]
        token_list_0 = []
        token_list_1 = []
        index_0 = 0
        while index_0 < len(list_0):
            index_1 = 0
            while index_1 < len(list_0[index_0]):
                token_list_0.append(list_0[index_0][index_1])
                index_1 += 1
            index_0 += 1

        index_0 = 0
        while index_0 < len(list_1):
            index_1 = 0
            while index_1 < len(list_1[index_0]):
                token_list_1.append(list_1[index_0][index_1])
                index_1 += 1
            index_0 += 1

        msttr_0 = round(cld.calc_msttr_permutation_test(token_list_0, msttr_seg_size), 2)
        msttr_1 = round(cld.calc_msttr_permutation_test(token_list_1, msttr_seg_size), 2)

        if calculate_msttr_diff()[1]:
            diff_msttr_list.append(msttr_0 - msttr_1)
        else:
            diff_msttr_list.append(msttr_1 - msttr_0)
        i += 1

    diff_msttr_ctr = 0
    for diff_msttr in diff_msttr_list:
        if diff_msttr >= calculate_msttr_diff()[0]:
            diff_msttr_ctr += 1
    print("MSTTR p-value")
    print(diff_msttr_ctr / repetition)


# A function that conducts a permutation test for the measure MTLD. The p-value of each measure is output to
# the console.
# param token_list_0: a list containing the tokens of the first corpus
# param token_list_1: a list containing the tokens of the second corpus
def permutation_mtld_test(token_list_0, token_list_1):
    segment_amount_0 = int(len(token_list_0) / mtld_seg_size)
    segment_amount_1 = int(len(token_list_1) / mtld_seg_size)
    sum_amount = int(segment_amount_0 + segment_amount_1)
    token_total_dict = {}
    diff_mtld_list = []
    index = 0
    start_index = 0
    end_index = mtld_seg_size

    while index < int(len(token_list_0) / mtld_seg_size):
        token_total_dict[index] = token_list_0[start_index:end_index]
        start_index += mtld_seg_size
        end_index += mtld_seg_size
        index += 1

    last_ind = index
    start_index = 0
    end_index = mtld_seg_size
    while index < int((len(token_list_1) / mtld_seg_size) + last_ind):
        token_total_dict[index] = token_list_1[start_index:end_index]
        start_index += mtld_seg_size
        end_index += mtld_seg_size
        index += 1

    w,h = msttr_seg_size, sum_amount;
    token_total_list = list(token_total_dict.values())

    i = 0

    while i < repetition:
        shuffle_list = token_total_list.copy()
        random.shuffle(shuffle_list)
        list_0 = shuffle_list[0:segment_amount_0]
        list_1 = shuffle_list[segment_amount_0:sum_amount]

        token_list_0 = []
        token_list_1 = []
        index_0 = 0
        while index_0 < len(list_0):
            index_1 = 0
            while index_1 < len(list_0[index_0]):
                token_list_0.append(list_0[index_0][index_1])
                index_1 += 1
            index_0 += 1

        index_0 = 0
        while index_0 < len(list_1):
            index_1 = 0
            while index_1 < len(list_1[index_0]):
                token_list_1.append(list_1[index_0][index_1])
                index_1 += 1
            index_0 += 1

        mtld_0 = round(cld.calc_mtld_permutation_test(token_list_0), 2)
        mtld_1 = round(cld.calc_mtld_permutation_test(token_list_1), 2)

        if calculate_msttr_diff()[1] == True:
            diff_mtld_list.append(mtld_0 - mtld_1)
        else:
            diff_mtld_list.append(mtld_1 - mtld_0)

        i += 1

    diff_mtld_ctr = 0
    for diff_mtld in diff_mtld_list:
        if diff_mtld >= calculate_mtld_diff()[0]:
            diff_mtld_ctr += 1
    print("MTLD p-value")
    print(diff_mtld_ctr / repetition)


# function call
permutation_msttr_test(token_list_0, token_list_1)
permutation_mtld_test(token_list_0, token_list_1)