from chrysacase import ChrysalideTestCase
from pychrysalide import SourceEndian
from pychrysalide.analysis import BinContent
from pychrysalide.arch import vmpa


class CustomContent(BinContent):

    def __init__(self, size):
        super(CustomContent, self).__init__()
        self._start = 10
        self._size = size * 8

    def _describe(self, full):
        return 'my_desc' + ('_full' if full else '')

    def _compute_checksum(self, checksum):
        checksum.update(b'xxxxx')

    def _compute_size(self):
        return int(self._size / 8)

    def _compute_start_pos(self):
        return vmpa(self._start, vmpa.VmpaSpecialValue.NO_VIRTUAL)

    def _compute_end_pos(self):
        return vmpa(self._start + self._size, vmpa.VmpaSpecialValue.NO_VIRTUAL)

    def _seek(self, addr, length):
        addr += length
        return True

    def _read_uxxx(self, addr, sizeof):

        assert(addr >= self.start_pos and addr < self.end_pos)

        val = int((addr - self.start_pos).phys / sizeof)

        addr += sizeof

        return val

    def _read_u8(self, addr):
        return self._read_uxxx(addr, 1)

    def _read_u16(self, addr, endian):
        return self._read_uxxx(addr, 2)

    def _read_u32(self, addr, endian):
        return self._read_uxxx(addr, 4)

    def _read_u64(self, addr, endian):
        return self._read_uxxx(addr, 8)

    def _read_uleb128(self, addr):
        return 128

    def _read_leb128(self, addr):
        return -128


class TestCustomContent(ChrysalideTestCase):
    """TestCase for custom implementation of analysis.BinContent."""

    def testBasicImplementations(self):
        """Involve all implemented basic wrappers for a custom content."""

        cnt = CustomContent(1)

        self.assertEqual(cnt._describe(False), 'my_desc')

        self.assertEqual(cnt._describe(True), 'my_desc_full')

        # $ echo -n 'xxxxx' | sha256sum
        # eaf16bc07968e013f3f94ab1342472434a39fc3475f11cf341a6c3965974f8e9  -

        expected = 'eaf16bc07968e013f3f94ab1342472434a39fc3475f11cf341a6c3965974f8e9'

        self.assertEqual(cnt.checksum, expected)

        self.assertEqual(cnt.size, 1)

        addr = cnt.start_pos
        offset = 13

        cnt.seek(addr, offset)

        self.assertEqual(addr.phys, cnt.start_pos.phys + offset)


    def testReadImplementations(self):
        """Involve main implemented read wrappers for a custom content."""

        cnt = CustomContent(8)

        def _run_check_read_implem(fn, args):

            last = None

            pos = cnt.start_pos

            while pos < cnt.end_pos:

                val = fn(pos, *args)

                if not(last is None):
                    self.assertEqual(last + 1, val)
                else:
                    self.assertEqual(val, 0)

                last = val

        checks = [
            [ cnt.read_u8, [] ],
            [ cnt.read_u16, [ SourceEndian.LITTLE ] ],
            [ cnt.read_u32, [ SourceEndian.LITTLE ] ],
            [ cnt.read_u64, [ SourceEndian.LITTLE ] ],
        ]

        for f, a in checks:
            _run_check_read_implem(f, a)


    def testLEB128Implementations(self):
        """Involve [U]LEB128 implemented wrappers for a custom content."""

        cnt = CustomContent(1)

        addr = cnt.start_pos

        self.assertEqual(cnt.read_uleb128(addr), 128)

        self.assertEqual(cnt.read_leb128(addr), -128)