#include "matrix.h"
#include "he_logic.h"
#include "quantum.h"
#include "analog.h"

// Define matrices for left and right half
static matrix_row_t matrix[MATRIX_ROWS];

// Multiplexer select pins
static const pin_t mux_sel_pins[] = { MUX_SEL0_PIN, MUX_SEL1_PIN, MUX_SEL2_PIN };

// ADC channels (RP2040 uses ADC 0, 1, 2 for GP26, 27, 28)
static const pin_t mux_val_pins[] = { MUX_VAL0_PIN, MUX_VAL1_PIN, MUX_VAL2_PIN };

static void select_mux_channel(uint8_t channel) {
    writePin(mux_sel_pins[0], (channel & 1));
    writePin(mux_sel_pins[1], (channel & 2));
    writePin(mux_sel_pins[2], (channel & 4));
    // Small delay for mux switching to settle
    wait_us(2);
}

void matrix_init_custom(void) {
    he_logic_init();
    
    for (int i = 0; i < 3; i++) {
        setPinOutput(mux_sel_pins[i]);
        writePinLow(mux_sel_pins[i]);
    }

    analogReference(ADC_REF_POWER);
    // Ensure ADC pins are initialized (QMK analog_init might do this)
    
    // Clear matrix
    for (uint8_t i = 0; i < MATRIX_ROWS; i++) {
        matrix[i] = 0;
    }
}

bool matrix_scan_custom(matrix_row_t current_matrix[]) {
    bool matrix_changed = false;

    // We have 3 MUXes, each handling 6 channels (0-5).
    // Let's assume Mux 0 is row 0, Mux 1 is row 1, Mux 2 is row 2.
    // The keys map to columns 0-5.
    // For the split half, it'll populate rows 0..2 or 3..5 depending on handedness.
    
    uint8_t row_offset = is_keyboard_left() ? 0 : 3;

    for (uint8_t col = 0; col < 6; col++) {
        select_mux_channel(col);
        
        for (uint8_t mux = 0; mux < 3; mux++) {
            uint8_t row = mux + row_offset;
            uint16_t adc_val = analogReadPin(mux_val_pins[mux]);
            
            he_logic_process_switch(row, col, adc_val);
            
            bool current_state = he_switches[row][col].is_pressed;
            bool previous_state = (matrix[row] & (1 << col));
            
            if (current_state != previous_state) {
                if (current_state) {
                    matrix[row] |= (1 << col);
                } else {
                    matrix[row] &= ~(1 << col);
                }
                matrix_changed = true;
            }
        }
    }
    
    if (matrix_changed) {
        for (uint8_t i = 0; i < MATRIX_ROWS; i++) {
            current_matrix[i] = matrix[i];
        }
    }

    return matrix_changed;
}
