import sys
import klog
import emtools

klog.LevelManager.add(klog.Level("STATE", 0.1))

KB = 1024
MB = KB * 1024
GB = MB * 1024

FC = 0
FZ = 1
FN = 2
FV = 3
FS = 4
FH = 5
FT = 6
FI = 7

def format_size(n):
  mods = ["bytes", "Kb", "Mb", "Gb"]
  
  modnum = 0
  while n >= 1024:
    modnum += 1
    n /= 1024.0
  
  if int(n) == n:
    return "%d %s" % (n, mods[modnum])
  
  else:
    return "%.1f %s" % (n, mods[modnum])

class Platform:
  # (name, flashsize, sramsize, eepromsize)
  
  ATMEGA168 = ("atmega168", 16*KB, KB, 512)

class AVREmulator(object):
  def __init__(self, platform=Platform.ATMEGA168):
    self.platform = platform
    
    self.running = False
    self.program_end = -1
    
    self.logger = klog.Logger("AVRem")
    self.logger.log("main", "INFO", "Platform: " + platform[0])
    self.logger.log("main", "INFO", "  Flash: " + format_size(platform[1]))
    self.logger.log("main", "INFO", "  SRAM: " + format_size(platform[2]))
    self.logger.log("main", "INFO", "  EEPROM: " + format_size(platform[3]))
    
    self.sreg = emtools.Register(8)
    self.pc = emtools.Register(16)
    
    self.program_memory = emtools.ROMemory(self.platform[1])
    
    self.data_memory = emtools.MemoryMap()
    self.registers = self.data_memory.map(0x0000, 0x0020, True)
    self.ioports = self.data_memory.map(0x0020, 0x0060, True)
    self.sram = self.data_memory.map(0x0060, self.platform[2], True)
  
  def load(self, program):
    self.program_memory.clear()
    self.program_memory.load(program)
    self.program_end = len(program)
    
    self.logger.log("main", "INFO", "Loaded program (%s)" % format_size(self.program_end))
  
  def get_flag_str(self):
    s = ""
    for i, c in enumerate("cznvshti"):
      if self.sreg.getbit(i): c = c.upper()
      s += c
    return s
  
  def log_state(self):
    lines = [
      "Current CPU state:",
      "  Flags: %s" % self.get_flag_str(),
      "  Registers: %s" % (", ".join(["r%02d=%02Xh" % (i, self.registers.get(i)) for i in range(0, 8)])),
      "             %s" % (", ".join(["r%02d=%02Xh" % (i, self.registers.get(i)) for i in range(8, 16)])),
      "             %s" % (", ".join(["r%02d=%02Xh" % (i, self.registers.get(i)) for i in range(16, 24)])),
      "             %s" % (", ".join(["r%02d=%02Xh" % (i, self.registers.get(i)) for i in range(24, 32)])),
    ]
    
    self.logger.log("main", "STATE", "\n".join(lines))
  
  def _halt(self):
    self.running = False
  
  def get_next_instruction(self):
    pc = self.pc.get()
    inst = self.program_memory.get_word(pc)
    self.logger.log("main", "DEBUG", "[%04Xh] Processing %04Xh" % (pc, inst))
    
    pc += 2
    self.pc.set(pc)
    if pc >= self.program_end:
      self.running = False
    
    return inst
  
  def run(self):
    self.running = True
    while self.running:
      self.run_one()
  
  def run_one(self):
    inst = self.get_next_instruction()
    
    if inst & 0xE000 == 0xE000:
      k = ((inst & 0x0F00) >> 4) + (inst & 0x000F)
      d = ((inst & 0x00F0) >> 4) + 16
      self.registers.set(d, k)
    
    elif inst & 0xFC00 == 0x0C00:
      rnum = ((inst & 0x0200) >> 5) + (inst & 0x000F)
      dnum = (inst & 0x01F0) >> 4
      r = self.registers.get(rnum)
      d = self.registers.get(dnum)
      x = (r + d) & 0xFF
      self.registers.set(dnum, x)
      
      carry = (d & r) | (r & ~x) | (~x & d)
      self.sreg.setbit(FC, carry >> 7)
      self.sreg.setbit(FZ, x == 0)
      self.sreg.setbit(FN, x & 0x80)
      self.sreg.setbit(FV, ((d & r & ~x) | (~d & ~r & x)) >> 7)
      self.sreg.setbit(FS, self.sreg.getbit(FN) ^ self.sreg.getbit(FV))
      self.sreg.setbit(FH, (carry >> 3) & 1)

if __name__ == "__main__":
  import discmd
  program = discmd.get_text(sys.argv[1:])
  
  em = AVREmulator()
  em.load(program)
  em.run()
  em.log_state()
