#!/usr/bin/env python

# File   : $HeadURL$
# Version: $Id$

import sys

from Ai import Ai
import common as com

EPSILON = 1e-3

# optimal epsilon=1.00e-03, C=2.50, kernel_width=24.00

C_LIST = [2.5, 3.0, 3.5]
KERNEL_WIDTH_LIST = [22.0, 24.0, 26.0]

VALIDATION_TRAIN_FRAC = 0.80

def to_str(e):
    result = "" if e['error'] > 1.0 else "error=%.4f, " % e['error']

    return "%sepsilon=%.2e, C=%.2f, kernel_width=%.2f" % (
        result, e['epsilon'], e['c'], e['kernel_width'])

def main(argv):
    ai = Ai()
    ai.load_train_data(com.TRAIN_X_FNAME, com.TRAIN_Y_FNAME)
    ai.enable_validation(VALIDATION_TRAIN_FRAC)

    best_error = {'error': 2.0, 'epsilon': -1.0, 'c': -1.0,
                  'kernel_width': -1.0}
    for kernel_width in KERNEL_WIDTH_LIST:
        for c in C_LIST:
            cur = {'error': 2.0, 'epsilon': EPSILON, 'c': c,
                   'kernel_width': kernel_width}

            print "Trying: %s" % to_str(cur)

            ai.train(kernel_width=kernel_width, c=c,
                     epsilon=EPSILON)
            print ""
            cur['error'] = ai.get_test_error()
            print ""

            if cur['error'] < best_error['error']:
                best_error = cur
                print "New best: %s" % to_str(best_error)
            else:
                print "Result: %s" % to_str(cur)
                print "Best: %s" % to_str(best_error)

            print ""

    print "Finally using parameters: %s" % to_str(best_error)
    ai.load_train_data(com.TRAIN_X_FNAME, com.TRAIN_Y_FNAME)
    ai.train(kernel_width=best_error['kernel_width'],
             c=best_error['c'], epsilon=best_error['epsilon'])
    ai.write_svm()
    print ""
    print "Finished :DD"

if __name__ == '__main__':
    sys.exit(main(sys.argv))
