import tkinter as tk
import numpy as np
import sounddevice as sd
import threading
from math import log2
import mido
import sys

# Béranger Seguin & Nicolas Guès, janvier 2025
# Merci également à Alexis Thibault
# Version du 7 février 2025

#TODO: sauvegarde/chargement d'un graphe (+ options)
#TODO: support des fichiers scl

################
## PARAMÈTRES ##
################

DEVICE = sd.default.device
SAMPLE_RATE = sd.query_devices(DEVICE, 'output')['default_samplerate']
AMPLITUDE_LAW = 1.4  # puissance de décroissance de l'amplitude selon la fréquence (1 = bruit rose)
TUNING = 440
DISTANCE_EXPONENT = 0  # puissance dans la dépendence du poids des arêtes à la longueur (0 = indépendant)
EIGEN_EXPONENT = 0.5  # puissance dans la formule fréquence = valeurpropre**EIGEN_EXPONENT

TIME_KEEP_NOTES = 1

def freq_to_name(frequency):
    quartersteps_from_A = log2(frequency/440) * 24
    notenames = [
        "A", "A+", "Bb", "B-",
        "B", "C-",
        "C", "C+", "C#", "D-",
        "D", "D+", "Eb", "E-",
        "E", "E+",
        "F", "F+", "F#", "G-",
        "G", "G+", "Ab", "A-"
    ]
    return notenames[round(quartersteps_from_A)%24]

def midicode_to_freq(midi):
    return TUNING * 2**((midi-69)/12)

###################
## MIDI LISTENER ##
###################

def process_midi_message(message):
    global audio_player
    if message.type == 'note_on' and message.velocity > 0:
        audio_player.add_note(
            message.note,
            (message.velocity/127)
        )
    elif message.type == 'note_off' or (message.type == 'note_on' and message.velocity == 0):
        audio_player.kill_code(message.note)

class MidiListener:
    def __init__(self):
        input_ports = mido.get_input_names()
        if not input_ports:
            print("Pas de contrôleur MIDI détecté.")
            self.selected_port = None
        else:
            for i, port in enumerate(input_ports):
                print(f"{i}: {port}")
            selected_port_index = int(input("Numéro du contrôleur MIDI: "))
            self.selected_port = input_ports[selected_port_index]

    def listen(self):
        with mido.open_input(self.selected_port) as input_port:
            try:
                for message in input_port:
                    process_midi_message(message)
            except KeyboardInterrupt:
                print("\nStopped MIDI listener.")

#########
## SON ##
#########

running_time = 0

class Note:
    def __init__(self, spectrum, note_code, time_born, volume, time_dead=None):
        self.spectrum = spectrum
        self.note_code = note_code
        self.time_born = time_born
        self.time_dead = time_dead
        self.volume = volume

    def envelope(self, time):
        if mode == DRONE_MODE:
            return 1
        elif mode == ENVELOPE_MODE:
            ATTACKSPEED = 10
            SUSTAIN = 0.3
            RELEASESPEED = 20
            envelope = 1
            time_after_attack = time - self.time_born
            if self.time_dead is None:
                envelope = np.minimum(
                    ATTACKSPEED*time_after_attack,
                    0.3 #sustain
                )
            else:
                time_after_dead = time-self.time_dead
                volume_when_dead = SUSTAIN #TODO!
                envelope = volume_when_dead * np.exp(-RELEASESPEED*(time_after_dead))
            return np.minimum(self.volume * envelope, 1)

    def generate_sound(self, t):
        sum_weights = 0
        base_freq = midicode_to_freq(self.note_code)
        sound = np.zeros_like(t)
        for harmonic in self.spectrum:
            weight = harmonic ** (-AMPLITUDE_LAW)
            sum_weights += weight
            sine_wave = weight * np.sin(2 * np.pi * base_freq * harmonic * t)
            sound += sine_wave

        if sum_weights > 1e-06:
            sound /= sum_weights

        sound *= self.envelope(t)

        kill_me_next = False
        if self.time_dead is not None and (running_time/SAMPLE_RATE - self.time_dead) > TIME_KEEP_NOTES:
            kill_me_next = True
        return sound, kill_me_next

class DronePlayer:
    def __init__(self):
        self.playing = False
        self.notes = []
        self.spectrum = []
        self.stream = None

    def play(self):
        # Create a new playback stream
        self.stream = sd.OutputStream(
            device=DEVICE,
            channels=1,
            callback=self.audio_callback,
            dtype='float32',
            samplerate=SAMPLE_RATE,
            finished_callback=self.stream_finished
        )

        # Start the stream
        self.stream.start()

        # Wait for the sound to finish
        sd.wait()

    def audio_callback(self, outdata, frames, time, status):
        if status:
            print(status, file=sys.stderr)
        global running_time

        t = (running_time + np.arange(frames)) / SAMPLE_RATE
        t = t.reshape(-1, 1)
        
        if mode == DRONE_MODE:
            self.notes = [Note(self.spectrum, 69, 0, 0.9)]

        sound = np.zeros_like(t)
        for i, note in enumerate(self.notes):
            note_sound, kill_that_note = note.generate_sound(t)
            sound += note_sound
            if kill_that_note:
                self.notes.pop(i)
        outdata[:] = sound
        running_time += frames

    def add_note(self, note_code, velocity):
        self.kill_code(note_code)
        self.notes.append(Note(
            self.spectrum,
            note_code,
            running_time/SAMPLE_RATE,
            velocity
        ))

    def kill_code(self, note_code=None):
        for i, note in enumerate(self.notes):
            if (note_code is None or note.note_code == note_code) and note.time_dead is None:
                note.time_dead = running_time/SAMPLE_RATE

    def new_spectrum(self, new_spectrum):
        self.spectrum = new_spectrum
        if len(new_spectrum) == 0:
            self.kill_code()
            return
        for i, note in enumerate(self.notes[:]):
            self.notes.append(Note(
                new_spectrum,
                note.note_code,
                note.time_born,
                note.volume,
                time_dead=note.time_dead
            ))
            self.notes.pop(i)

    def stop(self):
        sd.stop()

    def stream_finished(self):
        pass

def remove_duplicates(numbers, threshold=1e-10):
    unique_numbers = []
    
    for num in numbers:
        num = float(num)
        if not any(np.isclose(num, unique_num, atol=threshold) for unique_num in unique_numbers):
            unique_numbers.append(num)
    
    return np.array(unique_numbers)


###########
## GRAPH ##
###########

class Graph:
    def __init__(self):
        self.vertex_positions = {}
        self.vertex_objects = {}
        self.edges = {}
        self.current_vertex_id = 0

    def add_vertex(self, x, y):
        new_id = self.current_vertex_id
        self.vertex_positions[new_id] = (x, y)
        self.vertex_objects[new_id] = canvas.create_oval(
            x - RADIUS_DOTS, y - RADIUS_DOTS,
            x + RADIUS_DOTS, y + RADIUS_DOTS,
            fill='black', outline=''
        )
        self.current_vertex_id += 1
        return new_id

    def remove_vertex(self, i):
        canvas.delete(self.vertex_objects[i])
        self.vertex_objects.pop(i)
        self.vertex_positions.pop(i)

        # Supprimer les arêtes liées au sommet
        edges_to_remove = [key for key in self.edges if key[0] == i or key[1] == i]
        for edge in edges_to_remove:
            self.remove_edge_by_key(edge)

    def move_vertex(self, i, x, y):
        canvas.coords(self.vertex_objects[i], x - RADIUS_DOTS, y - RADIUS_DOTS, x + RADIUS_DOTS, y + RADIUS_DOTS)
        self.vertex_positions[i] = (x, y)
        # Déplacer les arêtes liées à ce point
        for (u, v), line in self.edges.items():
            if u == i or v == i:
                ux, uy = graph.vertex_positions[u]
                vx, vy = graph.vertex_positions[v]
                canvas.coords(line, ux, uy, vx, vy)

    def remove_all(self):
        for edge in self.edges:
            canvas.delete(self.edges[edge])
        for i in self.vertex_objects:
            canvas.delete(self.vertex_objects[i])
        self.vertex_objects.clear()
        self.vertex_positions.clear()
        self.edges.clear()
        self.current_vertex_id = 0

    def is_edge(self, i1, i2):
        i1, i2 = min(i1, i2), max(i1, i2)
        return (i1, i2) in self.edges

    def remove_edge_by_key(self, key):
        canvas.delete(self.edges[key])
        del self.edges[key]
        
    def remove_edge(self, i1, i2):
        i1, i2 = min(i1, i2), max(i1, i2)
        self.remove_edge_by_key((i1, i2))

    def add_edge(self, i1, i2):
        i1, i2 = min(i1, i2), max(i1, i2)
        x1, y1 = self.vertex_positions[i1]
        x2, y2 = self.vertex_positions[i2]
        self.edges[(i1, i2)] = canvas.create_line(
            x1, y1,
            x2, y2,
            fill='black', width=EDGE_WIDTH
        )

    def flip_edge(self, i1, i2):
        if self.is_edge(i1, i2):
            self.remove_edge(i1, i2)            
        else:
            self.add_edge(i1, i2)

    def laplacian_matrix(self):
        num_dots = len(self.vertex_objects)
        straight_indices = {}
        curr_idx = 0
        for d in self.vertex_objects:
            straight_indices[d] = curr_idx
            curr_idx += 1
        laplacian = np.zeros((num_dots, num_dots))
        for u, v in self.edges:
            (ux, uy) = self.vertex_positions[u]
            (vx, vy) = self.vertex_positions[v]
            edge_length = np.sqrt((vx - ux) ** 2 + (vy - uy) ** 2)
            weight_of_edge = edge_length ** DISTANCE_EXPONENT
            laplacian[straight_indices[u], straight_indices[v]] -= weight_of_edge
            laplacian[straight_indices[v], straight_indices[u]] -= weight_of_edge
            laplacian[straight_indices[u], straight_indices[u]] += weight_of_edge
            laplacian[straight_indices[v], straight_indices[v]] += weight_of_edge
        return laplacian
        

def update_spectrum():
    laplacian = graph.laplacian_matrix()
    eigenvalues = np.linalg.eigvalsh(laplacian)
    nonzero_eigenvalues = eigenvalues[eigenvalues > 1e-10]
    nonzero_eigenvalues = remove_duplicates(nonzero_eigenvalues, threshold=1e-7)
    nonzero_eigenvalues = nonzero_eigenvalues**EIGEN_EXPONENT

    if len(nonzero_eigenvalues) > 0:
        min_eigenvalue = np.min(nonzero_eigenvalues)
        scaled_eigenvalues = nonzero_eigenvalues / min_eigenvalue

        audio_player.new_spectrum(scaled_eigenvalues)
        update_frequencies_display(scaled_eigenvalues*TUNING)

        if not audio_player.playing:
            audio_player.playing = True
            threading.Thread(target=audio_player.play, daemon=True).start()
    else:
        audio_player.new_spectrum([])
        update_frequencies_display([])


########
## UI ##
########

audio_player = DronePlayer()
midi_listener = MidiListener()
threading.Thread(target=midi_listener.listen, daemon=True).start()

root = tk.Tk()
root.title('Faisons chanter les graphes !')
root.configure(bg='white')

screen_width = root.winfo_screenwidth()
screen_height = root.winfo_screenheight()

canvas = tk.Canvas(root, width=screen_width, height=screen_height, bg='white', highlightthickness=0)
canvas.pack(side=tk.RIGHT, fill=tk.BOTH, expand=True)

graph = Graph()
first_dot = None
dragging_dot = None
has_been_dragged = False

FREQUENCIES_TEXT = "Fréquences (Hz):"
DRONE_MODE = 0
ENVELOPE_MODE = 1
MODE_TEXT = ["Bourdon", "Enveloppe"]
RADIUS_DOTS = 14
RADIUS_DETECTION_CLICK = 20
EDGE_WIDTH = 4

mode = ENVELOPE_MODE

AUTOEDGE_OFF = 0
AUTOEDGE_ON = 1
AUTOEDGE_CHAIN = 2
AUTOEDGE_FLEUR = 3
AUTOEDGE_TEXT = ["Off", "Complet", "Chaîne", "Fleur"]
autoedge = 0
autoedge_list = []

# Changer de mode (bourdon/enveloppe)
def toggle_mode():
    global mode
    audio_player.notes = []
    mode = (mode+1)%len(MODE_TEXT)
    mode_button.config(text=MODE_TEXT[mode])

def flip_autoedge():
    global autoedge, autoedge_list
    autoedge = (autoedge+1)%len(AUTOEDGE_TEXT)
    autoedge_list.clear()
    autoedge_button.config(text="Arêtes automatiques " + AUTOEDGE_TEXT[autoedge])

# Détecter le clic sur un point
def clicsurpoint(x, y):
    for i, (dx, dy) in graph.vertex_positions.items():
        if np.sqrt((x - dx) ** 2 + (y - dy) ** 2) < RADIUS_DETECTION_CLICK:
            return i
    return None

#Clic gauche (ajout/déplacement de points ou ajout d'arêtes)
def on_canvas_click(event):
    global first_dot, dragging_dot
    x, y = event.x, event.y
    i = clicsurpoint(x, y)
    if i is not None: # on a cliqué sur un point
        if first_dot is not None: # J'ai déjà choisi le premier point d'une arête
            canvas.itemconfig(graph.vertex_objects[first_dot], fill='black')
            if first_dot != i:
                graph.flip_edge(first_dot, i)
                update_spectrum()
            first_dot = None

        else: # Je viens de cliquer sur un point, soit pour le déplacer, soit pour créer une arête
            dragging_dot = i
            has_been_dragged = False

    else: # aucun point n'a été trouvé, donc créer un nouveau point
       #if first_dot is None and dragging_dot is None:
        if dragging_dot is None:
            new_id = graph.add_vertex(x, y)
            #Peut-être devrait-on créer l'arête automatiquement si on avait déjà choisi un first_dot ?
            if autoedge != AUTOEDGE_OFF:
                for vtx in autoedge_list:
                    graph.add_edge(new_id, vtx)
                if autoedge != AUTOEDGE_FLEUR or len(autoedge_list) == 0:
                    if autoedge == AUTOEDGE_CHAIN:
                        autoedge_list.clear()                            
                    autoedge_list.append(new_id)
                update_spectrum()

#Clic droit (suppression de points)
def on_right_click(event):
    global first_dot, dragging_dot
    i = clicsurpoint(event.x, event.y)
    if i is not None: # on a cliqué sur un point
        if first_dot == i:
            first_dot = None
        if dragging_dot == i:
            dragging_dot = None
        graph.remove_vertex(i)
        update_spectrum()

# Drag-and-drop (mouvement)
def on_mouse_drag(event):
    global dragging_dot, has_been_dragged
    if dragging_dot is not None:
        graph.move_vertex(dragging_dot, event.x, event.y)
        has_been_dragged = True
        if DISTANCE_EXPONENT != 0:
            update_spectrum()

# Drag-and-drop (relâcher)
def on_mouse_release(event):
    global dragging_dot, first_dot, has_been_dragged
    if dragging_dot is not None:
        if not has_been_dragged and first_dot is None:
            first_dot = dragging_dot
            canvas.itemconfig(graph.vertex_objects[first_dot], fill='red')  # Mettre le point en rouge
        dragging_dot = None
        has_been_dragged = False

def update_amplitude(val):
    global AMPLITUDE_LAW
    AMPLITUDE_LAW = float(val)
    update_spectrum()

def update_tuning(val):
    global TUNING
    TUNING = float(val)
    update_spectrum()

def update_distance_exponent(val):
    global DISTANCE_EXPONENT
    DISTANCE_EXPONENT = float(val)
    update_spectrum()

def update_eigen_exponent(val):
    global EIGEN_EXPONENT
    EIGEN_EXPONENT = float(val)
    update_spectrum()

def update_frequencies_display(frequencies):
    global frequencies_text_id
    frequencies_str = ", ".join(f"{f:.2f}" for f in frequencies)
    notenames = [freq_to_name(f) for f in frequencies]
    names_str = ", ".join(notenames)
    canvas.itemconfig(frequencies_text_id, text=f"{FREQUENCIES_TEXT}\n{frequencies_str}\n{names_str}")

def reset_graph():
    global first_dot, dragging_dot
    graph.remove_all()
    update_spectrum()
    first_dot = None
    dragging_dot = None
    autoedge_list.clear()

currenty = 10

mode_button = tk.Button(root, text=MODE_TEXT[mode], width=10, height=2, command=toggle_mode, bg="darkblue", fg="white")
mode_button.place(x=10, y=currenty)

reset_button = tk.Button(root, text="Tout effacer", width=10, height=2, command=reset_graph, bg="darkblue", fg="white")
reset_button.place(x=150, y=currenty)
currenty += 60

autoedge_button = tk.Button(root, text="Arêtes automatiques", width=25, height=2, command=flip_autoedge, bg="darkblue", fg="white")
autoedge_button.place(x=10, y=currenty)
currenty += 60

amplitude_slider = tk.Scale(root, from_=1, to=3, resolution=0.1, orient=tk.VERTICAL,
                            label="Décroissance des harmoniques", command=update_amplitude)
amplitude_slider.set(AMPLITUDE_LAW)
amplitude_slider.pack(pady=100)
amplitude_slider.place(x=10, y=currenty)
currenty += 110

tuning_slider = tk.Scale(root, from_=50, to=1000, resolution=1, orient=tk.VERTICAL,
                                 label="Accordage", command=update_tuning)
tuning_slider.set(TUNING)
tuning_slider.pack(pady=10)
tuning_slider.place(x=10, y=currenty)
currenty += 110

distance_exponent_slider = tk.Scale(root, from_=-2, to=2, resolution=0.1, orient=tk.VERTICAL,
                                    label="Dépendence en la distance", command=update_distance_exponent)
distance_exponent_slider.set(DISTANCE_EXPONENT)
distance_exponent_slider.pack(pady=10)
distance_exponent_slider.place(x=10, y=currenty)
currenty += 110

eigen_exponent_slider = tk.Scale(root, from_=0.5, to=2, resolution=0.1, orient=tk.VERTICAL,
                                    label="Relation valeur propre > fréquence", command=update_eigen_exponent)
eigen_exponent_slider.set(EIGEN_EXPONENT)
eigen_exponent_slider.pack(pady=10)
eigen_exponent_slider.place(x=10, y=currenty)
currenty += 110

canvas.create_text(
    screen_width - 300,
    screen_height - 100,
    text=f"""
Clic gauche : créer/déplacer point/arête
Clic droit : supprimer un point \n
Échap : quitter  
    """,
    anchor='sw',
)

frequencies_text_id = canvas.create_text(
    20,
    screen_height - 100,
    text=f"{FREQUENCIES_TEXT}:\n\n",
    anchor='sw',
)

canvas.bind('<Button-1>', on_canvas_click)
canvas.bind('<Button-3>', on_right_click)
canvas.bind('<B1-Motion>', on_mouse_drag)
canvas.bind('<ButtonRelease-1>', on_mouse_release)
root.bind('<Escape>', lambda e: root.destroy())


#######################
## BOUCLE PRINCIPALE ##
#######################

root.mainloop()

