#!/usr/bin/env python

from . import index
from . import ops


class TextSearch(object):
    def __init__(self, text):
        self.words, self.index = index.index_text(text)

    def search(self, word, context, asstring=True):
        assert isinstance(word, (str, unicode))
        assert isinstance(context, int)
        assert context > -1
        if word not in self.index:
            return []
        rs = []
        if context == 0:  # special case for 0
            rs.append([self.words[i] for i in self.index[word]])
        else:
            for i in self.index[word]:
                ws = self.words[i-context:i+context+1]
                ws[0] = ops.remove_leading_punctuation(ws[0])
                ws[-1] = ops.remove_trailing_punctuation(ws[-1])
                rs.append(ws)
        return rs


def test():
    ts = TextSearch(".foo bar baz. bam")

    rs = ts.search('bar', 1)
    assert len(rs) == 1
    assert ' '.join(rs[0]) == 'foo bar baz'

    rs = ts.search('bar', 0)
    assert len(rs) == 1
    assert ' '.join(rs[0]) == 'bar'

    try:
        ts.search('bar', '1')
    except AssertionError:
        pass

    try:
        ts.search(1, 2)
    except AssertionError:
        pass

    try:
        ts.search('bar', -1)
    except AssertionError:
        pass
