Administrator
发布于 2025-02-01 / 7 阅读 / 0 评论 / 0 点赞

模型微调 mlx

import tkinter as tk
from tkinter import ttk, filedialog, scrolledtext, messagebox
import subprocess
import threading
import os
import json
from datetime import datetime
import queue
import sys
from pathlib import Path
import shutil
import time


class MLXFineTuningGUI:
    def __init__(self, root):
        self.root = root
        self.root.title("MLX Model Fine-tuning Tool")
        self.root.geometry("800x900")

        # Check environment
        self.check_environment()

        # Create log queue
        self.log_queue = queue.Queue()

        # Create main frame
        self.main_frame = ttk.Frame(root, padding="10")
        self.main_frame.grid(row=0, column=0, sticky=(tk.W, tk.E, tk.N, tk.S))

        # Configure root window grid
        root.grid_rowconfigure(0, weight=1)
        root.grid_columnconfigure(0, weight=1)

        # Configure main frame grid
        self.main_frame.grid_columnconfigure(1, weight=1)

        # Default values dictionary
        self.default_values = {
            'batch_size': '2',
            'num_layers': '8',
            'iterations': '1000',
            'fine_tune_type': 'lora',
            'val_batches': '4',
            'learning_rate': '2e-4',
            'steps_per_report': '5',
            'steps_per_eval': '25',
            'save_every': '50',
            'test_batches': '4',
            'max_seq_length': '256',
            'seed': '42',
            'grad_accum_steps': '4',
            'use_float16': True,
            'grad_checkpoint': True
        }

        # Model selection
        ttk.Label(self.main_frame, text="Model Directory*:").grid(row=0, column=0, sticky=tk.W, pady=5)
        self.model_path = tk.StringVar()
        ttk.Entry(self.main_frame, textvariable=self.model_path, width=50).grid(row=0, column=1, padx=5, sticky=(tk.W, tk.E))
        ttk.Button(self.main_frame, text="Browse", command=self.browse_model).grid(row=0, column=2)

        # Data selection
        ttk.Label(self.main_frame, text="Data Directory*:").grid(row=1, column=0, sticky=tk.W, pady=5)
        self.data_path = tk.StringVar()
        ttk.Entry(self.main_frame, textvariable=self.data_path, width=50).grid(row=1, column=1, padx=5, sticky=(tk.W, tk.E))
        ttk.Button(self.main_frame, text="Browse", command=self.browse_data).grid(row=1, column=2)

        # Training parameters frame
        params_frame = ttk.LabelFrame(self.main_frame, text="Training Parameters", padding="5")
        params_frame.grid(row=2, column=0, columnspan=3, sticky=(tk.W, tk.E), pady=10)

        # Create parameter grid
        row = 0
        # Basic parameters
        ttk.Label(params_frame, text="Batch Size*:").grid(row=row, column=0, sticky=tk.W)
        self.batch_size = tk.StringVar(value=self.default_values['batch_size'])
        ttk.Entry(params_frame, textvariable=self.batch_size, width=10).grid(row=row, column=1, padx=5)

        ttk.Label(params_frame, text="Num Layers*:").grid(row=row, column=2, sticky=tk.W, padx=10)
        self.num_layers = tk.StringVar(value=self.default_values['num_layers'])
        ttk.Entry(params_frame, textvariable=self.num_layers, width=10).grid(row=row, column=3, padx=5)

        ttk.Label(params_frame, text="Iterations*:").grid(row=row, column=4, sticky=tk.W, padx=10)
        self.iterations = tk.StringVar(value=self.default_values['iterations'])
        ttk.Entry(params_frame, textvariable=self.iterations, width=10).grid(row=row, column=5, padx=5)

        # Fine-tuning type
        row += 1
        ttk.Label(params_frame, text="Fine-tune Type:").grid(row=row, column=0, sticky=tk.W)
        self.fine_tune_type = tk.StringVar(value=self.default_values['fine_tune_type'])
        fine_tune_combo = ttk.Combobox(params_frame, textvariable=self.fine_tune_type, width=10)
        fine_tune_combo['values'] = ('lora', 'dora', 'full')
        fine_tune_combo.grid(row=row, column=1, padx=5)

        ttk.Label(params_frame, text="Val Batches:").grid(row=row, column=2, sticky=tk.W, padx=10)
        self.val_batches = tk.StringVar(value=self.default_values['val_batches'])
        ttk.Entry(params_frame, textvariable=self.val_batches, width=10).grid(row=row, column=3, padx=5)

        ttk.Label(params_frame, text="Learning Rate:").grid(row=row, column=4, sticky=tk.W, padx=10)
        self.learning_rate = tk.StringVar(value=self.default_values['learning_rate'])
        ttk.Entry(params_frame, textvariable=self.learning_rate, width=10).grid(row=row, column=5, padx=5)

        # Additional parameters
        row += 1
        ttk.Label(params_frame, text="Steps per Report:").grid(row=row, column=0, sticky=tk.W)
        self.steps_per_report = tk.StringVar(value=self.default_values['steps_per_report'])
        ttk.Entry(params_frame, textvariable=self.steps_per_report, width=10).grid(row=row, column=1, padx=5)

        ttk.Label(params_frame, text="Steps per Eval:").grid(row=row, column=2, sticky=tk.W, padx=10)
        self.steps_per_eval = tk.StringVar(value=self.default_values['steps_per_eval'])
        ttk.Entry(params_frame, textvariable=self.steps_per_eval, width=10).grid(row=row, column=3, padx=5)

        ttk.Label(params_frame, text="Save Every:").grid(row=row, column=4, sticky=tk.W, padx=10)
        self.save_every = tk.StringVar(value=self.default_values['save_every'])
        ttk.Entry(params_frame, textvariable=self.save_every, width=10).grid(row=row, column=5, padx=5)

        row += 1
        ttk.Label(params_frame, text="Test Batches:").grid(row=row, column=0, sticky=tk.W)
        self.test_batches = tk.StringVar(value=self.default_values['test_batches'])
        ttk.Entry(params_frame, textvariable=self.test_batches, width=10).grid(row=row, column=1, padx=5)

        ttk.Label(params_frame, text="Max Seq Length:").grid(row=row, column=2, sticky=tk.W, padx=10)
        self.max_seq_length = tk.StringVar(value=self.default_values['max_seq_length'])
        ttk.Entry(params_frame, textvariable=self.max_seq_length, width=10).grid(row=row, column=3, padx=5)

        ttk.Label(params_frame, text="Seed:").grid(row=row, column=4, sticky=tk.W, padx=10)
        self.seed = tk.StringVar(value=self.default_values['seed'])
        ttk.Entry(params_frame, textvariable=self.seed, width=10).grid(row=row, column=5, padx=5)

        # Memory optimization options
        row += 1
        ttk.Label(params_frame, text="Memory Optimization", font=('', 10, 'bold')).grid(row=row, column=0, sticky=tk.W, pady=(10, 5))

        row += 1
        self.use_float16 = tk.BooleanVar(value=self.default_values['use_float16'])
        ttk.Checkbutton(params_frame, text="Use Float16", variable=self.use_float16).grid(
            row=row, column=0, columnspan=2, sticky=tk.W)

        ttk.Label(params_frame, text="Gradient Accum Steps:").grid(row=row, column=2, sticky=tk.W, padx=10)
        self.grad_accum_steps = tk.StringVar(value=self.default_values['grad_accum_steps'])
        ttk.Entry(params_frame, textvariable=self.grad_accum_steps, width=10).grid(row=row, column=3, padx=5)

        # Advanced options
        row += 1
        ttk.Label(params_frame, text="Advanced Options:", font=('', 10, 'bold')).grid(row=row, column=0, sticky=tk.W, pady=(10, 5))

        row += 1
        self.grad_checkpoint = tk.BooleanVar(value=self.default_values['grad_checkpoint'])
        ttk.Checkbutton(params_frame, text="Use Gradient Checkpoint", variable=self.grad_checkpoint).grid(
            row=row, column=0, columnspan=2, sticky=tk.W)

        # Output directory
        ttk.Label(self.main_frame, text="Output Directory*:").grid(row=4, column=0, sticky=tk.W, pady=5)
        self.output_path = tk.StringVar()
        ttk.Entry(self.main_frame, textvariable=self.output_path, width=50).grid(row=4, column=1, padx=5, sticky=(tk.W, tk.E))
        ttk.Button(self.main_frame, text="Browse", command=self.browse_output).grid(row=4, column=2)

        # Progress bar
        self.progress = ttk.Progressbar(self.main_frame, length=300, mode='determinate')
        self.progress.grid(row=5, column=0, columnspan=3, pady=10, sticky=(tk.W, tk.E))

        # Progress label
        self.progress_label = ttk.Label(self.main_frame, text="0%")
        self.progress_label.grid(row=5, column=2, pady=10)

        # Log area
        log_frame = ttk.LabelFrame(self.main_frame, text="Training Log", padding="5")
        log_frame.grid(row=6, column=0, columnspan=3, sticky=(tk.W, tk.E, tk.N, tk.S), pady=5)

        # Configure log frame grid
        log_frame.grid_rowconfigure(0, weight=1)
        log_frame.grid_columnconfigure(0, weight=1)

        self.log_area = scrolledtext.ScrolledText(log_frame, height=15, width=80)
        self.log_area.grid(row=0, column=0, sticky=(tk.W, tk.E, tk.N, tk.S))

        # Configure log tags
        self.log_area.tag_config('error', foreground='red')
        self.log_area.tag_config('command', foreground='blue')
        self.log_area.tag_config('success', foreground='green')
        self.log_area.tag_config('info', foreground='black')

        # Control buttons
        button_frame = ttk.Frame(self.main_frame)
        button_frame.grid(row=7, column=0, columnspan=3, pady=10)

        self.start_button = ttk.Button(button_frame, text="Start Fine-tuning", command=self.start_finetuning)
        self.start_button.grid(row=0, column=0, padx=5)

        self.stop_button = ttk.Button(button_frame, text="Stop", command=self.stop_finetuning, state=tk.DISABLED)
        self.stop_button.grid(row=0, column=1, padx=5)

        self.clear_button = ttk.Button(button_frame, text="Clear Log", command=self.clear_log)
        self.clear_button.grid(row=0, column=2, padx=5)

        # Initialize process variables
        self.process = None
        self.training_thread = None
        self.is_running = False

        # Start log observer
        self.root.after(100, self.check_log_queue)

    def ensure_directory(self, path):
        """Ensure directory exists, create if it doesn't"""
        Path(path).mkdir(parents=True, exist_ok=True)
        return path

    def get_adapter_path(self):
        """Get the path for adapter files"""
        return os.path.join(self.output_path.get(), "adapters")

    def get_finetuned_model_path(self):
        """Get the path for the finetuned model"""
        return os.path.join(self.output_path.get(), "finetuned_model")

    def run_finetuning(self):
        """Execute fine-tuning process"""
        try:
            # Create necessary directories
            adapter_path = self.ensure_directory(self.get_adapter_path())
            finetuned_model_path = self.ensure_directory(self.get_finetuned_model_path())

            # Build training command
            python_executable = sys.executable
            train_command = [
                python_executable,
                "-m",
                "mlx_lm.lora",
                "--train",
                "--model", self.model_path.get(),
                "--batch-size", self.batch_size.get(),
                "--num-layers", self.num_layers.get(),
                "--iters", self.iterations.get(),
                "--data", self.data_path.get(),
                "--fine-tune-type", self.fine_tune_type.get(),
                "--adapter-path", adapter_path
            ]

            # Add optional parameters if they have non-default values
            if self.val_batches.get() != "0":
                train_command.extend(["--val-batches", self.val_batches.get()])
            if self.learning_rate.get() != "0":
                train_command.extend(["--learning-rate", self.learning_rate.get()])
            if self.steps_per_report.get() != "0":
                train_command.extend(["--steps-per-report", self.steps_per_report.get()])
            if self.steps_per_eval.get() != "0":
                train_command.extend(["--steps-per-eval", self.steps_per_eval.get()])
            if self.save_every.get() != "0":
                train_command.extend(["--save-every", self.save_every.get()])
            if self.test_batches.get() != "0":
                train_command.extend(["--test-batches", self.test_batches.get()])
            if self.max_seq_length.get() != "0":
                train_command.extend(["--max-seq-length", self.max_seq_length.get()])
            if self.seed.get() != "0":
                train_command.extend(["--seed", self.seed.get()])
            if self.grad_checkpoint.get():
                train_command.append("--grad-checkpoint")

            # Log the command
            self.log_message(f"Executing command: {' '.join(train_command)}", 'command')

            # Set environment variables
            env = os.environ.copy()
            env["PYTHONPATH"] = os.pathsep.join(sys.path)

            # Create process and set pipes
            self.process = subprocess.Popen(
                train_command,
                stdout=subprocess.PIPE,
                stderr=subprocess.PIPE,
                text=True,
                bufsize=1,
                universal_newlines=True,
                env=env
            )

            # Create output queues
            output_queue = queue.Queue()
            error_queue = queue.Queue()

            # Start output reader threads
            stdout_thread = threading.Thread(target=self.read_output,
                                          args=(self.process.stdout, output_queue),
                                          daemon=True)
            stderr_thread = threading.Thread(target=self.read_output,
                                          args=(self.process.stderr, error_queue),
                                          daemon=True)

            stdout_thread.start()
            stderr_thread.start()

            current_iter = 0
            total_iters = int(self.iterations.get())

            # Process output
            while self.process.poll() is None and self.is_running:
                # Handle standard output
                while True:
                    try:
                        output = output_queue.get_nowait()
                        if output:
                            output = output.strip()
                            self.log_message(output, 'info')
                            # Update progress bar
                            if "Iter" in output:
                                try:
                                    iter_str = output.split("Iter")[1].split(":")[0].strip()
                                    current_iter = int(iter_str)
                                    self.update_progress(current_iter, total_iters)
                                except:
                                    pass
                    except queue.Empty:
                        break

                # Handle error output
                while True:
                    try:
                        error = error_queue.get_nowait()
                        if error:
                            self.log_message(f"Error: {error.strip()}", 'error')
                    except queue.Empty:
                        break

                time.sleep(0.1)

            # Wait for output threads to finish
            stdout_thread.join(timeout=5)
            stderr_thread.join(timeout=5)

            if not self.is_running:
                self.log_message("Training process was terminated by user.", 'info')
                return

            if self.process.returncode != 0:
                self.log_message(f"Training process terminated abnormally with code: {self.process.returncode}", 'error')
                return

            # Add evaluation phase
            if self.should_run_test():
                self.run_test_phase()

            # Add model fusion phase
            self.run_fusion_phase()

        except Exception as e:
            self.log_message(f"Error during fine-tuning: {str(e)}", 'error')
            import traceback
            self.log_message(f"Error details: {traceback.format_exc()}", 'error')

        finally:
            self.is_running = False
            self.start_button.config(state=tk.NORMAL)
            self.stop_button.config(state=tk.DISABLED)
            self.progress['value'] = 100
            self.progress_label['text'] = "100%"

    def should_run_test(self):
        """Determine if test phase should be run"""
        return self.test_batches.get() != "0"

    def run_test_phase(self):
        """Run test phase"""
        self.log_message("\nStarting model evaluation...", 'info')

        adapter_path = self.get_adapter_path()
        if not os.path.exists(adapter_path):
            self.log_message(f"Error: Adapter path not found: {adapter_path}", 'error')
            return

        test_command = [
            sys.executable,
            "-m",
            "mlx_lm.lora",
            "--model", self.model_path.get(),
            "--data", self.data_path.get(),
            "--adapter-path", adapter_path,
            "--test"
        ]

        # Add optional test parameters
        if self.test_batches.get() != "0":
            test_command.extend(["--test-batches", self.test_batches.get()])
        if self.max_seq_length.get() != "0":
            test_command.extend(["--max-seq-length", self.max_seq_length.get()])

        self.log_message(f"Executing evaluation command: {' '.join(test_command)}", 'command')

        # Execute evaluation command
        test_process = subprocess.Popen(
            test_command,
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE,
            text=True,
            env=os.environ.copy()
        )

        # Handle evaluation output
        stdout, stderr = test_process.communicate()

        if stdout:
            self.log_message(stdout.strip(), 'info')
        if stderr:
            self.log_message(f"Evaluation error: {stderr.strip()}", 'error')

        if test_process.returncode == 0:
            self.log_message("Model evaluation completed!", 'success')
        else:
            self.log_message("Model evaluation failed!", 'error')

    def run_fusion_phase(self):
        """Run model fusion phase"""
        self.log_message("\nStarting model fusion...", 'info')

        adapter_path = self.get_adapter_path()
        finetuned_model_path = self.get_finetuned_model_path()

        if not os.path.exists(adapter_path):
            self.log_message(f"Error: Adapter path not found: {adapter_path}", 'error')
            return

        fuse_command = [
            sys.executable,
            "-m",
            "mlx_lm.fuse",
            "--model", self.model_path.get(),
            "--adapter-path", adapter_path,
            "--save-path", finetuned_model_path
        ]

        self.log_message(f"Executing command: {' '.join(fuse_command)}", 'command')

        # Execute fusion command
        fuse_process = subprocess.Popen(
            fuse_command,
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE,
            text=True,
            env=os.environ.copy()
        )

        # Handle fusion output
        stdout, stderr = fuse_process.communicate()

        if stdout:
            self.log_message(stdout.strip(), 'info')
        if stderr:
            self.log_message(f"Fusion error: {stderr.strip()}", 'error')

        if fuse_process.returncode == 0:
            self.log_message("Model fusion completed successfully!", 'success')
            self.log_message(f"Fine-tuned model saved to: {finetuned_model_path}", 'success')
        else:
            self.log_message("Model fusion failed!", 'error')

    def read_output(self, pipe, queue):
        """Helper function to read output"""
        try:
            while True:
                line = pipe.readline()
                if not line and self.process.poll() is not None:
                    break
                if line:
                    queue.put(line)
        except Exception as e:
            queue.put(f"Error reading output: {str(e)}")
        finally:
            pipe.close()

    def check_environment(self):
        """Check if environment meets requirements"""
        try:
            import mlx
            import mlx_lm
        except ImportError:
            messagebox.showerror("Error", "Please install MLX and MLX-LM first:\n pip install mlx mlx-lm")
            sys.exit(1)

        # Check if running on Apple Silicon Mac
        if not (sys.platform == "darwin" and "arm" in os.uname().machine):
            messagebox.showerror("Error", "This program can only run on Apple Silicon Mac")
            sys.exit(1)

    def browse_model(self):
        directory = filedialog.askdirectory(title="Select Model Directory")
        if directory:
            self.model_path.set(directory)

    def browse_data(self):
        directory = filedialog.askdirectory(title="Select Data Directory")
        if directory:
            self.data_path.set(directory)

    def browse_output(self):
        directory = filedialog.askdirectory(title="Select Output Directory")
        if directory:
            self.output_path.set(directory)

    def clear_log(self):
        """Clear log area"""
        self.log_area.delete(1.0, tk.END)

    def check_log_queue(self):
        """Check log queue and update log area"""
        while True:
            try:
                message, tag = self.log_queue.get_nowait()
                self.log_area.insert(tk.END, message + '\n', tag)
                self.log_area.see(tk.END)
                self.log_area.update_idletasks()
            except queue.Empty:
                break
        self.root.after(100, self.check_log_queue)

    def log_message(self, message, tag='info'):
        """Add tagged log message"""
        self.log_queue.put((message, tag))

    def verify_data_files(self, data_path):
        """Verify existence and requirements of data files"""
        required_files = ['train.jsonl', 'valid.jsonl', 'test.jsonl']
        missing_files = []
        for file in required_files:
            if not os.path.exists(os.path.join(data_path, file)):
                missing_files.append(file)
        return missing_files

    def start_finetuning(self):
        """Start fine-tuning process"""
        if not all([self.model_path.get(), self.data_path.get(), self.output_path.get()]):
            self.log_message("Error: Please select all required directories.", 'error')
            messagebox.showerror("Error", "Please select all required directories")
            return

        # Verify data files
        missing_files = self.verify_data_files(self.data_path.get())
        if missing_files:
            self.log_message(f"Error: Missing required files in data directory: {', '.join(missing_files)}", 'error')
            messagebox.showerror("Error", f"Missing required files in data directory:\n{', '.join(missing_files)}")
            return

        try:
            # Create output directories
            self.ensure_directory(self.get_adapter_path())
            self.ensure_directory(self.get_finetuned_model_path())
        except Exception as e:
            self.log_message(f"Error creating directories: {str(e)}", 'error')
            messagebox.showerror("Error", f"Failed to create output directories: {str(e)}")
            return

        self.is_running = True
        self.start_button.config(state=tk.DISABLED)
        self.stop_button.config(state=tk.NORMAL)
        self.progress['value'] = 0
        self.progress_label['text'] = "0%"

        # Start training in a separate thread
        self.training_thread = threading.Thread(target=self.run_finetuning)
        self.training_thread.start()

    def stop_finetuning(self):
        """Stop fine-tuning process"""
        if self.process:
            self.is_running = False
            self.process.terminate()
            self.log_message("Stopping fine-tuning process...", 'info')

    def update_progress(self, current_iter, total_iters):
        """Update progress bar and label"""
        progress = (current_iter / total_iters) * 100
        self.progress['value'] = progress
        self.progress_label['text'] = f"{progress:.1f}%"
        self.root.update_idletasks()


def main():
    """Main function"""
    try:
        # Try to use themed style
        from ttkthemes import ThemedTk
        root = ThemedTk(theme="arc")
    except ImportError:
        root = tk.Tk()

    # Set window icon (if on macOS)
    if sys.platform == "darwin":
        root.createcommand('tk::mac::Quit', root.quit)

    # Set minimum window size
    root.minsize(800, 900)

    # Create application
    app = MLXFineTuningGUI(root)

    # Start main loop
    root.mainloop()


if __name__ == "__main__":
    main()

评论