import tkinter as tk
from tkinter import ttk, filedialog, scrolledtext, messagebox
import subprocess
import threading
import os
import json
from datetime import datetime
import queue
import sys
from pathlib import Path
import shutil
import time
class MLXFineTuningGUI:
def __init__(self, root):
self.root = root
self.root.title("MLX Model Fine-tuning Tool")
self.root.geometry("800x900")
# Check environment
self.check_environment()
# Create log queue
self.log_queue = queue.Queue()
# Create main frame
self.main_frame = ttk.Frame(root, padding="10")
self.main_frame.grid(row=0, column=0, sticky=(tk.W, tk.E, tk.N, tk.S))
# Configure root window grid
root.grid_rowconfigure(0, weight=1)
root.grid_columnconfigure(0, weight=1)
# Configure main frame grid
self.main_frame.grid_columnconfigure(1, weight=1)
# Default values dictionary
self.default_values = {
'batch_size': '2',
'num_layers': '8',
'iterations': '1000',
'fine_tune_type': 'lora',
'val_batches': '4',
'learning_rate': '2e-4',
'steps_per_report': '5',
'steps_per_eval': '25',
'save_every': '50',
'test_batches': '4',
'max_seq_length': '256',
'seed': '42',
'grad_accum_steps': '4',
'use_float16': True,
'grad_checkpoint': True
}
# Model selection
ttk.Label(self.main_frame, text="Model Directory*:").grid(row=0, column=0, sticky=tk.W, pady=5)
self.model_path = tk.StringVar()
ttk.Entry(self.main_frame, textvariable=self.model_path, width=50).grid(row=0, column=1, padx=5, sticky=(tk.W, tk.E))
ttk.Button(self.main_frame, text="Browse", command=self.browse_model).grid(row=0, column=2)
# Data selection
ttk.Label(self.main_frame, text="Data Directory*:").grid(row=1, column=0, sticky=tk.W, pady=5)
self.data_path = tk.StringVar()
ttk.Entry(self.main_frame, textvariable=self.data_path, width=50).grid(row=1, column=1, padx=5, sticky=(tk.W, tk.E))
ttk.Button(self.main_frame, text="Browse", command=self.browse_data).grid(row=1, column=2)
# Training parameters frame
params_frame = ttk.LabelFrame(self.main_frame, text="Training Parameters", padding="5")
params_frame.grid(row=2, column=0, columnspan=3, sticky=(tk.W, tk.E), pady=10)
# Create parameter grid
row = 0
# Basic parameters
ttk.Label(params_frame, text="Batch Size*:").grid(row=row, column=0, sticky=tk.W)
self.batch_size = tk.StringVar(value=self.default_values['batch_size'])
ttk.Entry(params_frame, textvariable=self.batch_size, width=10).grid(row=row, column=1, padx=5)
ttk.Label(params_frame, text="Num Layers*:").grid(row=row, column=2, sticky=tk.W, padx=10)
self.num_layers = tk.StringVar(value=self.default_values['num_layers'])
ttk.Entry(params_frame, textvariable=self.num_layers, width=10).grid(row=row, column=3, padx=5)
ttk.Label(params_frame, text="Iterations*:").grid(row=row, column=4, sticky=tk.W, padx=10)
self.iterations = tk.StringVar(value=self.default_values['iterations'])
ttk.Entry(params_frame, textvariable=self.iterations, width=10).grid(row=row, column=5, padx=5)
# Fine-tuning type
row += 1
ttk.Label(params_frame, text="Fine-tune Type:").grid(row=row, column=0, sticky=tk.W)
self.fine_tune_type = tk.StringVar(value=self.default_values['fine_tune_type'])
fine_tune_combo = ttk.Combobox(params_frame, textvariable=self.fine_tune_type, width=10)
fine_tune_combo['values'] = ('lora', 'dora', 'full')
fine_tune_combo.grid(row=row, column=1, padx=5)
ttk.Label(params_frame, text="Val Batches:").grid(row=row, column=2, sticky=tk.W, padx=10)
self.val_batches = tk.StringVar(value=self.default_values['val_batches'])
ttk.Entry(params_frame, textvariable=self.val_batches, width=10).grid(row=row, column=3, padx=5)
ttk.Label(params_frame, text="Learning Rate:").grid(row=row, column=4, sticky=tk.W, padx=10)
self.learning_rate = tk.StringVar(value=self.default_values['learning_rate'])
ttk.Entry(params_frame, textvariable=self.learning_rate, width=10).grid(row=row, column=5, padx=5)
# Additional parameters
row += 1
ttk.Label(params_frame, text="Steps per Report:").grid(row=row, column=0, sticky=tk.W)
self.steps_per_report = tk.StringVar(value=self.default_values['steps_per_report'])
ttk.Entry(params_frame, textvariable=self.steps_per_report, width=10).grid(row=row, column=1, padx=5)
ttk.Label(params_frame, text="Steps per Eval:").grid(row=row, column=2, sticky=tk.W, padx=10)
self.steps_per_eval = tk.StringVar(value=self.default_values['steps_per_eval'])
ttk.Entry(params_frame, textvariable=self.steps_per_eval, width=10).grid(row=row, column=3, padx=5)
ttk.Label(params_frame, text="Save Every:").grid(row=row, column=4, sticky=tk.W, padx=10)
self.save_every = tk.StringVar(value=self.default_values['save_every'])
ttk.Entry(params_frame, textvariable=self.save_every, width=10).grid(row=row, column=5, padx=5)
row += 1
ttk.Label(params_frame, text="Test Batches:").grid(row=row, column=0, sticky=tk.W)
self.test_batches = tk.StringVar(value=self.default_values['test_batches'])
ttk.Entry(params_frame, textvariable=self.test_batches, width=10).grid(row=row, column=1, padx=5)
ttk.Label(params_frame, text="Max Seq Length:").grid(row=row, column=2, sticky=tk.W, padx=10)
self.max_seq_length = tk.StringVar(value=self.default_values['max_seq_length'])
ttk.Entry(params_frame, textvariable=self.max_seq_length, width=10).grid(row=row, column=3, padx=5)
ttk.Label(params_frame, text="Seed:").grid(row=row, column=4, sticky=tk.W, padx=10)
self.seed = tk.StringVar(value=self.default_values['seed'])
ttk.Entry(params_frame, textvariable=self.seed, width=10).grid(row=row, column=5, padx=5)
# Memory optimization options
row += 1
ttk.Label(params_frame, text="Memory Optimization", font=('', 10, 'bold')).grid(row=row, column=0, sticky=tk.W, pady=(10, 5))
row += 1
self.use_float16 = tk.BooleanVar(value=self.default_values['use_float16'])
ttk.Checkbutton(params_frame, text="Use Float16", variable=self.use_float16).grid(
row=row, column=0, columnspan=2, sticky=tk.W)
ttk.Label(params_frame, text="Gradient Accum Steps:").grid(row=row, column=2, sticky=tk.W, padx=10)
self.grad_accum_steps = tk.StringVar(value=self.default_values['grad_accum_steps'])
ttk.Entry(params_frame, textvariable=self.grad_accum_steps, width=10).grid(row=row, column=3, padx=5)
# Advanced options
row += 1
ttk.Label(params_frame, text="Advanced Options:", font=('', 10, 'bold')).grid(row=row, column=0, sticky=tk.W, pady=(10, 5))
row += 1
self.grad_checkpoint = tk.BooleanVar(value=self.default_values['grad_checkpoint'])
ttk.Checkbutton(params_frame, text="Use Gradient Checkpoint", variable=self.grad_checkpoint).grid(
row=row, column=0, columnspan=2, sticky=tk.W)
# Output directory
ttk.Label(self.main_frame, text="Output Directory*:").grid(row=4, column=0, sticky=tk.W, pady=5)
self.output_path = tk.StringVar()
ttk.Entry(self.main_frame, textvariable=self.output_path, width=50).grid(row=4, column=1, padx=5, sticky=(tk.W, tk.E))
ttk.Button(self.main_frame, text="Browse", command=self.browse_output).grid(row=4, column=2)
# Progress bar
self.progress = ttk.Progressbar(self.main_frame, length=300, mode='determinate')
self.progress.grid(row=5, column=0, columnspan=3, pady=10, sticky=(tk.W, tk.E))
# Progress label
self.progress_label = ttk.Label(self.main_frame, text="0%")
self.progress_label.grid(row=5, column=2, pady=10)
# Log area
log_frame = ttk.LabelFrame(self.main_frame, text="Training Log", padding="5")
log_frame.grid(row=6, column=0, columnspan=3, sticky=(tk.W, tk.E, tk.N, tk.S), pady=5)
# Configure log frame grid
log_frame.grid_rowconfigure(0, weight=1)
log_frame.grid_columnconfigure(0, weight=1)
self.log_area = scrolledtext.ScrolledText(log_frame, height=15, width=80)
self.log_area.grid(row=0, column=0, sticky=(tk.W, tk.E, tk.N, tk.S))
# Configure log tags
self.log_area.tag_config('error', foreground='red')
self.log_area.tag_config('command', foreground='blue')
self.log_area.tag_config('success', foreground='green')
self.log_area.tag_config('info', foreground='black')
# Control buttons
button_frame = ttk.Frame(self.main_frame)
button_frame.grid(row=7, column=0, columnspan=3, pady=10)
self.start_button = ttk.Button(button_frame, text="Start Fine-tuning", command=self.start_finetuning)
self.start_button.grid(row=0, column=0, padx=5)
self.stop_button = ttk.Button(button_frame, text="Stop", command=self.stop_finetuning, state=tk.DISABLED)
self.stop_button.grid(row=0, column=1, padx=5)
self.clear_button = ttk.Button(button_frame, text="Clear Log", command=self.clear_log)
self.clear_button.grid(row=0, column=2, padx=5)
# Initialize process variables
self.process = None
self.training_thread = None
self.is_running = False
# Start log observer
self.root.after(100, self.check_log_queue)
def ensure_directory(self, path):
"""Ensure directory exists, create if it doesn't"""
Path(path).mkdir(parents=True, exist_ok=True)
return path
def get_adapter_path(self):
"""Get the path for adapter files"""
return os.path.join(self.output_path.get(), "adapters")
def get_finetuned_model_path(self):
"""Get the path for the finetuned model"""
return os.path.join(self.output_path.get(), "finetuned_model")
def run_finetuning(self):
"""Execute fine-tuning process"""
try:
# Create necessary directories
adapter_path = self.ensure_directory(self.get_adapter_path())
finetuned_model_path = self.ensure_directory(self.get_finetuned_model_path())
# Build training command
python_executable = sys.executable
train_command = [
python_executable,
"-m",
"mlx_lm.lora",
"--train",
"--model", self.model_path.get(),
"--batch-size", self.batch_size.get(),
"--num-layers", self.num_layers.get(),
"--iters", self.iterations.get(),
"--data", self.data_path.get(),
"--fine-tune-type", self.fine_tune_type.get(),
"--adapter-path", adapter_path
]
# Add optional parameters if they have non-default values
if self.val_batches.get() != "0":
train_command.extend(["--val-batches", self.val_batches.get()])
if self.learning_rate.get() != "0":
train_command.extend(["--learning-rate", self.learning_rate.get()])
if self.steps_per_report.get() != "0":
train_command.extend(["--steps-per-report", self.steps_per_report.get()])
if self.steps_per_eval.get() != "0":
train_command.extend(["--steps-per-eval", self.steps_per_eval.get()])
if self.save_every.get() != "0":
train_command.extend(["--save-every", self.save_every.get()])
if self.test_batches.get() != "0":
train_command.extend(["--test-batches", self.test_batches.get()])
if self.max_seq_length.get() != "0":
train_command.extend(["--max-seq-length", self.max_seq_length.get()])
if self.seed.get() != "0":
train_command.extend(["--seed", self.seed.get()])
if self.grad_checkpoint.get():
train_command.append("--grad-checkpoint")
# Log the command
self.log_message(f"Executing command: {' '.join(train_command)}", 'command')
# Set environment variables
env = os.environ.copy()
env["PYTHONPATH"] = os.pathsep.join(sys.path)
# Create process and set pipes
self.process = subprocess.Popen(
train_command,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
bufsize=1,
universal_newlines=True,
env=env
)
# Create output queues
output_queue = queue.Queue()
error_queue = queue.Queue()
# Start output reader threads
stdout_thread = threading.Thread(target=self.read_output,
args=(self.process.stdout, output_queue),
daemon=True)
stderr_thread = threading.Thread(target=self.read_output,
args=(self.process.stderr, error_queue),
daemon=True)
stdout_thread.start()
stderr_thread.start()
current_iter = 0
total_iters = int(self.iterations.get())
# Process output
while self.process.poll() is None and self.is_running:
# Handle standard output
while True:
try:
output = output_queue.get_nowait()
if output:
output = output.strip()
self.log_message(output, 'info')
# Update progress bar
if "Iter" in output:
try:
iter_str = output.split("Iter")[1].split(":")[0].strip()
current_iter = int(iter_str)
self.update_progress(current_iter, total_iters)
except:
pass
except queue.Empty:
break
# Handle error output
while True:
try:
error = error_queue.get_nowait()
if error:
self.log_message(f"Error: {error.strip()}", 'error')
except queue.Empty:
break
time.sleep(0.1)
# Wait for output threads to finish
stdout_thread.join(timeout=5)
stderr_thread.join(timeout=5)
if not self.is_running:
self.log_message("Training process was terminated by user.", 'info')
return
if self.process.returncode != 0:
self.log_message(f"Training process terminated abnormally with code: {self.process.returncode}", 'error')
return
# Add evaluation phase
if self.should_run_test():
self.run_test_phase()
# Add model fusion phase
self.run_fusion_phase()
except Exception as e:
self.log_message(f"Error during fine-tuning: {str(e)}", 'error')
import traceback
self.log_message(f"Error details: {traceback.format_exc()}", 'error')
finally:
self.is_running = False
self.start_button.config(state=tk.NORMAL)
self.stop_button.config(state=tk.DISABLED)
self.progress['value'] = 100
self.progress_label['text'] = "100%"
def should_run_test(self):
"""Determine if test phase should be run"""
return self.test_batches.get() != "0"
def run_test_phase(self):
"""Run test phase"""
self.log_message("\nStarting model evaluation...", 'info')
adapter_path = self.get_adapter_path()
if not os.path.exists(adapter_path):
self.log_message(f"Error: Adapter path not found: {adapter_path}", 'error')
return
test_command = [
sys.executable,
"-m",
"mlx_lm.lora",
"--model", self.model_path.get(),
"--data", self.data_path.get(),
"--adapter-path", adapter_path,
"--test"
]
# Add optional test parameters
if self.test_batches.get() != "0":
test_command.extend(["--test-batches", self.test_batches.get()])
if self.max_seq_length.get() != "0":
test_command.extend(["--max-seq-length", self.max_seq_length.get()])
self.log_message(f"Executing evaluation command: {' '.join(test_command)}", 'command')
# Execute evaluation command
test_process = subprocess.Popen(
test_command,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
env=os.environ.copy()
)
# Handle evaluation output
stdout, stderr = test_process.communicate()
if stdout:
self.log_message(stdout.strip(), 'info')
if stderr:
self.log_message(f"Evaluation error: {stderr.strip()}", 'error')
if test_process.returncode == 0:
self.log_message("Model evaluation completed!", 'success')
else:
self.log_message("Model evaluation failed!", 'error')
def run_fusion_phase(self):
"""Run model fusion phase"""
self.log_message("\nStarting model fusion...", 'info')
adapter_path = self.get_adapter_path()
finetuned_model_path = self.get_finetuned_model_path()
if not os.path.exists(adapter_path):
self.log_message(f"Error: Adapter path not found: {adapter_path}", 'error')
return
fuse_command = [
sys.executable,
"-m",
"mlx_lm.fuse",
"--model", self.model_path.get(),
"--adapter-path", adapter_path,
"--save-path", finetuned_model_path
]
self.log_message(f"Executing command: {' '.join(fuse_command)}", 'command')
# Execute fusion command
fuse_process = subprocess.Popen(
fuse_command,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
env=os.environ.copy()
)
# Handle fusion output
stdout, stderr = fuse_process.communicate()
if stdout:
self.log_message(stdout.strip(), 'info')
if stderr:
self.log_message(f"Fusion error: {stderr.strip()}", 'error')
if fuse_process.returncode == 0:
self.log_message("Model fusion completed successfully!", 'success')
self.log_message(f"Fine-tuned model saved to: {finetuned_model_path}", 'success')
else:
self.log_message("Model fusion failed!", 'error')
def read_output(self, pipe, queue):
"""Helper function to read output"""
try:
while True:
line = pipe.readline()
if not line and self.process.poll() is not None:
break
if line:
queue.put(line)
except Exception as e:
queue.put(f"Error reading output: {str(e)}")
finally:
pipe.close()
def check_environment(self):
"""Check if environment meets requirements"""
try:
import mlx
import mlx_lm
except ImportError:
messagebox.showerror("Error", "Please install MLX and MLX-LM first:\n pip install mlx mlx-lm")
sys.exit(1)
# Check if running on Apple Silicon Mac
if not (sys.platform == "darwin" and "arm" in os.uname().machine):
messagebox.showerror("Error", "This program can only run on Apple Silicon Mac")
sys.exit(1)
def browse_model(self):
directory = filedialog.askdirectory(title="Select Model Directory")
if directory:
self.model_path.set(directory)
def browse_data(self):
directory = filedialog.askdirectory(title="Select Data Directory")
if directory:
self.data_path.set(directory)
def browse_output(self):
directory = filedialog.askdirectory(title="Select Output Directory")
if directory:
self.output_path.set(directory)
def clear_log(self):
"""Clear log area"""
self.log_area.delete(1.0, tk.END)
def check_log_queue(self):
"""Check log queue and update log area"""
while True:
try:
message, tag = self.log_queue.get_nowait()
self.log_area.insert(tk.END, message + '\n', tag)
self.log_area.see(tk.END)
self.log_area.update_idletasks()
except queue.Empty:
break
self.root.after(100, self.check_log_queue)
def log_message(self, message, tag='info'):
"""Add tagged log message"""
self.log_queue.put((message, tag))
def verify_data_files(self, data_path):
"""Verify existence and requirements of data files"""
required_files = ['train.jsonl', 'valid.jsonl', 'test.jsonl']
missing_files = []
for file in required_files:
if not os.path.exists(os.path.join(data_path, file)):
missing_files.append(file)
return missing_files
def start_finetuning(self):
"""Start fine-tuning process"""
if not all([self.model_path.get(), self.data_path.get(), self.output_path.get()]):
self.log_message("Error: Please select all required directories.", 'error')
messagebox.showerror("Error", "Please select all required directories")
return
# Verify data files
missing_files = self.verify_data_files(self.data_path.get())
if missing_files:
self.log_message(f"Error: Missing required files in data directory: {', '.join(missing_files)}", 'error')
messagebox.showerror("Error", f"Missing required files in data directory:\n{', '.join(missing_files)}")
return
try:
# Create output directories
self.ensure_directory(self.get_adapter_path())
self.ensure_directory(self.get_finetuned_model_path())
except Exception as e:
self.log_message(f"Error creating directories: {str(e)}", 'error')
messagebox.showerror("Error", f"Failed to create output directories: {str(e)}")
return
self.is_running = True
self.start_button.config(state=tk.DISABLED)
self.stop_button.config(state=tk.NORMAL)
self.progress['value'] = 0
self.progress_label['text'] = "0%"
# Start training in a separate thread
self.training_thread = threading.Thread(target=self.run_finetuning)
self.training_thread.start()
def stop_finetuning(self):
"""Stop fine-tuning process"""
if self.process:
self.is_running = False
self.process.terminate()
self.log_message("Stopping fine-tuning process...", 'info')
def update_progress(self, current_iter, total_iters):
"""Update progress bar and label"""
progress = (current_iter / total_iters) * 100
self.progress['value'] = progress
self.progress_label['text'] = f"{progress:.1f}%"
self.root.update_idletasks()
def main():
"""Main function"""
try:
# Try to use themed style
from ttkthemes import ThemedTk
root = ThemedTk(theme="arc")
except ImportError:
root = tk.Tk()
# Set window icon (if on macOS)
if sys.platform == "darwin":
root.createcommand('tk::mac::Quit', root.quit)
# Set minimum window size
root.minsize(800, 900)
# Create application
app = MLXFineTuningGUI(root)
# Start main loop
root.mainloop()
if __name__ == "__main__":
main()