from psychopy import sound
import numpy as np
import scipy.signal as scis
from scipy.signal import butter
from scipy.signal import filtfilt

def wind(x,srate=22050,wdms=20):
  # srate in Hz, gate duration in ms, vector.
  npts = len(x)
  #npts = npts(2)
  if (srate==48828) :
    wds = np.round(2*wdms/1000 * srate)
    wds = np.round(wds)
  else:
    wds = 2*wdms/1000 * srate
    wds = np.round(wds)
    if np.mod(wds,2)!=0:
	wds = wds+1
  w = np.linspace(-1*(np.pi/2),1.5*np.pi,wds)
  w = (np.sin(w)+1)/2
  x[0:np.round(wds/2)] = x[0:np.round(wds/2)]*w[0:np.round(wds/2)]
  if (srate==48828):
    x[npts-np.round(wds/2):npts] = x[npts-np.round(wds/2):npts]*w[np.round(wds/2):wds]
  else: 
    x[npts-np.round(wds/2):npts] = x[npts-np.round(wds/2):npts]*w[np.round(wds/2):wds]
  return x



def gener_tone(f0=2000,sampleRate=22050,dur=0.05):
  nSamples = int(dur*sampleRate)
  outArr = np.arange(0.0,1.0, 1.0/nSamples)
  outArr *= 2*np.pi*f0*dur
  outArr1 = np.sin(outArr)
  outArr=outArr1
  outArr[0:110] = np.arange(0.0,1.0,1.0/110)*outArr[0:110]
  outArr[-110:] = np.flipud(np.arange(0.0,1.0,1.0/110))*outArr[-110:]
  return sound.Sound(outArr, sampleRate=22050, bits=8)


def pinknoisefilter(Xin):
  Nx = np.power(2,16) # number of samples to synthesize
  A = [1, -2.494956002,   2.017265875,  -0.522189400]
  B = [0.049922035, -0.095993537, 0.050612699, -0.004408786]
  nT60 = np.round(np.log(1000)/(1-np.max(np.abs(np.roots(A))))) # T60 est. the time to decay by  60 dB
  #v = np.random.randn(1,int(Nx+nT60)) # Gaussian white noise: N(0,1)
  #x = scis.lfilter(B,A,v[0])  # Apply 1/F roll-off to PSD
  x = scis.lfilter(B,A,Xin)  
  x = x[nT60+1:end] # Skip transient response
  return x

def harmcomp_create(dur=None, srate=None, f0=None, harms=None, phase_type=None, con_type=None):
    if phase_type is None:    
        phase_type = 's'    
    
    if con_type is None:
        con_type = 'p'    
    
    if con_type == 't':    
        subthresh = 0    
        p = 0    
        while not subthresh:        
            p += 1        
            new_f0 = f0 / p        
            if new_f0 <= 30:            
                subthresh = 1            
                    
        harms = np.arange((min(harms) * f0),new_f0,(max(harms) * f0)) / new_f0    
        f0 = new_f0
    
    if con_type == 'j':    
        jit_vals = (rand(size(harms)) - 0.5) * f0    
        harms = (harms * f0 + jit_vals) / f0    
    
    if con_type == 'n':    
        stim_out = farpn(srate, min(harms) * f0, max(harms) * f0, dur)    
    else:    
        t = np.arange(0,dur, 1. / srate)
        #print len(t)
        #stim_out = np.zeros((dur * srate))
        stim_out = np.zeros((np.shape(t)[0])) # seems to be safer
        #print len(stim_out)
        phase_off = 0    
        for h in np.sort(harms):
	  #print 'Adding Harmonic %d' % h
	  if phase_type == 'n':            
	      phase_off = 0            
	  elif phase_type == 'a':            
	      phase_off = phase_off + np.pi            
	  elif phase_type == 'r':            
	      phase_off = rand(1, 1) * 2 * np.pi            
	  elif phase_type == 's':            
	      phase_off = np.pi * h * (h - 1) / max(harms)            
	  elif phase_type == 'z':            
	      phase_off = -1 * np.pi * h * (h - 1) / max(harms)
	  
	  stim_out = stim_out + np.sin(2 * np.pi * h * f0 * t + phase_off)        
            
        stim_out = 0.1 * stim_out / np.sqrt(np.mean(stim_out*stim_out))    
    
    if con_type == 'r':    
        stim_out = phase_randomise(stim_out, srate, 500 / f0)
        
    stim_out = filterpitch(stim_out)
    return stim_out

def filterpitch(x,order=4,srate=22050):
  Wn = np.asarray([1000,4000])/(srate/2.)
  b, a = butter(order,Wn,btype='band')
  y = filtfilt(b, a, x)
  return y


def farpn(fs=22050,lo=500, hi=5000, dur=1):
  '''
  Fixed-Amplitude Random-Phase Noise
  Produces fixed-amplitude random-phase noise with specified passband
  Usage:
  stim_out = farpn(srate,low_f,high_f,dur);

  stim_out = output stimulus (row vector)
  srate = sampling rate (Hz)
  low_f = lower limit of passband (Hz)
  high_f = upper limit of passband (Hz)
  dur = duration of output stimulus (s)

  Originally by Tim Griffiths, 2004
  '''
  t = np.arange(0,dur,1./fs)
  npts = len(t)
  fbin = np.float(fs)/npts
  mag = np.ones((1,npts/2))
  mag[np.round(1):np.round(lo/fbin)] = 0
  mag[np.round(hi/fbin):len(mag)] = 0

  phase = 2*np.pi*np.random.randn(npts/2-1)
  allphase = np.concatenate(([0], phase,[0],-phase[::-1]))
  allmag = np.concatenate((mag, mag[::-1]))
  rect = allmag*np.exp(allphase*1j)
  sig = np.real(np.fft.ifft(rect))
  x = sig/(10*np.std(sig)) # set rms value to 0.1
  return x
  
  
def phase_randomise(stim_in,time_bin,srate=22050):
  '''
  Phase randomisation function
  
  Designed for use with Regular Interval Noise (RIN) stimuli
  
  Converts input stimulus into frequency domain, randomises phase
  values while maintaining power spectrum over time, then returns
  stimulus to time domain. Psychoacoustic effect is to retain power
  spectral fluctuations while removing pitch

  Operates over a number of user-specified time bins. A time bin length
  equal to or longer than the delay used to create the RIN will not
  remove pitch entirely. An overly short time bin will also not remove pitch
  entirely. Suggested time bin length is half the delay of the RIN. Phase
  is randomised using overlapping, Hanning tapered, windows to remove
  sharp transitions.

  Note, output stimulus may be slightly truncated. To avoid truncation,
  ensure that the length of the input stimulus is an exact multiple of
  time_bin

  Usage:

  stim_out = phase_randomise(stim_in,srate,time_bin)

  stim_out = output waveform (row vector)
  stim_in = imput waveform (row vector)
  srate = stimulus sampling rate (Hz)
  time_bin = time bin length (ms) used for each phase randomisation, which
	  is recommended to be half of the reciprocal of the frequency
	  e.g. 100Hz frequency gives 10ms reciprocal, so 5ms time bin

  Part of Pitch Stimulus Design Toolbox
  Version 1.1
  Will Sedley, Newcastle Auditory Group, UK
  December 2010
  '''
  time_bin = 2*np.round(time_bin*srate/2000)
  stim_in = stim_in[np.arange(1,(len(stim_in)-mod(len(stim_in),time_bin)))]
  stim_out_a = np.zeros(np.size(stim_in))
  stim_out_b = np.zeros(np.size(stim_in))
  nbins = len(stim_in)/time_bin
  hanwin = hanning(time_bin).T
  hanstart = np.concatenate([np.ones(1,time_bin/2), hanwin[np.arange(1,(time_bin/2+1),time_bin)]])
  hanend = hanstart[::-1]
  for cur_bin in range(0,np.floor(nbins)):
    bin_f = np.fft(stim_in[np.arange((cur_bin-1)*time_bin+1,cur_bin*time_bin)])
    bin_phase = 2*np.pi*np.randn(1,time_bin/2-1)
    bin_phase = np.concatenate(([0], bin_phase, [0], -bin_phase[::-1]))
    if cur_bin == 1:
      stim_out_a[(cur_bin-1)*time_bin+1:cur_bin*time_bin] = hanstart*np.real(np.fft.ifft(bin_f*np.exp(bin_phase*1j)))
    elif cur_bin == np.floor(nbins):
      stim_out_a[(cur_bin-1)*time_bin+1:cur_bin*time_bin] = hanend*np.real(np.fft.ifft(bin_f*np.exp(bin_phase*1j)))
    else:
	stim_out_a[(cur_bin-1)*time_bin+1:cur_bin*time_bin] = hanwin*np.real(np.fft.ifft(bin_f*np.exp(bin_phase*1j)))
      
  for cur_bin in range(0,np.floor(nbins-1)):
    bin_f = np.fft(stim_in[(cur_bin-0.5)*time_bin+1:(cur_bin+0.5)*time_bin])
    bin_phase = 2*np.pi*np.randn(1,time_bin/2-1)
    bin_phase = np.concatenate(([0], bin_phase, [0], -bin_phase[::-1]))
    stim_out_b[(cur_bin-0.5)*time_bin+1:(cur_bin+0.5)*time_bin] = hanwin*np.real(np.fft.ifft(bin_f*np.exp(bin_phase*1j)))
  
  stim_out = stim_out_a + stim_out_b  
  return stim_out
  
  
  
def createPsySound(arraysound,play=True,srate=22050,rfms=5):
  risefalltime_sp = int((float(rfms)/1000)*srate)
  arraysound[0:risefalltime_sp] = np.arange(0.0,1.0,1.0/risefalltime_sp)*arraysound[0:risefalltime_sp]
  arraysound[-risefalltime_sp:] = np.flipud(np.arange(0.0,1.0,1.0/risefalltime_sp))*arraysound[-risefalltime_sp:]
  soundA = sound.SoundPygame(arraysound,sampleRate=srate,bits=16)
  if play:
    soundA.play()
  return soundA
  
def harmcomp_create2(durA=None, srate=None, f0A=None, f0B=None, harms=None, phase_type=None, con_type=None,pcvolA=100,filtering=False,catch_trial=False):
  if catch_trial:
    aC = harmcomp_create(dur=10, srate=22050, f0=f0A, harms=harms, phase_type=phase_type, con_type=con_type)
    if filtering:
      aC = filterpitch(aC)
      aC = 0.1*aC/np.std(aC)
      aC = wind(aC,srate=22050,wdms=20)
  else:
    aA = harmcomp_create(dur=durA, srate=22050, f0=f0A, harms=harms, phase_type=phase_type, con_type=con_type)
    aA = aA*pcvolA/100.
    aB = harmcomp_create(dur=2, srate=22050, f0=f0B, harms=harms, phase_type=phase_type, con_type=con_type)
    if filtering and not(catch_trial):
      aA = filterpitch(aA)
      aB = filterpitch(aB)
      aA = 0.1*aA/np.std(aA)
      aB = 0.1*aB/np.std(aB)
      aA = wind(aA,srate=22050,wdms=20)
      aB = wind(aB,srate=22050,wdms=20)
    aC = np.concatenate((aA.T,aB.T))
  return aC
  
def harmcomp_create3(durA=10, srate=None, f0A=None, f0B=None, harms=None, phase_type=None, con_type=None,burst=150,rate=4,rfms=5,filtering=False):
  if srate is None:
    srate = 22050
  sildur = (1000-(burst*rate))/rate
  durAtot = durA*(burst+sildur)
  aA = harmcomp_create(dur=durAtot/1000., srate=srate, f0=f0A, harms=harms, phase_type=phase_type, con_type=con_type)
  risefalltime_sp = int((float(rfms)/1000)*srate)
  ON_sp = int((float(burst-2*rfms)/1000)*srate)
  #print ON_sp
  OFF_sp = int((float(sildur)/1000)*srate)+2
  #print OFF_sp
  mask = np.concatenate((np.arange(0.0,1.0,1.0/risefalltime_sp),np.ones((ON_sp)),np.flipud(np.arange(0.0,1.0,1.0/risefalltime_sp)),np.zeros((OFF_sp ))))

  #,np.zeros((int((sildur/1000.)*srate)))))
  maskall =  np.tile(mask,durA)
  print len(maskall)
  print len(aA)
  aA = aA*maskall[0:aA.shape[0]]
  aB = harmcomp_create(dur=durAtot/1000., srate=srate, f0=f0B, harms=harms, phase_type=phase_type, con_type=con_type)
  aB = aB*maskall[0:aB.shape[0]]
  aC = np.concatenate((aA.T,aB.T))
  #return aC,maskall,aA,durAtot,mask
  if filtering:
    aC = filterpitch(aC)
  return aC

def harmcomp_create4(durA=10, srate=None, f0A=None, f0B=None, harms=None, phase_type=None, con_type=None,burstA=150,burstB=150,rate=4,rfms=5):
  if srate is None:
    srate = 22050
  sildurA = (1000-(burstA*rate))/rate
  durAtotA = durA*(burstA+sildurA)
  aA = harmcomp_create(dur=durAtotA/1000., srate=srate, f0=f0A, harms=harms, phase_type=phase_type, con_type=con_type)
  risefalltime_sp = int((float(rfms)/1000)*srate)
  ON_spA = int((float(burstA-2*rfms)/1000)*srate)
  #print ON_sp
  OFF_spA = int((float(sildurA)/1000)*srate)+2
  #print OFF_sp
  maskA = np.concatenate((np.arange(0.0,1.0,1.0/risefalltime_sp),np.ones((ON_spA)),np.flipud(np.arange(0.0,1.0,1.0/risefalltime_sp)),np.zeros((OFF_spA))))

  #,np.zeros((int((sildur/1000.)*srate)))))
  maskallA =  np.tile(maskA,durA)
  sildurB = (1000-(burstB*rate))/rate
  durBtotB = durA*(burstB+sildurB)
  risefalltime_sp = int((float(rfms)/1000)*srate)
  ON_spB = int((float(burstB-2*rfms)/1000)*srate)
  #print ON_sp
  OFF_spB = int((float(sildurB)/1000)*srate)+2
  #print OFF_sp
  maskB = np.concatenate((np.arange(0.0,1.0,1.0/risefalltime_sp),np.ones((ON_spB)),np.flipud(np.arange(0.0,1.0,1.0/risefalltime_sp)),np.zeros((OFF_spB))))

  #,np.zeros((int((sildur/1000.)*srate)))))
  maskallB =  np.tile(maskB,durA)
  print len(maskallA)
  print len(aA)
  aA = aA*maskallA[0:aA.shape[0]]
  aB = harmcomp_create(dur=durAtotA/1000., srate=srate, f0=f0B, harms=harms, phase_type=phase_type, con_type=con_type)
  aB = aB*maskallB[0:aB.shape[0]]
  aC = np.concatenate((aA.T,aB.T))
  #return aC,maskall,aA,durAtot,mask
  return aC  
#def test():
  #stimout = harmcomp_create(dur=1, srate=22050, f0=2000, harms=[1,2,3], phase_type='n', con_type='p')
  #psysound = createPsySound(stimout,play=False)  