import sys
import json
import base64
from io import BytesIO
from PIL import Image
import numpy as np
from transformers import TableTransformerForObjectDetection, DetrImageProcessor
import torch

# Lade das Microsoft Table Transformer Modell und den DetrImageProcessor
detection_model = TableTransformerForObjectDetection.from_pretrained("microsoft/table-transformer-detection")
structure_model = TableTransformerForObjectDetection.from_pretrained("microsoft/table-transformer-structure-recognition")
# processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")  # Verwende DetrImageProcessor
processor = DetrImageProcessor()

# Zusätzlicher Rand beim Croppen der Bilder zum Erkennen von Zeilen und Spalten. Ohne Rand fehlen sonst die äußeren Zeilen und Spalten.
detection_margin = 20;
threshold = float(sys.argv[1]) if len(sys.argv) > 1 else 0.9

# Funktion zum Zuschneiden eines Tabellenbereichs aus dem Bild
def crop_table(image, box):
    x_min, y_min, x_max, y_max = map(int, box)
    return image.crop((x_min - detection_margin, y_min - detection_margin, x_max + detection_margin, y_max + detection_margin))

def process_image(image_bytes, pageIndex, first, output_base_path):
    """Verarbeitet ein Bild und erkennt Tabellen."""
    try:
        # Konvertiere Byte-Array zu PIL Image
        image = Image.open(BytesIO(image_bytes)).convert("RGB")
        # image.save("D:\\python\\XYZ\\tatr" + str(pageIndex) + "_" + str(first) + ".png", "PNG")
        # Vorbereitung des Bildes für das Modell mit DetrImageProcessor
        inputs = processor(images=image, return_tensors="pt")
        
        # Initialisiere das Ergebnis-Objekt
        tables = []

        # Schritt 1: Erkenne Tabellen im Bild
        #inputs = processor(images=image, return_tensors="pt")
        with torch.no_grad():
            detection_outputs = detection_model(**inputs)

        # Post-Processing der Tabellenerkennung
        target_sizes = torch.tensor([image.size[::-1]])
        detection_results = processor.post_process_object_detection(detection_outputs, target_sizes=target_sizes, threshold=threshold)[0]

        if output_base_path:
            from torchvision.utils import draw_bounding_boxes
            import torchvision.transforms as T
            import os

            # Erstelle den Dateinamen für die Seite
            base, ext = os.path.splitext(output_base_path)
            output_path = f"{base}_page_{pageIndex}_{first}{ext}"

            image_tensor = T.ToTensor()(image) * 255
            image_tensor = image_tensor.to(torch.uint8)
            boxes = torch.tensor(detection_results["boxes"], dtype=torch.float)
            labels = [f"{detection_model.config.id2label[label.item()]}: {score.item():.2f}" for label, score in zip(detection_results["labels"], detection_results["scores"])]

            annotated_image = draw_bounding_boxes(
                image_tensor, boxes, labels=labels, colors="red", width=2
            )
            T.ToPILImage()(annotated_image).save(output_path)
            # print(f"Seite {pageIndex}: Annotiertes Bild gespeichert unter: {output_path}")

        # Schritt 2: Verarbeite jede Tabelle und erkenne Zellen
        for table_idx, (table_score, table_label, table_box) in enumerate(zip( detection_results["scores"], detection_results["labels"], detection_results["boxes"])):
            # Erstelle Tabellen-Eintrag
            table_entry = {
                "score": float(table_score),  # Konvertiere in Python float
                "label": detection_model.config.id2label[table_label.item()],
                # Hinweis: die Werte müssen keine nativen Zahlen sein, es könnte z.B. auch ein torch.tensor sein - daher die Konvertierung
                # Verfügbare Methoden gibt es hier: https://pytorch.org/docs/stable/tensors.html
                "x": int(table_box[0]),
                "y": int(table_box[1]),
                "width": int(table_box[2]-table_box[0]),
                "height": int(table_box[3]-table_box[1]),
                "cells": []
            }

            # Schneide die Tabelle aus dem Originalbild
            cropped_table = crop_table(image, table_box)

            # Vorverarbeitung des zugeschnittenen Tabellenbildes
            structure_inputs = processor(images=cropped_table, return_tensors="pt")

            # Führe die Strukturerkennung durch
            with torch.no_grad():
                structure_outputs = structure_model(**structure_inputs)

            # Post-Processing der Strukturerkennung
            structure_target_sizes = torch.tensor([cropped_table.size[::-1]])
            structure_results = processor.post_process_object_detection(structure_outputs, target_sizes=structure_target_sizes, threshold=threshold * 0.75)[0]

            # Korrigiere die Bounding Boxes der Zellen, um sie relativ zum Originalbild zu machen
            x_offset, y_offset = map(int, table_box[:2])
            adjusted_boxes = structure_results["boxes"].clone()
            adjusted_boxes[:, [0, 2]] += x_offset - detection_margin  # Korrigiere x-Koordinaten
            adjusted_boxes[:, [1, 3]] += y_offset - detection_margin # Korrigiere y-Koordinaten

            # Füge Zellen zum Tabellen-Eintrag hinzu
            for cell_score, cell_label, cell_box in zip(
                structure_results["scores"], structure_results["labels"], adjusted_boxes
            ):
                cell_entry = {
                    "score": float(cell_score),  # Konvertiere in Python float
                    "label": structure_model.config.id2label[cell_label.item()],
                    "x": int(cell_box[0]),
                    "y": int(cell_box[1]),
                    "width": int(cell_box[2]-cell_box[0]),
                    "height": int(cell_box[3]-cell_box[1])
                }
                table_entry["cells"].append(cell_entry)

            # Füge die Tabelle zum Ergebnis hinzu
            tables.append(table_entry)
                
            
        return tables
    except Exception as e:
        return {"error": str(e)}

def main():
    """Liest JSON-Objekte von stdin und verarbeitet Bilder."""
    buffer = ""
    
    while True:
        try:
            # Lies Zeilen von stdin
            line = sys.stdin.readline()
            if not line:  # EOF
                break
            
            buffer += line.strip()
            
            # Versuche, JSON-Objekte zu parsen
            while buffer:
                try:
                    obj, idx = json.JSONDecoder().raw_decode(buffer)
                    buffer = buffer[idx:].strip()
                    
                    # Verarbeite das JSON-Objekt
                    if "imageData" in obj:
                        # Dekodiere Base64-Bild
                        image_bytes = base64.b64decode(obj["imageData"])
                        pageIndex = obj["pageIndex"]
                        first = obj["first"]
                        id = obj["id"]
                        # Erkenne Tabellen
                        tables = process_image(image_bytes, pageIndex, first, "D:\\python\\XYZ\\tatr.png")
                        # Gib Ergebnis als JSON aus
                        result = {"tables": tables, "pageIndex": pageIndex, "first": first, "id": id}
                        print(json.dumps(result), flush=True)
                    else:
                        print(json.dumps({"error": "No image field in JSON"}), flush=True)
                        
                except json.JSONDecodeError:
                    # Incomplete JSON, wait for more input
                    break
                
        except Exception as e:
            print(json.dumps({"error": str(e)}), flush=True)
            break

if __name__ == "__main__":
    main()