#!/usr/bin/env python

import logging
from optimizer import Optimizer
from tqdm import tqdm

# Setup logging.
logging.basicConfig(
    format='%(asctime)s - %(levelname)s - %(message)s',
    datefmt='%m/%d/%Y %I:%M:%S %p',
    level=logging.DEBUG,
    filename='log.txt'
)

def train_networks(networks):
    """Train each network.

    Args:
        networks (list): Current population of networks
    """
    pbar = tqdm(total=len(networks))
    for network in networks:
        network.train()
        pbar.update(1)
    pbar.close()

def get_average_accuracy(networks):
    """Get the average accuracy for a group of networks.

    Args:
        networks (list): List of networks

    Returns:
        float: The average accuracy of a population of networks.

    """
    total_accuracy = 0
    for network in networks:
        total_accuracy += network.accuracy

    return total_accuracy / len(networks)

def generate(generations, population, nn_param_choices):
    """Generate a network with the genetic algorithm.

    Args:
        generations (int): Number of times to evole the population
        population (int): Number of networks in each generation
        nn_param_choices (dict): Parameter choices for networks

    """
    optimizer = Optimizer(nn_param_choices)
    networks = optimizer.create_population(population)

    # Evolve the generation.
    for i in range(generations):
        logging.info("***Doing generation %d of %d***" %
                     (i + 1, generations))

        # Train and get accuracy for networks.
        train_networks(networks)

        # Get the average accuracy for this generation.
        average_accuracy = get_average_accuracy(networks)

        # Print out the average accuracy each generation.
        logging.info("Generation average: %.2f%%" % (average_accuracy * 100))
        logging.info('-'*80)

        # Evolve, except on the last iteration.
        if i != generations - 1:
            # Do the evolution.
            networks = optimizer.evolve(networks)

    # Sort our final population.
    networks = sorted(networks, key=lambda x: x.accuracy, reverse=True)

    # Print out the top 5 networks.
    print_networks(networks[:5])

def print_networks(networks):
    """Print a list of networks.

    Args:
        networks (list): The population of networks

    """
    logging.info('-'*80)
    for network in networks:
        network.print_network()

def main():
    """Evolve a network."""
    generations = 7  # Number of times to evolve the population.
    population = 5  # Number of networks in each generation.

    nn_param_choices = {
        'nb_neurons': [8, 16, 32, 64, 128, 256, 512, 768, 1024],
        'nb_layers': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
        'activation': ['relu', 'elu', 'tanh', 'sigmoid'],
        #'optimizer': ['rmsprop', 'adam', 'sgd', 'adagrad',
        #              'adadelta', 'adamax', 'nadam'],
        #'optimizer_opts': {'lr': [0.1, 0.5, 1.0, 10.0, 100.0],
        #                    'decay': [1e-1, 1e-2, 1e-3, 1e-4, 1e-5, 1e-6, 1e-7],
        #                    'momentum': [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7,
        #                        0.8, 0.9, 1.0]},
        'lr': [0.1, 0.5, 1.0, 10.0, 100.0],
        'decay': [1e-1, 1e-2, 1e-3, 1e-4, 1e-5, 1e-6, 1e-7],
        'momentum': [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]

    }

    logging.info("***Evolving %d generations with population %d***" %
                 (generations, population))

    generate(generations, population, nn_param_choices)

if __name__ == '__main__':
    main()