try:
  import unittest2 as unittest
except ImportError:
  import unittest

from mock import patch
from mock import MagicMock as Mock

import stat

import arklib.ark_utils as ark_utils

class ark_utils_test(unittest.TestCase):

  @patch('arklib.ark_utils.ark_s3')
  def test_download_from_s3(self, m_s3):
    _cred = Mock()
    _cred.aws_access_key = 'access'
    _cred.aws_secret_key = 'secret'

    s3 = Mock()
    m_s3.return_value = s3

    s3.trim_s3_scheme.return_value = 'trimmed_url'
    s3.parse_bucket_key.return_value = ('bucket', 'key')

    ark_utils.download_from_s3('url', 'local_file', _cred)

    # Asserts
    m_s3.assert_called_with('access', 'secret')
    s3.trim_s3_scheme.assert_called_with('url')
    s3.parse_bucket_key.assert_called_with('trimmed_url')
    s3.download_file.assert_called_with('bucket', 'key', 'local_file')

  @patch('__builtin__.open')
  @patch('os.stat')
  @patch('os.chmod')
  def test_write_script(self, m_chmod, m_stat, m_open):
    # Setup
    m_file = Mock()
    m_open.return_value = m_file
    m_file.write
    m_st = Mock()
    m_stat.return_value = m_st
    m_st.st_mode = 10

    # Run
    ark_utils.write_script("script_path", "content")

    # Asserts
    m_open.assert_called_with("script_path", 'w')
    m_file.write.assert_called_with("content")
    m_file.close.assert_called_with()
    m_stat.assert_called_with("script_path")
    m_chmod.assert_called_with("script_path", m_st.st_mode | stat.S_IEXEC)

  @patch('subprocess.call')
  def test_exec_script(self, sp_call):
    ark_utils.exec_script(["script_path"])
    sp_call.assert_called_with(["script_path"], shell=False)

  @patch('os.remove')
  def test_delete_script(self, os_remove):
    ark_utils.delete_script("script_path")
    os_remove.assert_called_with("script_path")

  def test_random_word(self):
    rw = ark_utils.random_word(10)
    self.assertTrue(len(rw) == 10)
    rw = ark_utils.random_word(0)
    self.assertTrue(len(rw) == 0)

  def test_guess_region_from_az(self):
    az = 'us-east-1a'
    region = ark_utils.guess_region_from_az(az)
    self.assertTrue(region == 'us-east-1')

