#!/usr/bin/env python3
"""
Phase 3: Translate book from English to Russian using Ollama.
Resumable, chunk-based translation with glossary injection.
v2 — improved validation, math-block preservation, fallback handling.
"""

import json
import re
import time
import sys
import requests
import yaml
from pathlib import Path

# === CONFIGURATION ===
CONFIG = {
    "input_file": "book_cleaned2.md",
    "glossary_file": "glossary.yaml",
    "state_file": "translation_state.jsonl",
    "output_file": "book_translated_ru.md",
    "ollama_url": "http://127.0.0.1:11434",
    "model": "gemma4:26b",
    "chunk_size_target": 1500,
    "chunk_size_max": 2000,
    "context_paragraphs": 3,
    "max_retries": 3,
    "temperature": 0.3,
    "request_timeout": 9000,
    "fallback_to_source": False,
}

SYSTEM_PROMPT = """You are a professional literary translator from English to Russian.
You are translating "The Man Who Loved Only Numbers" by Paul Hoffman — a biography of mathematician Paul Erdős, intended for a general audience.

RULES:
1. Translate into natural, fluent literary Russian. Match the tone and style of the original — it is lively, witty, and accessible.
2. Preserve all paragraph breaks. Each blank-line-separated paragraph in the source must produce exactly one paragraph in your translation, in the same order. Do NOT merge paragraphs even if the source seems to break mid-sentence (this is a deliberate artifact of the source layout — keep the break).
3. Do NOT translate or alter mathematical content. This includes:
   - Anything between $...$ or $$...$$ delimiters — copy character-for-character.
   - LaTeX commands like \\sqrt, \\pi, \\aleph, \\begin{array}, \\end{array}, \\frac, \\sum, etc.
   - Inline numeric expressions like "714 × 715 = 2 × 3 × 5 × 7 × 11 × 13 × 17", "2^67 - 1", "1/171".
4. Do NOT modify Markdown image references. Copy them VERBATIM, including the exact filename in the parentheses. For example, `![](_page_53_Picture_2.jpeg)` must remain `![](_page_53_Picture_2.jpeg)` — never replace it with `image.jpg` or any other placeholder.
5. Preserve all other Markdown formatting: headings (#, ##), bold (**), italic (*), tables, lists, blockquotes.
6. Use the provided glossary for all names, places, and terms. If a term is in the glossary, you MUST use the specified Russian translation.
7. Do not add, remove, summarize, or paraphrase content. Translate everything faithfully.
8. Do not add translator's notes, commentary, or bracketed annotations like [P1], [int], [Pло] — these are forbidden.
9. Keep Latin phrases (like "Non numerantur, sed ponderantur") in Latin, followed by the existing Russian translation in parentheses if present in the source.
10. Output ONLY the Russian translation. No preamble, no explanation, no "Here is the translation:" prefix."""


def load_glossary(path):
    """Load glossary YAML and format as string for prompt injection."""
    with open(path, 'r', encoding='utf-8') as f:
        data = yaml.safe_load(f)

    lines = []
    for category, entries in data.items():
        if isinstance(entries, dict):
            for eng, rus in entries.items():
                lines.append(f"  {eng} → {rus}")

    return '\n'.join(lines)


def split_into_chunks(text, target_words=1500, max_words=2000):
    """
    Split text into chunks by paragraphs.
    Never splits a paragraph. Tries to hit target_words per chunk.
    Respects heading boundaries (always starts new chunk at # or ## heading).
    """
    paragraphs = re.split(r'\n\n+', text)
    chunks = []
    current_chunk = []
    current_words = 0

    for para in paragraphs:
        para = para.strip()
        if not para:
            continue

        para_words = len(para.split())
        is_heading = para.startswith('#')

        if is_heading and current_chunk:
            chunks.append('\n\n'.join(current_chunk))
            current_chunk = [para]
            current_words = para_words
        elif current_words + para_words > max_words and current_chunk:
            chunks.append('\n\n'.join(current_chunk))
            current_chunk = [para]
            current_words = para_words
        elif (current_words + para_words > target_words
              and current_chunk
              and current_words > target_words // 2):
            chunks.append('\n\n'.join(current_chunk))
            current_chunk = [para]
            current_words = para_words
        else:
            current_chunk.append(para)
            current_words += para_words

    if current_chunk:
        chunks.append('\n\n'.join(current_chunk))

    return chunks


def load_state(path):
    """Load translation state from JSONL file (last entry per chunk_id wins)."""
    state = {}
    if Path(path).exists():
        with open(path, 'r', encoding='utf-8') as f:
            for line in f:
                line = line.strip()
                if line:
                    entry = json.loads(line)
                    state[entry['chunk_id']] = entry
    return state


def save_state_entry(path, entry):
    """Append a single state entry to the JSONL file."""
    with open(path, 'a', encoding='utf-8') as f:
        f.write(json.dumps(entry, ensure_ascii=False) + '\n')


def get_context(state, chunk_index, n_paragraphs=3):
    """Get the last n paragraphs from the previous translated chunk."""
    if chunk_index == 0:
        return ""

    prev_id = f"chunk_{chunk_index - 1:04d}"
    if prev_id in state and state[prev_id].get('translated_text'):
        paragraphs = state[prev_id]['translated_text'].split('\n\n')
        context_paras = paragraphs[-n_paragraphs:]
        return '\n\n'.join(context_paras)

    return ""


def call_ollama(prompt, system_prompt, config):
    """Call Ollama API and return the response text."""
    url = f"{config['ollama_url']}/api/generate"

    payload = {
        "model": config["model"],
        "prompt": prompt,
        "system": system_prompt,
        "stream": False,
        "options": {
            "temperature": config["temperature"],
            "num_ctx": 32768,
        }
    }

    try:
        response = requests.post(url, json=payload, timeout=config["request_timeout"])
        response.raise_for_status()
        result = response.json()
        return result.get("response", "").strip()
    except requests.exceptions.Timeout:
        print(f"    ERROR: Request timed out ({config['request_timeout']}s)")
        return None
    except requests.exceptions.RequestException as e:
        print(f"    ERROR: {e}")
        return None


def build_prompt(source_text, glossary_str, context_text):
    """Build the user prompt for translation."""
    parts = []

    parts.append("GLOSSARY (use these translations for all names and terms):")
    parts.append(glossary_str)
    parts.append("")

    if context_text:
        parts.append("CONTEXT (previous translated text, for continuity — do NOT re-translate this):")
        parts.append(context_text)
        parts.append("")

    parts.append("TRANSLATE THE FOLLOWING TEXT (preserve every blank-line paragraph break exactly):")
    parts.append(source_text)

    return '\n'.join(parts)


# === GARBAGE-TOKEN DETECTION ===
# Patterns that indicate the model hallucinated a stray artifact like
# [Pло4], [P1зом], [intP5], [P\n\n8], "прошлоstring", "2ло 220".
GARBAGE_PATTERNS = [
    r'\[P[а-яА-Я0-9\s]{1,8}\]',          # [Pло4], [P1зом]
    r'\[int[А-Яа-я0-9]{1,5}\]',          # [intP5]
    r'\[[A-Za-z]{1,4}\d+[а-яА-Я]+\]',    # mixed-script bracket tokens
    r'[а-яА-Я]{2,}[a-z]{4,}\b',          # "прошлоstring" — Cyrillic glued to English
    r'\b\d+ло\s+\d+',                    # "2ло 220" — digit-Cyrillic-digit garbage
]


def detect_garbage_tokens(text):
    """Return a list of suspicious tokens found in the translation."""
    found = []
    for pat in GARBAGE_PATTERNS:
        for m in re.finditer(pat, text):
            found.append(m.group(0))
    return found

def strip_model_preamble(text):
    """Remove common LLM preambles like 'Here is the translation:' from the start."""
    if not text:
        return text
    preambles = [
        r'^(Here is|Here\'s)\s+the\s+(Russian\s+)?translation[:.]?\s*\n+',
        r'^(Вот|Вот\s+перевод)[:.]?\s*\n+',
        r'^Translation[:.]?\s*\n+',
        r'^Перевод[:.]?\s*\n+',
    ]
    cleaned = text
    for pat in preambles:
        cleaned = re.sub(pat, '', cleaned, flags=re.IGNORECASE)
    return cleaned.strip()


def extract_math_blocks(text):
    """Return a list of all $$...$$ and $...$ math blocks (non-overlapping)."""
    # $$...$$ first (greedy-safe with non-greedy + DOTALL), then single $...$
    display = re.findall(r'\$\$.+?\$\$', text, re.DOTALL)
    # For inline $...$, avoid matching things like "$5" (currency) — require non-space after $
    inline = re.findall(r'(?<!\$)\$(?!\s)[^\$\n]+?(?<!\s)\$(?!\$)', text)
    return display + inline


def extract_image_refs(text):
    """Return all Markdown image references with their exact filenames."""
    return re.findall(r'!\[[^\]]*\]\([^)]+\)', text)


def validate_translation(source_text, translated_text):
    """
    Basic validation of translation quality.
    Returns (is_valid, issues_list, severity).
    severity: 'ok' | 'warn' | 'fatal'
    """
    issues = []
    severity = 'ok'

    if not translated_text:
        return False, ["Empty translation"], 'fatal'

    # --- Paragraph count: WARN only, never fatal ---
    source_paras = [p.strip() for p in source_text.split('\n\n') if p.strip()]
    trans_paras = [p.strip() for p in translated_text.split('\n\n') if p.strip()]

    diff = abs(len(source_paras) - len(trans_paras))
    # Tolerate ±2 paragraphs (model often legitimately joins mid-sentence breaks)
    if diff > 2:
        issues.append(f"Paragraph count mismatch: source={len(source_paras)}, translation={len(trans_paras)}")
        severity = 'warn'

    # --- Math blocks: must be preserved verbatim ---
    source_math = extract_math_blocks(source_text)
    for formula in source_math:
        if formula not in translated_text:
            issues.append(f"Math block altered or missing: {formula[:60]}...")
            severity = 'warn'

    # --- Image references: must be preserved EXACTLY (with original filename) ---
    source_images = extract_image_refs(source_text)
    trans_images = extract_image_refs(translated_text)
    for img in source_images:
        if img not in translated_text:
            # Detect the common failure mode: model replaced filename with image.jpg
            placeholder_used = any('image.jpg' in t or 'image.jpeg' in t or 'image.png' in t
                                   for t in trans_images)
            if placeholder_used:
                issues.append(f"Image filename replaced with placeholder (expected: {img})")
            else:
                issues.append(f"Image reference missing: {img}")
            severity = 'warn'

    # --- Cyrillic ratio: must be mostly Russian ---
    cyrillic = len(re.findall(r'[а-яА-ЯёЁ]', translated_text))
    latin = len(re.findall(r'[a-zA-Z]', translated_text))
    if cyrillic == 0:
        issues.append("No Cyrillic characters found — translation likely failed")
        severity = 'fatal'
    elif latin > 0 and cyrillic / (cyrillic + latin) < 0.4:
        issues.append(f"Low Cyrillic ratio: {cyrillic}/{cyrillic + latin}")
        severity = 'warn'

    # --- Garbage tokens: hallucinated artifacts like [Pло4], [intP5] ---
    garbage = detect_garbage_tokens(translated_text)
    if garbage:
        # Cap report to first 5 to keep logs sane
        sample = ', '.join(repr(g) for g in garbage[:5])
        issues.append(f"Garbage tokens detected ({len(garbage)}): {sample}")
        severity = 'warn'

    # --- Length sanity: translation shouldn't be wildly shorter/longer ---
    src_len = len(source_text)
    trn_len = len(translated_text)
    if src_len > 200:
        ratio = trn_len / src_len
        # Russian is typically 1.0–1.3x English in chars; flag <0.6 or >2.0
        if ratio < 0.6:
            issues.append(f"Translation suspiciously short: ratio={ratio:.2f}")
            severity = 'warn'
        elif ratio > 2.0:
            issues.append(f"Translation suspiciously long: ratio={ratio:.2f}")
            severity = 'warn'

    is_valid = (severity == 'ok')
    return is_valid, issues, severity


def translate_chunk(chunk_index, source_text, glossary_str, context_text, config, state):
    """Translate a single chunk with retries and graceful fallback."""
    chunk_id = f"chunk_{chunk_index:04d}"

    # Skip if already done (or done_with_issues — treat as done for resume purposes)
    if chunk_id in state and state[chunk_id].get('status') in ('done', 'done_with_issues'):
        return state[chunk_id]['translated_text']

    prompt = build_prompt(source_text, glossary_str, context_text)

    best_attempt = None        # (translated, issues, severity, elapsed)
    last_elapsed = 0

    for attempt in range(config['max_retries'] + 1):
        print(f"  Attempt {attempt + 1}/{config['max_retries'] + 1}...")
        start_time = time.time()

        translated = call_ollama(prompt, SYSTEM_PROMPT, config)

        elapsed = time.time() - start_time
        last_elapsed = elapsed
        print(f"  Completed in {elapsed:.1f}s")

        if translated is None:
            print(f"  Failed — API error")
            if attempt < config['max_retries']:
                print(f"  Retrying in 10s...")
                time.sleep(10)
                continue
            else:
                break  # fall through to fallback handling below

        # Strip any "Here is the translation:" preamble the model added
        translated = strip_model_preamble(translated)

        # Validate
        is_valid, issues, severity = validate_translation(source_text, translated)

        if is_valid:
            entry = {
                'chunk_id': chunk_id,
                'chunk_index': chunk_index,
                'status': 'done',
                'source_text': source_text,
                'translated_text': translated,
                'timestamp': time.time(),
                'elapsed_seconds': elapsed,
                'attempts': attempt + 1,
            }
            save_state_entry(config['state_file'], entry)
            state[chunk_id] = entry
            return translated

        # Not perfectly valid — keep the best attempt so far
        print(f"  Validation [{severity}]: {issues}")

        if best_attempt is None or _is_better(severity, best_attempt[2]):
            best_attempt = (translated, issues, severity, elapsed)

        # Fatal issues (empty / no Cyrillic) — always retry if we have budget
        # Warn-level issues — also retry; we'll use best_attempt if all retries fail
        if attempt < config['max_retries']:
            print(f"  Retrying...")
            time.sleep(2)
            continue

    # === All retries exhausted ===
    if best_attempt is not None:
        translated, issues, severity, elapsed = best_attempt
        print(f"  WARNING: Saving best attempt despite issues ({severity})")
        entry = {
            'chunk_id': chunk_id,
            'chunk_index': chunk_index,
            'status': 'done_with_issues',
            'source_text': source_text,
            'translated_text': translated,
            'issues': issues,
            'severity': severity,
            'timestamp': time.time(),
            'elapsed_seconds': elapsed,
            'attempts': config['max_retries'] + 1,
        }
        save_state_entry(config['state_file'], entry)
        state[chunk_id] = entry
        return translated

    # Total failure: no usable response from any attempt
    if config.get('fallback_to_source', False):
        print(f"  FALLBACK: Storing source text verbatim so EPUB stays buildable")
        entry = {
            'chunk_id': chunk_id,
            'chunk_index': chunk_index,
            'status': 'failed_fallback',
            'source_text': source_text,
            'translated_text': f"<!-- UNTRANSLATED — needs manual work -->\n\n{source_text}",
            'issues': ["All retries failed — stored source as fallback"],
            'timestamp': time.time(),
            'elapsed_seconds': last_elapsed,
            'attempts': config['max_retries'] + 1,
        }
    else:
        entry = {
            'chunk_id': chunk_id,
            'chunk_index': chunk_index,
            'status': 'failed',
            'source_text': source_text,
            'translated_text': None,
            'timestamp': time.time(),
            'elapsed_seconds': last_elapsed,
            'attempts': config['max_retries'] + 1,
        }

    save_state_entry(config['state_file'], entry)
    state[chunk_id] = entry
    return entry.get('translated_text')


def _is_better(new_severity, old_severity):
    """Severity ordering: ok > warn > fatal. Lower severity is 'better'."""
    rank = {'ok': 0, 'warn': 1, 'fatal': 2}
    return rank.get(new_severity, 3) < rank.get(old_severity, 3)


def assemble_translation(state, output_path):
    """Assemble all translated chunks into final markdown file."""
    chunks = []
    i = 0
    while True:
        chunk_id = f"chunk_{i:04d}"
        if chunk_id not in state:
            break
        entry = state[chunk_id]
        if entry.get('translated_text'):
            chunks.append(entry['translated_text'])
        else:
            chunks.append(f"\n\n[!!! CHUNK {chunk_id} FAILED — NEEDS MANUAL TRANSLATION !!!]\n\n")
        i += 1

    full_text = '\n\n'.join(chunks)
    with open(output_path, 'w', encoding='utf-8') as f:
        f.write(full_text)

    return i

def main():
    print("=" * 60)
    print("BOOK TRANSLATION - Phase 3 (v2)")
    print("=" * 60)

    # Load input
    input_path = Path(CONFIG['input_file'])
    if not input_path.exists():
        print(f"Error: {CONFIG['input_file']} not found.")
        sys.exit(1)

    print(f"\nLoading source text from {CONFIG['input_file']}...")
    with open(input_path, 'r', encoding='utf-8') as f:
        source_text = f.read()

    # Load glossary
    glossary_path = Path(CONFIG['glossary_file'])
    if not glossary_path.exists():
        print(f"Error: {CONFIG['glossary_file']} not found.")
        sys.exit(1)

    print(f"Loading glossary from {CONFIG['glossary_file']}...")
    glossary_str = load_glossary(glossary_path)
    print(f"  Glossary: {len(glossary_str.splitlines())} entries")

    # Split into chunks
    print(f"\nSplitting text into chunks (target: {CONFIG['chunk_size_target']} words)...")
    chunks = split_into_chunks(source_text,
                               target_words=CONFIG['chunk_size_target'],
                               max_words=CONFIG['chunk_size_max'])
    print(f"  Created {len(chunks)} chunks")

    word_counts = [len(c.split()) for c in chunks]
    print(f"  Words per chunk: min={min(word_counts)}, max={max(word_counts)}, avg={sum(word_counts)//len(word_counts)}")
    print(f"  Total words: {sum(word_counts)}")

    # Load existing state (for resuming)
    print(f"\nLoading state from {CONFIG['state_file']}...")
    state = load_state(CONFIG['state_file'])
    done_count = sum(1 for v in state.values()
                     if v.get('status') in ('done', 'done_with_issues'))
    failed_count = sum(1 for v in state.values()
                       if v.get('status') in ('failed', 'failed_fallback'))
    print(f"  Already translated: {done_count}/{len(chunks)} chunks")
    if failed_count:
        print(f"  Previously failed:  {failed_count} chunks (will be retried)")

    # If everything is already done, just (re)assemble and exit
    if done_count == len(chunks):
        print("\nAll chunks already translated! Assembling output...")
        n = assemble_translation(state, CONFIG['output_file'])
        print(f"  Assembled {n} chunks into {CONFIG['output_file']}")
        print("Done!")
        return

    # Test Ollama connectivity
    print(f"\nTesting Ollama connection ({CONFIG['ollama_url']})...")
    try:
        r = requests.get(f"{CONFIG['ollama_url']}/api/tags", timeout=10)
        r.raise_for_status()
        models = [m['name'] for m in r.json().get('models', [])]
        if CONFIG['model'] not in models and not any(CONFIG['model'] in m for m in models):
            print(f"  WARNING: Model '{CONFIG['model']}' not found. Available: {models}")
        else:
            print(f"  OK — model '{CONFIG['model']}' available")
    except Exception as e:
        print(f"  ERROR: Cannot connect to Ollama: {e}")
        sys.exit(1)

    # Estimate time
    remaining = len(chunks) - done_count
    est_seconds_per_chunk = 300  # rough average from prior runs
    est_hours = (remaining * est_seconds_per_chunk) / 3600
    print(f"\n  Remaining: {remaining} chunks")
    print(f"  Estimated time: {est_hours:.1f} hours (rough)")
    print(f"  Press Ctrl+C at any time — progress is saved automatically.\n")

    # Main translation loop
    print("-" * 60)
    start_time = time.time()
    translated_count = 0
    issues_this_run = 0
    failed_this_run = 0

    for i, chunk in enumerate(chunks):
        chunk_id = f"chunk_{i:04d}"

        # Skip already done / done_with_issues
        if chunk_id in state and state[chunk_id].get('status') in ('done', 'done_with_issues'):
            continue

        # Get context from previous chunk
        context = get_context(state, i, CONFIG['context_paragraphs'])

        # Progress info
        elapsed_total = time.time() - start_time
        if translated_count > 0:
            avg_time = elapsed_total / translated_count
            eta = avg_time * (remaining - translated_count)
            eta_str = f"ETA: {eta/3600:.1f}h"
        else:
            eta_str = "ETA: calculating..."

        print(f"\n[{i+1}/{len(chunks)}] Translating {chunk_id} "
              f"({len(chunk.split())} words) — {eta_str}")

        result = translate_chunk(i, chunk, glossary_str, context, CONFIG, state)

        # Tally outcomes from the freshly written state
        final_status = state.get(chunk_id, {}).get('status')
        if final_status == 'done':
            translated_count += 1
            preview = (result or '')[:100].replace('\n', ' ')
            print(f"  ✓ Preview: {preview}...")
        elif final_status == 'done_with_issues':
            translated_count += 1
            issues_this_run += 1
            preview = (result or '')[:100].replace('\n', ' ')
            print(f"  ⚠ Saved with issues. Preview: {preview}...")
        else:
            failed_this_run += 1
            print(f"  ✗ FAILED — needs manual translation")

    # Final assembly
    print("\n" + "=" * 60)
    print("TRANSLATION COMPLETE")
    print("=" * 60)

    total_elapsed = time.time() - start_time
    print(f"  Time this session: {total_elapsed/3600:.1f} hours")
    print(f"  Chunks translated this session: {translated_count}")
    if issues_this_run:
        print(f"  Chunks saved with issues: {issues_this_run}")
    if failed_this_run:
        print(f"  Chunks that failed: {failed_this_run}")

    # Assemble final output
    print(f"\nAssembling final output to {CONFIG['output_file']}...")
    n = assemble_translation(state, CONFIG['output_file'])
    print(f"  Assembled {n} chunks.")

    # Overall report (across all sessions, not just this run)
    total_done = sum(1 for v in state.values() if v.get('status') == 'done')
    total_issues = sum(1 for v in state.values() if v.get('status') == 'done_with_issues')
    total_failed = sum(1 for v in state.values() if v.get('status') == 'failed')
    total_fallback = sum(1 for v in state.values() if v.get('status') == 'failed_fallback')

    print(f"\n  Overall status:")
    print(f"    ✓ done:              {total_done}")
    if total_issues:
        print(f"    ⚠ done_with_issues:  {total_issues} (review recommended)")
    if total_fallback:
        print(f"    ⚠ failed_fallback:   {total_fallback} (source kept as placeholder)")
    if total_failed:
        print(f"    ✗ failed:            {total_failed} (no translation stored)")

    # Per-chunk issue summary so you know exactly what to review
    if total_issues or total_fallback or total_failed:
        print(f"\n  Chunks needing review:")
        for cid in sorted(state.keys()):
            entry = state[cid]
            st = entry.get('status')
            if st in ('done_with_issues', 'failed', 'failed_fallback'):
                issues = entry.get('issues') or []
                first = issues[0] if issues else st
                print(f"    {cid} [{st}]: {first}")

    print(f"\nOutput file: {CONFIG['output_file']}")
    print("Done!")


if __name__ == "__main__":
    main()




