#!/usr/bin/env python3
import os
import sys
import shlex
import subprocess
import time
import signal
from pathlib import Path
import socket

from prompt_toolkit import PromptSession
from prompt_toolkit.history import FileHistory
from prompt_toolkit.auto_suggest import AutoSuggestFromHistory
from prompt_toolkit.completion import Completer, Completion
from prompt_toolkit.lexers import Lexer
from prompt_toolkit.styles import Style
from prompt_toolkit.formatted_text import FormattedText
from prompt_toolkit.key_binding import KeyBindings

OSH_VERSION = "4.0-py"
HISTORY_FILE = os.path.expanduser("~/.osh_history")

# ---------- command and path frequency ----------
def load_frequencies():
    cmd_freq = {}
    dir_freq = {}
    try:
        with open(HISTORY_FILE, 'r') as f:
            for line in f:
                line = line.strip()
                if line:
                    try:
                        words = shlex.split(line)
                    except ValueError:
                        # If parsing fails, split on spaces
                        words = line.split()
                    if words:
                        cmd = words[0]
                        cmd_freq[cmd] = cmd_freq.get(cmd, 0) + 1
                        if cmd == 'cd' and len(words) > 1:
                            path = words[1]
                            dir_freq[path] = dir_freq.get(path, 0) + 1
                        # For other commands, count arguments as paths if they look like paths
                        for arg in words[1:]:
                            if '/' in arg or arg in ['.', '..']:
                                dir_freq[arg] = dir_freq.get(arg, 0) + 1
    except FileNotFoundError:
        pass
    return cmd_freq, dir_freq

# ---------- signals ----------
def sigint_handler(signum, frame):
    print()
signal.signal(signal.SIGINT, sigint_handler)

# ---------- prompt ----------
def git_branch():
    try:
        out = subprocess.check_output(
            ["git", "branch", "--show-current"],
            stderr=subprocess.DEVNULL,
            text=True
        ).strip()
        return [("class:git", f" ({out})")] if out else []
    except Exception:
        return []

def get_prompt():
    cwd = os.getcwd()
    user = os.getenv("USER", "user")
    host = socket.gethostname()
    time_str = time.strftime("%H:%M:%S")
    # Shorten cwd if too long
    home = os.path.expanduser("~")
    if cwd.startswith(home):
        cwd = "~" + cwd[len(home):]
    parts = [
        ("class:time", f"[{time_str}] "),
        ("class:user", f"{user}@{host} "),
        ("class:cwd", cwd),
    ] + git_branch() + [
        ("class:prompt", "$ ")
    ]
    return FormattedText(parts)

# ---------- completer ----------
class OSHCompleter(Completer):
    def __init__(self):
        self.cmd_freq, self.dir_freq = load_frequencies()

    def get_completions(self, document, complete_event):
        text = document.text_before_cursor
        words = text.split()
        word = document.get_word_before_cursor(WORD=True)

        if not words:
            # Complete commands sorted by frequency
            path_dirs = os.environ.get('PATH', '').split(':')
            commands = set()
            for d in path_dirs:
                if os.path.isdir(d):
                    try:
                        commands.update(os.listdir(d))
                    except PermissionError:
                        pass
            # Sort by frequency descending, then alphabetically
            sorted_cmds = sorted(commands, key=lambda c: (-self.cmd_freq.get(c, 0), c))
            for cmd in sorted_cmds:
                if cmd.startswith(word):
                    yield Completion(cmd, start_position=-len(word))
            return

        first_word = words[0]
        if first_word in ['cd']:
            # Complete directories sorted by frequency, then name
            try:
                dirs = [d for d in os.listdir('.') if os.path.isdir(d)]
            except PermissionError:
                dirs = []
            sorted_dirs = sorted(dirs, key=lambda d: (-self.dir_freq.get(d, 0), d))
            for d in sorted_dirs:
                if d.startswith(word):
                    yield Completion(d, start_position=-len(word))
        elif first_word in ['ls', 'cat', 'less', 'head', 'tail', 'cp', 'mv', 'rm', 'chmod', 'chown']:
            # Complete files and directories sorted by frequency, then name
            try:
                items = os.listdir('.')
            except PermissionError:
                items = []
            sorted_items = sorted(items, key=lambda i: (-self.dir_freq.get(i, 0), i))
            for item in sorted_items:
                if item.startswith(word):
                    yield Completion(item, start_position=-len(word))
        else:
            # Default: complete commands if first word, else files
            if len(words) == 1:
                path_dirs = os.environ.get('PATH', '').split(':')
                commands = set()
                for d in path_dirs:
                    if os.path.isdir(d):
                        try:
                            commands.update(os.listdir(d))
                        except PermissionError:
                            pass
                sorted_cmds = sorted(commands, key=lambda c: (-self.cmd_freq.get(c, 0), c))
                for cmd in sorted_cmds:
                    if cmd.startswith(word):
                        yield Completion(cmd, start_position=-len(word))
            else:
                try:
                    items = os.listdir('.')
                except PermissionError:
                    items = []
                sorted_items = sorted(items, key=lambda i: (-self.dir_freq.get(i, 0), i))
                for item in sorted_items:
                    if item.startswith(word):
                        yield Completion(item, start_position=-len(word))

# ---------- lexer for syntax highlighting ----------
class OSHLexer(Lexer):
    def lex_document(self, document):
        def get_line(lineno):
            line = document.lines[lineno]
            parts = []
            words = shlex.split(line, posix=False) if line.strip() else []
            pos = 0
            for word in words:
                start = line.find(word, pos)
                if start != -1:
                    if pos < start:
                        parts.append(('class:default', line[pos:start]))
                    if word in ['cd', 'exit', 'help', 'ls', 'pwd', 'echo']:  # known commands
                        parts.append(('class:command', word))
                    else:
                        parts.append(('class:argument', word))
                    pos = start + len(word)
            if pos < len(line):
                parts.append(('class:default', line[pos:]))
            return parts
        return get_line

style = Style.from_dict({
    'command': '#ansigreen bold',
    'argument': '#ansiblue',
    'default': '#ansiwhite',
    'time': '#ansiblue bold',
    'user': '#ansigreen bold',
    'cwd': '#ansicyan bold',
    'git': '#ansiyellow',
    'prompt': '#ansimagenta bold',
})

# ---------- shell state ----------
shell_vars = {}
aliases = {}
config = {}

# ---------- builtins ----------
def builtin_cd(args):
    try:
        os.chdir(args[1] if len(args) > 1 else os.path.expanduser("~"))
        return 0
    except Exception as e:
        print(e)
        return 1

def builtin_exit(args):
    sys.exit(0)

def builtin_pwd(args):
    print(os.getcwd())
    return 0

def builtin_echo(args):
    print(' '.join(args[1:]))
    return 0

def builtin_export(args):
    if len(args) == 1:
        for k, v in os.environ.items():
            print(f"{k}={v}")
    else:
        for arg in args[1:]:
            if '=' in arg:
                k, v = arg.split('=', 1)
                os.environ[k] = v
                shell_vars[k] = v
            else:
                if arg in shell_vars:
                    os.environ[arg] = shell_vars[arg]
    return 0

def builtin_unset(args):
    for arg in args[1:]:
        if arg in os.environ:
            del os.environ[arg]
        if arg in shell_vars:
            del shell_vars[arg]
    return 0

def builtin_history(args):
    try:
        with open(HISTORY_FILE, 'r') as f:
            for i, line in enumerate(f, 1):
                print(f"{i:4} {line.strip()}")
    except FileNotFoundError:
        pass
    return 0

def builtin_alias(args):
    if len(args) == 1:
        for k, v in aliases.items():
            print(f"alias {k}='{v}'")
    else:
        for arg in args[1:]:
            if '=' in arg:
                k, v = arg.split('=', 1)
                aliases[k] = v
            else:
                if arg in aliases:
                    print(f"alias {arg}='{aliases[arg]}'")
    return 0

def builtin_which(args):
    if len(args) < 2:
        print("which: missing argument")
        return 1
    cmd = args[1]
    path_dirs = os.environ.get('PATH', '').split(':')
    for d in path_dirs:
        if os.path.isdir(d):
            full_path = os.path.join(d, cmd)
            if os.path.isfile(full_path) and os.access(full_path, os.X_OK):
                print(full_path)
                return 0
    print(f"which: {cmd}: command not found")
    return 1

def builtin_type(args):
    if len(args) < 2:
        print("type: missing argument")
        return 1
    cmd = args[1]
    if cmd in BUILTINS:
        print(f"{cmd} is a shell builtin")
    else:
        path_dirs = os.environ.get('PATH', '').split(':')
        for d in path_dirs:
            if os.path.isdir(d):
                full_path = os.path.join(d, cmd)
                if os.path.isfile(full_path) and os.access(full_path, os.X_OK):
                    print(f"{cmd} is {full_path}")
                    return 0
        print(f"type: {cmd}: not found")
        return 1
    return 0

def builtin_source(args):
    if len(args) < 2:
        print("source: missing argument")
        return 1
    file = args[1]
    try:
        with open(file, 'r') as f:
            for line in f:
                line = line.strip()
                if line and not line.startswith('#'):
                    execute(line)
    except FileNotFoundError:
        print(f"source: {file}: No such file or directory")
        return 1
    return 0

def builtin_jobs(args):
    # Basic jobs, since no full job control
    print("jobs: no jobs")
    return 0

def builtin_kill(args):
    if len(args) < 2:
        print("kill: missing argument")
        return 1
    try:
        pid = int(args[1])
        os.kill(pid, signal.SIGTERM)
    except ValueError:
        print("kill: invalid process id")
        return 1
    except ProcessLookupError:
        print("kill: no such process")
        return 1
    return 0

def builtin_set(args):
    if len(args) == 1:
        for k, v in config.items():
            print(f"{k}={v}")
    else:
        for arg in args[1:]:
            if '=' in arg:
                k, v = arg.split('=', 1)
                config[k] = v
            else:
                if arg in config:
                    print(f"{arg}={config[arg]}")
    return 0

def builtin_help(args):
    print("OSH", OSH_VERSION)
    print("builtins: cd exit pwd echo export unset history alias which type source jobs kill set help")
    return 0

BUILTINS = {
    "cd": builtin_cd,
    "exit": builtin_exit,
    "pwd": builtin_pwd,
    "echo": builtin_echo,
    "export": builtin_export,
    "unset": builtin_unset,
    "history": builtin_history,
    "alias": builtin_alias,
    "which": builtin_which,
    "type": builtin_type,
    "source": builtin_source,
    "jobs": builtin_jobs,
    "kill": builtin_kill,
    "set": builtin_set,
    "help": builtin_help,
}

# ---------- execution ----------
def run_pipeline(cmds):
    procs = []
    prev = None

    for cmd in cmds:
        try:
            p = subprocess.Popen(
                cmd,
                stdin=prev,
                stdout=subprocess.PIPE if cmd != cmds[-1] else None,
                stderr=None  # Let stderr go to terminal
            )
        except FileNotFoundError:
            print(f"osh: command not found: {cmd[0]}")
            # Clean up previous pipes
            if prev:
                prev.close()
            for p in procs:
                p.terminate()
            return 127
        if prev:
            prev.close()
        prev = p.stdout
        procs.append(p)

    status = procs[-1].wait()
    return status

def execute_command(args, stdin=None, stdout=None, stderr=None, background=False):
    if not args:
        return 0

    if args[0] in BUILTINS:
        if background:
            print("osh: cannot background builtin")
            return 1
        return BUILTINS[args[0]](args)

    try:
        p = subprocess.Popen(
            args,
            stdin=stdin,
            stdout=stdout,
            stderr=stderr
        )
        if background:
            print(f"[{p.pid}]")
            return 0
        return p.wait()
    except FileNotFoundError:
        print(f"osh: command not found: {args[0]}")
        return 127

def expand_history(line):
    # Handle !! and !n
    if '!!' in line:
        try:
            with open(HISTORY_FILE, 'r') as f:
                lines = f.readlines()
                if lines:
                    last_cmd = lines[-1].strip()
                    line = line.replace('!!', last_cmd)
        except FileNotFoundError:
            pass
    # Handle !n
    import re
    def replace_n(match):
        n = int(match.group(1))
        try:
            with open(HISTORY_FILE, 'r') as f:
                lines = f.readlines()
                if 1 <= n <= len(lines):
                    return lines[n-1].strip()
        except FileNotFoundError:
            pass
        return match.group(0)
    line = re.sub(r'!(\d+)', replace_n, line)
    return line

def expand_variables(line, last_status):
    # Expand history first
    line = expand_history(line)
    # Expand $VAR and $?
    expanded = os.path.expandvars(line)
    expanded = expanded.replace('$?', str(last_status))
    return expanded

def expand_aliases(line):
    words = shlex.split(line)
    if words and words[0] in aliases:
        alias_value = aliases[words[0]]
        # Simple replacement, assume no args for alias
        line = alias_value + ' ' + ' '.join(words[1:])
    return line

def execute(line):
    # Handle variable assignments
    line = line.strip()
    if '=' in line and not line.startswith('export ') and not line.startswith('alias '):
        # Simple assignment VAR=value
        parts = line.split('=', 1)
        if len(parts) == 2 and parts[0].isidentifier():
            var, val = parts
            shell_vars[var] = val
            return 0

    # Expand aliases first
    line = expand_aliases(line)

    # Expand variables
    global last_status
    line = expand_variables(line, last_status)

    # Handle ; && ||
    commands = []
    current = ""
    i = 0
    while i < len(line):
        if line[i:i+2] == '&&':
            commands.append((current.strip(), '&&'))
            current = ""
            i += 2
        elif line[i:i+2] == '||':
            commands.append((current.strip(), '||'))
            current = ""
            i += 2
        elif line[i] == ';':
            commands.append((current.strip(), ';'))
            current = ""
            i += 1
        else:
            current += line[i]
            i += 1
    if current.strip():
        commands.append((current.strip(), None))

    last_status = 0
    for cmd, op in commands:
        if op == '&&' and last_status != 0:
            continue
        if op == '||' and last_status == 0:
            continue

        # Split on | for pipelines
        pipeline_parts = cmd.split("|")
        if len(pipeline_parts) > 1:
            cmds_parsed = [shlex.split(p.strip()) for p in pipeline_parts]
            last_status = run_pipeline(cmds_parsed)
        else:
            # Single command with redirections
            args = shlex.split(cmd)
            if not args:
                continue

            # Handle background
            background = False
            if args and args[-1] == '&':
                background = True
                args = args[:-1]

            # Handle redirections
            stdout_file = None
            stdin_file = None
            append = False

            i = 0
            while i < len(args):
                if args[i] == '>':
                    if i + 1 < len(args):
                        stdout_file = args[i + 1]
                        args = args[:i]
                    break
                elif args[i] == '>>':
                    if i + 1 < len(args):
                        stdout_file = args[i + 1]
                        append = True
                        args = args[:i]
                    break
                elif args[i] == '<':
                    if i + 1 < len(args):
                        stdin_file = args[i + 1]
                        args = args[:i]
                    break
                i += 1

            try:
                if stdout_file and stdin_file:
                    with open(stdout_file, 'a' if append else 'w') as outf, open(stdin_file, 'r') as inf:
                        last_status = execute_command(args, stdin=inf, stdout=outf, background=background)
                elif stdout_file:
                    with open(stdout_file, 'a' if append else 'w') as outf:
                        last_status = execute_command(args, stdout=outf, background=background)
                elif stdin_file:
                    with open(stdin_file, 'r') as inf:
                        last_status = execute_command(args, stdin=inf, background=background)
                else:
                    last_status = execute_command(args, background=background)
            except Exception as e:
                print(f"osh: {e}")
                last_status = 1

    return last_status

# ---------- config ----------
def load_config():
    global config
    config_file = os.path.expanduser("~/.oshrc")
    try:
        with open(config_file, 'r') as f:
            for line in f:
                line = line.strip()
                if line and not line.startswith('#'):
                    if line.startswith('OSH_') and '=' in line:
                        key, val = line.split('=', 1)
                        config[key] = val
                    else:
                        execute(line)
    except FileNotFoundError:
        pass

# ---------- key bindings ----------
kb = KeyBindings()

@kb.add('c-c')
def copy_selection(event):
    """Copy selected text to clipboard"""
    # Get selected text
    buffer = event.current_buffer
    if buffer.selection_state:
        selected_text = buffer.selection_state.original_document.text[buffer.selection_state.original_cursor_position:buffer.cursor_position]
        # Copy to clipboard (using xclip or similar if available)
        try:
            subprocess.run(['xclip', '-selection', 'clipboard'], input=selected_text, text=True, check=True)
        except (subprocess.CalledProcessError, FileNotFoundError):
            # Fallback to OSC 52 if xclip not available
            # But for simplicity, just notify
            pass

@kb.add('c-v')
def paste_clipboard(event):
    """Paste from clipboard"""
    try:
        result = subprocess.run(['xclip', '-selection', 'clipboard', '-o'], capture_output=True, text=True, check=True)
        event.current_buffer.insert_text(result.stdout)
    except (subprocess.CalledProcessError, FileNotFoundError):
        # Fallback or notify
        pass

# ---------- main loop ----------
last_status = 0

def main():
    global last_status
    load_config()
    print(f"Welcome to OSH {OSH_VERSION}")

    session = PromptSession(
        history=FileHistory(HISTORY_FILE),
        auto_suggest=AutoSuggestFromHistory(),
        completer=OSHCompleter(),
        lexer=OSHLexer(),
        style=style,
        key_bindings=kb,
    )

    while True:
        try:
            line = session.prompt(get_prompt())
        except EOFError:
            print()
            break
        except KeyboardInterrupt:
            continue

        if not line.strip():
            continue

        start = time.monotonic()
        last_status = execute(line)
        elapsed = time.monotonic() - start

        if elapsed > 1.0:
            print(f"[{elapsed:.2f}s]")

if __name__ == "__main__":
    if len(sys.argv) > 1:
        # Script mode
        script_file = sys.argv[1]
        try:
            with open(script_file, 'r') as f:
                for line in f:
                    line = line.strip()
                    if line and not line.startswith('#'):
                        execute(line)
        except FileNotFoundError:
            print(f"osh: {script_file}: No such file or directory")
            sys.exit(1)
    else:
        main()
