# This file is part of Neuroinfo Toolkit.
#
# Neuroinfo Toolkit is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Neuroinfo Toolkit is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Neuroinfo Toolkit.  If not, see <http://www.gnu.org/licenses/>.

from neuro.exceptions import IllegalArgumentException
from neuro.exceptions import DimensionMismatchException
from shogun.Features import *
from shogun.Classifier import *
from shogun.Kernel import *
from numpy import ndarray, array

class SVM:
	'''
	Support Vector Machine classification model (deprecated)
	'''
	
	def __init__(self):
		'''
		Constructor
		'''
		self._training = None
		self._labels = None
		self._test = None
		self._cost = 1.0
		self._width = 2.1
		
	def setCost(self, cost=1.0):
		'''
		Set the SVM cost parameter
		
		:param cost: 
		:type cost: int, float
		'''
		if(not isinstance(cost, float) or not isinstance(cost, int)):
			raise IllegalArgumentException("Cost must be an instance of int or float")
		
		self._cost = float(cost)
	
	def setWidth(self, width=2.1):
		'''
		Set the SVM [Gaussian] kernel width
		
		:param width:
		:type width: int, float
		'''
		if(not isinstance(width, float) or not isinstance(width, int)):
			raise IllegalArgumentException("Kernel width must be an instance of int or float")
		
		self._width = float(width)
		
	def setTrainingSet(self, set):
		'''
		Set SVM training set, column-major
		
		:param set:
		:type set: ndarray
		'''
		## --- input validation
		if(not isinstance(set, ndarray)):
			raise IllegalArgumentException("Training set must be an instance of array")
		
		if(set.ndim != 2):
			raise IllegalArgumentException("Set must be 2-D")
		
		## --- ensure array are floats
		if(not isinstance(set[0][0], float)):
			set = set.astype(float)
		
		self._training = set
		
	def setLabels(self, labels):
		'''
		Set SVM training set labels
		
		:param labels:
		:type labels: list
		'''
		## --- input validation
		if(not isinstance(labels, list)):
			raise IllegalArgumentException("Labels must be an instance of list")
		
		if(len(labels) <= 1):
			raise IllegalArgumentException("Must be more than 1 label")
		
		## --- transpose training set array
		if(self._training != None):
			if(self._training.shape[1] != len(labels)):
				raise DimensionMismatchException("Number of labels do not match the number of training examples")
		
		## --- ensure labels are floats
		if(not isinstance(labels[0], float)):
			labels[0] = float(labels[0])
		
		## --- convert to a numpy array
		self._labels = array(self._labels)
	
		self._labels = labels
		
	def setTestSet(self, set):
		'''
		Set SVM test set, column-major
		
		:param set:
		:type set: ndarray
		'''
		## --- input validation
		if(not isinstance(set, ndarray)):
			raise IllegalArgumentException("Test set must be an instance of array")
		
		if(set.ndim != 2):
			raise IllegalArgumentException("Set must be 2-D")
		
		## --- ensure set are floats
		if(not isinstance(set[0][0], float)):
			set = set.astype(float)
		
		self._test = set
		
	def classify(self):
		'''
		Classify the test set
		
		:returns: Classified labels
		:rtype: list
		'''		
		## --- add training examples
		feats_train = RealFeatures(self._training);
		
		## --- add 
		feats_test = RealFeatures(self._test);
		kernel = GaussianKernel(feats_train, feats_train, self._width);
		
		labels = Labels(self._labels);
		svm = LibSVM(self._cost, kernel, labels);
		svm.train();
		
		kernel.init(feats_train, feats_test);
		
		return svm.classify().get_labels();