1、FFT部分

#include <M5StickCPlus.h>

#pragma GCC optimize ("O3")
#include <driver/i2s.h>
#include <math.h>
#include <string.h>
#include <stdio.h>
#include "dywapitchtrack.h"
#include "fix_fft.h"

#define PIN_CLK  0
#define PIN_DATA 34
#define SAMPLES 1024 // Must be a power of 2
#define READ_LEN (2 * SAMPLES)
//M5StickCPlus的宽度
#define TFT_WIDTH 240
//M5StickCPlus的高度
#define TFT_HEIGHT 135
#define BANDS 8
#define BANDS_WIDTH ( TFT_WIDTH / BANDS )
#define BANDS_PADDING 8
#define BAR_WIDTH ( BANDS_WIDTH - BANDS_PADDING )
#define NOISE_FLOOR 1
#define AMPLIFIER (1 << 2)
#define MAGNIFY 3
#define RSHIFT 13
#define RSHIFT2 1
#define OSC_NOISEFLOOR 100
#define OSC_SAMPLES DYWAPT_SAMPLESIZE
#define OSC_SKIPCOUNT (OSC_SAMPLES/SAMPLES)
#define OSC_EXTRASKIP 0

//#define MAXBUFSIZE (2*SAMPLES)
#define MAXBUFSIZE OSC_SAMPLES

struct eqBand {
  const char *freqname;
  int peak;
  int lastpeak;
  uint16_t lastval;
};

enum : uint8_t {
  ModeSpectrumBars,
  ModeOscilloscope,
  ModeTuner,
  ModeCount,
};

static uint8_t runmode = 0;

static volatile bool semaphore = false;
static volatile bool needinit = true;

static eqBand audiospectrum[BANDS] = {
  // freqname,peak,lastpeak,lastval,
  // 频率名称,峰值,前一个峰值
  { ".1k", 0 },
  { ".2k", 0 },
  { ".5k", 0 },
  { " 1k", 0 },
  { " 2k", 0 },
  { " 4k", 0 },
  { " 8k", 0 },
  { "16k", 0 }
};
 
static int vTemp[2][MAXBUFSIZE];
static uint8_t curbuf = 0;
static uint16_t colormap[TFT_HEIGHT];//color palette for the band meter(pre-fill in setup)

static dywapitchtracker pitchTracker;
static TFT_eSprite sprite(&M5.Lcd);

static int bufposcount = 0;

static const char *notestr[12] = {
  "C ", "C#", "D ", "D#", "E ", "F ", "F#", "G ", "G#", "A ", "A#", "B "
};

static uint16_t oscbuf[2][OSC_SAMPLES];
#if OSC_EXTRASKIP > 0
static int skipcount = 0;
#endif

void i2sInit(){
   i2s_config_t i2s_config = {
    .mode = (i2s_mode_t)(I2S_MODE_MASTER | I2S_MODE_RX | I2S_MODE_PDM),
    .sample_rate =  DYWAPT_SAMPLERATE,
    .bits_per_sample = I2S_BITS_PER_SAMPLE_16BIT, //is fixed at 12bit,stereo,MSB
    .channel_format = I2S_CHANNEL_FMT_ALL_RIGHT,
    .communication_format = I2S_COMM_FORMAT_I2S,
    .intr_alloc_flags = ESP_INTR_FLAG_LEVEL1,
    .dma_buf_count = 2,
    .dma_buf_len = 128,
   };
   i2s_pin_config_t pin_config;
   pin_config.bck_io_num   = I2S_PIN_NO_CHANGE;
   pin_config.ws_io_num    = PIN_CLK;
   pin_config.data_out_num = I2S_PIN_NO_CHANGE;
   pin_config.data_in_num  = PIN_DATA; 
   i2s_driver_install(I2S_NUM_0, &i2s_config, 0, NULL);
   i2s_set_pin(I2S_NUM_0, &pin_config);
   i2s_set_clk(I2S_NUM_0, DYWAPT_SAMPLERATE, I2S_BITS_PER_SAMPLE_16BIT, I2S_CHANNEL_MONO);
}
 
void setup() {
  M5.begin();
  setCpuFrequencyMhz(80);
  M5.Lcd.setRotation(1);
  M5.Lcd.fillScreen(BLACK);

  sprite.createSprite(TFT_WIDTH, TFT_HEIGHT);
  sprite.setTextSize(1);

  i2sInit();
 
  for(uint8_t i=0;i<TFT_HEIGHT;i++) {
    //colormap[i] = M5.Lcd.color565(TFT_HEIGHT-i*.5,i*1.1,0); //RGB
    //colormap[i] = M5.Lcd.color565(TFT_HEIGHT-i*4.4,i*2.5,0);//RGB:rev macsbug
    float r = TFT_HEIGHT - i;
    float g = i;
    float mag = (r > g)? 255./r : 255./g;
    r *= mag;
    g *= mag;
    colormap[i] = M5.Lcd.color565((uint8_t)r,(uint8_t)g,0); // Modified by KKQ-KKQ
  }

  int core = 1 - xPortGetCoreID();
  xTaskCreatePinnedToCore(looptask,"calctask",32768,NULL,1,NULL,core);
}

void initMode() {
  M5.Lcd.fillRect(0, 0,
                  TFT_WIDTH, TFT_HEIGHT, BLACK);
  switch (runmode) {
    //频谱模式
    case ModeSpectrumBars:
      M5.Lcd.setTextSize(1);
      M5.Lcd.setTextColor(LIGHTGREY);
      for (byte band = 0; band < BANDS; band++) {
        M5.Lcd.setCursor(BANDS_WIDTH*band + 2, 0);
        //哦哦哦,我明白了,这其实就只是在打印audiospectrum那个结构体里的频率名称而已
        M5.Lcd.print(audiospectrum[band].freqname);
      }
      break;

    case ModeOscilloscope:
      {
        sprite.setTextColor(GREEN);
        dywapitch_inittracking(&pitchTracker);
      }
      break;

    case ModeTuner:
      sprite.setTextSize(1);
      break;
  }
}


//好嘞,这里是实际的频谱的渲染程序部分
void showSpectrumBars(){
  int *vTemp_ = vTemp[curbuf^1];

  //对mic的数据应用了一个什么窗口,然后计算了一下fft,这里才开始fft
  Fixed15FFT::apply_window(vTemp_);
  Fixed15FFT::calc_fft(vTemp_, vTemp_ + SAMPLES);

  int values[BANDS] = {};
  for (int i = 2; i < (SAMPLES/2); i++){ 
    // Don't use sample 0 and only first SAMPLES/2 are usable. 
    // Each array element represents a frequency and its value the amplitude.
    int ampsq = vTemp_[i] * vTemp_[i] + vTemp_[i + SAMPLES] * vTemp_[i + SAMPLES];
    if (ampsq > NOISE_FLOOR) {
      byte bandNum = getBand(i);
      if(bandNum != 8 && ampsq > values[bandNum]) {
        values[bandNum] = ampsq;
      }
    }
  }
  for (byte band = 0; band < BANDS; band++) {
    int log2 = FIX_LOG2<15-RSHIFT>(values[band]);
    if (log2 > -AMPLIFIER) {
      displayBand(band, ((log2 + AMPLIFIER) * MAGNIFY) >> RSHIFT2);
    } else {
      displayBand(band, 0);
    }
  }
  //long vnow = millis();
  for (byte band = 0; band < BANDS; band++) {
    if (audiospectrum[band].peak > 0) {
      audiospectrum[band].peak -= 2;
    }
    if(audiospectrum[band].peak <= 0) {
      audiospectrum[band].peak = 0;
    }
    // only draw if peak changed
    if(audiospectrum[band].lastpeak != audiospectrum[band].peak) {

      // delete last peak
      uint16_t hpos = BANDS_WIDTH*band + (BANDS_PADDING/2);
      M5.Lcd.drawFastHLine(hpos,TFT_HEIGHT-audiospectrum[band].lastpeak,BAR_WIDTH,BLACK);
      audiospectrum[band].lastpeak = audiospectrum[band].peak;
      uint16_t ypos = TFT_HEIGHT - audiospectrum[band].peak;
      M5.Lcd.drawFastHLine(hpos, ypos,
                           BAR_WIDTH, colormap[ypos]);
    }
  }
  // struct eqBand {
  //   const char *freqname;
  //   int peak;
  //   int lastpeak;
  //   uint16_t lastval;
  // };
  //试着打印一下
  // Serial.print(audiospectrum[0].peak);
  // Serial.print(","); 
  // Serial.print(audiospectrum[1].peak);
  // Serial.print(","); 
  // Serial.print(audiospectrum[2].peak);
  // Serial.print(","); 
  // Serial.print(audiospectrum[3].peak);
  // Serial.print(","); 
  // Serial.print(audiospectrum[4].peak);
  // Serial.print(","); 
  // Serial.print(audiospectrum[5].peak);
  // Serial.print(","); 
  // Serial.print(audiospectrum[6].peak);
  // Serial.print(","); 
  // Serial.print(audiospectrum[7].peak);
  // Serial.println();
  int dy10=audiospectrum[1].lastpeak-audiospectrum[0].lastpeak;
  int dy21=audiospectrum[2].lastpeak-audiospectrum[1].lastpeak;
  int dy32=audiospectrum[3].lastpeak-audiospectrum[2].lastpeak;
  int dy43=audiospectrum[4].lastpeak-audiospectrum[3].lastpeak;
  int dy54=audiospectrum[5].lastpeak-audiospectrum[4].lastpeak;
  int dy65=audiospectrum[6].lastpeak-audiospectrum[5].lastpeak;
  int dy76=audiospectrum[7].lastpeak-audiospectrum[6].lastpeak;
  float dddx = 0.455;
  //计算出来了7段斜率
  float slope01 = dy10 / dddx;
        slope01 = fabs(slope01) < 5? 0 : slope01;
  float slope12 = dy21 / dddx;
        slope12 = fabs(slope12) < 5? 0 : slope12;
  float slope23 = dy32 / dddx;
        slope23 = fabs(slope23) < 5? 0 : slope23;
  float slope34 = dy43 / dddx;
        slope34 = fabs(slope34) < 5? 0 : slope34;
  float slope45 = dy54 / dddx;
        slope45 = fabs(slope45) < 5? 0 : slope45;
  float slope56 = dy65 / dddx;
        slope56 = fabs(slope56) < 5? 0 : slope56;
  float slope67 = dy76 / dddx;
        slope67 = fabs(slope67) < 5? 0 : slope67;

  Serial.print(slope01);
  Serial.print(",");
  Serial.print(slope12);
  Serial.print(",");
  Serial.print(slope23);
  Serial.print(",");
  Serial.print(slope34);
  Serial.print(",");
  Serial.print(slope45);
  Serial.print(",");
  Serial.print(slope56);
  Serial.print(",");
  Serial.print(slope67);
  Serial.println();

}

void displayBand(int band, int dsize){
  uint16_t hpos = BANDS_WIDTH*band + (BANDS_PADDING/2);
  if (dsize < 0) dsize = 0;
  if(dsize>TFT_HEIGHT-10) {
    dsize = TFT_HEIGHT-10; // leave some hspace for text
  }
  if(dsize < audiospectrum[band].lastval) {
    // lower value, delete some lines
    M5.Lcd.fillRect(hpos, TFT_HEIGHT-audiospectrum[band].lastval,
                    BAR_WIDTH, audiospectrum[band].lastval - dsize,BLACK);
  }
  for (int s = 0; s <= dsize; s=s+4){
    uint16_t ypos = TFT_HEIGHT - s;
    M5.Lcd.drawFastHLine(hpos, ypos, BAR_WIDTH, colormap[ypos]);
  }
  if (dsize > audiospectrum[band].peak){audiospectrum[band].peak = dsize;}
  audiospectrum[band].lastval = dsize;
}

byte getBand(int i) {
  if (i >= 2   && i < 4  ) return 0;  // 125Hz
  if (i >= 4   && i < 8  ) return 1;  // 250Hz
  if (i >= 8   && i < 16 ) return 2;  // 500Hz
  if (i >= 16  && i < 32 ) return 3;  // 1000Hz
  if (i >= 32  && i < 64 ) return 4;  // 2000Hz
  if (i >= 64  && i < 128) return 5;  // 4000Hz
  if (i >= 128 && i < 256) return 6;  // 8000Hz
  if (i >= 256 && i < 512) return 7;  // 16000Hz
  return 8;
}

float calcNumSamples(float f) {
  if (f == 0.0) return SAMPLES/2;
  float s = (float)(DYWAPT_SAMPLERATE * 2) / f;
  if (s > (float)OSC_SAMPLES) {
    do {
      s *= 0.5;
    } while (s > (float)OSC_SAMPLES);
  }
  return s;
}

void showFreq(float freq) {
  if (freq > 0) {
    char strbuf[16];
    sprintf(strbuf, "%8.2fHz", freq);
    sprite.drawString(strbuf, 0, 0, 1);
    float fnote = log2(freq)*12 - 36.376316562f;
    int note = fnote + 0.5f;
    if (note >= 0) {
      sprite.setCursor(TFT_WIDTH/2 - 4, 0);
      sprite.print(notestr[note % 12]);
      sprite.print(note / 12 - 1);
      float cent = (fnote - note) * 100;
      sprintf(strbuf, "%.1fcents", cent);
      sprite.drawRightString(strbuf, TFT_WIDTH, 0, 1);
    }
  }
}

void showOscilloscope()
{
  uint16_t i,j;
  uint16_t *oscbuf_ = oscbuf[curbuf^1];
  int *vTemp_ = vTemp[curbuf ^ 1];
  float freq = dywapitch_computepitch(&pitchTracker, vTemp_);
#if OSC_EXTRASKIP > 0
  if (skipcount < OSC_EXTRASKIP) {
    ++skipcount;
    return;
  }
  skipcount = 0;
#endif
  uint16_t s = calcNumSamples(freq);
  float mx = (float)TFT_WIDTH / s;
  float my;
  uint16_t maxV = 0;
  uint16_t minV = 65535;
  uint16_t offset = 0;
  for (i = 0; i < s; ++i) {
    if (maxV < oscbuf_[i]) maxV = oscbuf_[i];
    if (minV > oscbuf_[i]) {
      minV = oscbuf_[i];
      if (i + s <= OSC_SAMPLES) offset = i;
    }
  }
  if (maxV - minV > OSC_NOISEFLOOR) {
    my = (float)(TFT_HEIGHT-10) / (maxV - minV);
  }
  else {
    my = (float)(TFT_HEIGHT-10) / OSC_NOISEFLOOR;
    minV = (((int)maxV + (int)minV) >> 1) - OSC_NOISEFLOOR/2;
  }
  sprite.fillSprite(BLACK);
  uint16_t y = TFT_HEIGHT - (oscbuf_[offset] - minV) * my;
  for (i = 1; i < s; ++i) {
    uint16_t y2 = TFT_HEIGHT - (oscbuf_[offset + i] - minV) * my;
    sprite.drawLine((uint16_t)((i-1) * mx), y,
                    (uint16_t)(i * mx), y2, LIGHTGREY);
    y = y2;
  }
  showFreq(freq);
  sprite.pushSprite(0,0);
}

void showTuner() {
  int *vTemp_ = vTemp[curbuf ^ 1];
  float freq = dywapitch_computepitch(&pitchTracker, vTemp_);
  float fnote;
  int note;
  if (freq > 0) {
    fnote = log2(freq)*12 - 36.376316562;
    note = fnote + 0.5;
  }
  else {
    note = -1;
  }
  uint32_t bgcolor, fgcolor;
  if (note >= 0) {
    float cent = (fnote - note) * 100;
    if (abs(cent) < 2.) {
      bgcolor = GREEN;
      fgcolor = BLACK;
    }
    else {
      bgcolor = DARKGREY;
      fgcolor = BLACK;
    }
    sprite.fillSprite(bgcolor);
    sprite.fillRect(0, 0, TFT_WIDTH, TFT_HEIGHT, bgcolor);
    sprite.setTextColor(fgcolor);
    sprite.drawRect(2, 36, TFT_WIDTH-3, TFT_HEIGHT-40, fgcolor);
    sprite.drawRect(TFT_WIDTH/2 + 1, 36, 1, TFT_HEIGHT - 40, fgcolor);
    sprite.fillCircle(((float)TFT_WIDTH/2 + 1) + cent * ((float)(TFT_WIDTH-3)/100), (TFT_HEIGHT+34)/2, 5, fgcolor);
    char strbuf[8];
    sprintf(strbuf, "%s%d", notestr[note % 12], note / 12 - 1);
    sprite.drawCentreString(strbuf, TFT_WIDTH/2, 3, 4);
  }
  else {
    sprite.fillSprite(DARKGREY);
    sprite.drawRect(2, 36, TFT_WIDTH-3, TFT_HEIGHT-40, BLACK);
    sprite.drawLine(TFT_WIDTH/2 + 1, 36, TFT_WIDTH/2 + 1, TFT_HEIGHT - 4, BLACK);
  }
  sprite.pushSprite(0,0);
}

void looptask(void *) {
  while (1) {
    if (needinit) {
      initMode();
      needinit = false;
    }
    if (semaphore) {
      switch(runmode) {
        case ModeSpectrumBars:
          //这里是实际的进入函数
          showSpectrumBars();
          break;

        case ModeOscilloscope:
          showOscilloscope();
          break;

        case ModeTuner:
          showTuner();
          break;
      }
      semaphore = false;
    }
    else {
      vTaskDelay(10);
    }
  }
}

void loop() {
  M5.update();
  if (M5.BtnA.wasReleased()) {
    ++runmode;
    if (runmode >= ModeCount) runmode = 0;

    bufposcount = 0;
    needinit = true;;
  }
  uint16_t i,j;
  j = bufposcount * SAMPLES;
  uint16_t *adcBuffer = &oscbuf[curbuf][j];
  size_t bytesread;
  i2s_read(I2S_NUM_0,(char*)adcBuffer,READ_LEN,&bytesread,portMAX_DELAY);
  int32_t dc = 0;
  for (int i = 0; i < SAMPLES; ++i) {
    dc += adcBuffer[i];
  }
  dc /= SAMPLES;

  switch(runmode) {
    case ModeSpectrumBars:
      for (int i = 0; i < SAMPLES; ++i) {
        
        vTemp[curbuf][i] = (int)adcBuffer[i] - dc;
        vTemp[curbuf][i + SAMPLES] = 0;
      }
      curbuf ^= 1;
      semaphore = true;
      break;
    case ModeOscilloscope:
      for (i = 0; i < SAMPLES; ++i) {
        vTemp[curbuf][i + j] = (int)adcBuffer[i] - dc;
      }
      if (++bufposcount >= OSC_SKIPCOUNT) {
        bufposcount = 0;
        curbuf ^= 1;
        semaphore = true;
      }
      break;
    case ModeTuner:
      j = bufposcount * SAMPLES;
      for (i = 0; i < SAMPLES; ++i) {
        vTemp[curbuf][i + j] = (int)adcBuffer[i] - dc;
      }
      if (++bufposcount >= OSC_SKIPCOUNT) {
        bufposcount = 0;
        curbuf ^= 1;
        semaphore = true;
      }
      break;
  }
}

 https://github.com/KKQ-KKQ/m5stickc-audiospectrum/tree/master 

用的是这个,但是有修改:


  int dy10=audiospectrum[1].lastpeak-audiospectrum[0].lastpeak;
  int dy21=audiospectrum[2].lastpeak-audiospectrum[1].lastpeak;
  int dy32=audiospectrum[3].lastpeak-audiospectrum[2].lastpeak;
  int dy43=audiospectrum[4].lastpeak-audiospectrum[3].lastpeak;
  int dy54=audiospectrum[5].lastpeak-audiospectrum[4].lastpeak;
  int dy65=audiospectrum[6].lastpeak-audiospectrum[5].lastpeak;
  int dy76=audiospectrum[7].lastpeak-audiospectrum[6].lastpeak;
  float dddx = 0.455;
  //计算出来了7段斜率
  float slope01 = dy10 / dddx;
        slope01 = fabs(slope01) < 5? 0 : slope01;
  float slope12 = dy21 / dddx;
        slope12 = fabs(slope12) < 5? 0 : slope12;
  float slope23 = dy32 / dddx;
        slope23 = fabs(slope23) < 5? 0 : slope23;
  float slope34 = dy43 / dddx;
        slope34 = fabs(slope34) < 5? 0 : slope34;
  float slope45 = dy54 / dddx;
        slope45 = fabs(slope45) < 5? 0 : slope45;
  float slope56 = dy65 / dddx;
        slope56 = fabs(slope56) < 5? 0 : slope56;
  float slope67 = dy76 / dddx;
        slope67 = fabs(slope67) < 5? 0 : slope67;

  Serial.print(slope01);
  Serial.print(",");
  Serial.print(slope12);
  Serial.print(",");
  Serial.print(slope23);
  Serial.print(",");
  Serial.print(slope34);
  Serial.print(",");
  Serial.print(slope45);
  Serial.print(",");
  Serial.print(slope56);
  Serial.print(",");
  Serial.print(slope67);
  Serial.println();


2、PC机器接收串口并转发至mqtt部分(同时也是保存csv的部分):

import serial
import pandas as pd
import paho.mqtt.client as mqtt
import json

try:
    # 创建串口对象
    ser = serial.Serial('COM5', 115200)
except serial.SerialException as e:
    print(f"Error opening COM5: {e}")

export_csv_filename = 0

def on_connect(client, userdata, flags, rc):
    if rc == 0:
        print("Connected to MQTT broker")
    else:
        print("Connection failed with code", rc)



client = mqtt.Client()
client.on_connect = on_connect

client.connect("192.168.50.232", 1883, 60)

while True:
    loops_nums = 0
    # 创建一个空的 DataFrame
    df = pd.DataFrame(columns=["slope01","slope12","slope23","slope34","slope45","slope56","slope67"])
    while loops_nums< 120:
        try:
            # 读取数据
            data = ser.readline()
            data_array = []
            data_dict  = {}
            if data:
                decodeLine = data.decode('utf-8').strip()
                #print(decodeLine)
                values = decodeLine.split(',')
                #将合法的数据解析为一个数组
                for value in values:
                    try:
                        #print(value)
                        f_temp = float(value)
                        data_array.append(f_temp)
                    except ValueError:
                        pass
                # 将数组转换为字典
                data_dict = {
                            "slope01": data_array[0], 
                            "slope12": data_array[1], 
                            "slope23": data_array[2],
                            "slope34": data_array[3],
                            "slope45": data_array[4],
                            "slope56": data_array[5],
                            "slope67": data_array[6],
                            }
                #print("打印data_dict啊")
                #print(data_dict)
                # 使用 concat 函数添加
                new_df = pd.DataFrame([data_dict],index=[loops_nums])
                df = pd.concat([df, new_df])
        except Exception as e:
            print(f"programe error: {e}")
            pass
        loops_nums=loops_nums+1
    print("打印df啊")
    print(df)
    # export_csv_filename = export_csv_filename +1
    # df.to_csv('traindata_'+ str(export_csv_filename) +'.csv', index=False)  # 不保存索引
    # 序列化数据框为 JSON 字符串
    serialized_df = json.dumps(df.to_dict())
    client.publish("lemon_ken_mic", serialized_df)

# 关闭串口
ser.close()
print("======================关闭串口====================")

# python -m venv venv
# .\venv\Scripts\Activate.ps1
# pip install pyserial
# pip install paho-mqtt
# pip install pandas
# python mqtt_send.py



3、训练以及评估部分:

import torch
import pandas as pd

# 定义神经网络类
class AudioClassifier(torch.nn.Module):
    def __init__(self, batch_size, in_features):
        super(AudioClassifier, self).__init__()
        self.layer1 = torch.nn.Linear(in_features, 128)
        self.additional_layer = torch.nn.Linear(128, 64)
        self.layer2 = torch.nn.Linear(64, 32)
        self.layer3 = torch.nn.Linear(32, 2)  # 假设两类

    def forward(self, x):
        x = torch.relu(self.layer1(x))
        x = torch.relu(self.additional_layer(x))  # 使用新增层
        x = torch.relu(self.layer2(x))
        x = self.layer3(x)
        return x

# 读取数据并处理
def load_data(file_paths):
    data = []
    labels = []
    for path in file_paths:
        df = pd.read_csv(path, skiprows=1)  # 跳过第一行
        features = df.values
        data.append(features)
        # 根据文件名确定标签
        if "c0" in path:
            labels.append(0)  
        elif "c1" in path:
            labels.append(1) 
    return torch.tensor(data), torch.tensor(labels)

# 示例用法
train_file_paths = ["traindata_c0_1.csv","traindata_c0_2.csv","traindata_c0_3.csv","traindata_c0_4.csv"
              ,"traindata_c0_5.csv","traindata_c0_6.csv","traindata_c0_11.csv","traindata_c0_12.csv",
              "traindata_c0_13.csv","traindata_c0_14.csv","traindata_c0_15.csv","traindata_c0_16.csv",
              "traindata_c0_17.csv","traindata_c0_18.csv","traindata_c0_19.csv","traindata_c0_20.csv",
              "traindata_c0_21.csv","traindata_c0_22.csv",
              "traindata_c1_1.csv","traindata_c1_2.csv","traindata_c1_3.csv","traindata_c1_4.csv"
              ,"traindata_c1_5.csv","traindata_c1_6.csv","traindata_c1_11.csv","traindata_c1_12.csv"
              ,"traindata_c1_13.csv","traindata_c1_14.csv","traindata_c1_15.csv","traindata_c1_16.csv"
              ,"traindata_c1_17.csv","traindata_c1_18.csv","traindata_c1_19.csv","traindata_c1_20.csv"]
train_data, train_labels = load_data(train_file_paths)

print("train_data.dtype:"+str(train_data.dtype))
print("train_labels.dtype:"+str(train_labels.dtype))

# 创建神经网络
# 对于 torch.nn.Linear,它通常期望输入的是一个二维张量,
# 其中第一维可以理解为批量大小(在你的例子中就是 20 个样本),
# 第二维是每个样本的特征数量。在你这个情况中,你需要先将每个样本(119 行×7 列)展平成一个一维向量,
# 这样每个样本就变成了一个长度为 119×7=833 的向量。然后将这 20 个展平后的样本组合成一个二维张量,
# 其形状就是 (20, 833),将这个二维张量作为输入传递给 torch.nn.Linear。
# 例如,如果有一个输入张量x,其形状为[batch_size, in_features],
# 那么经过Linear层的变换后,输出张量y的形状将为[batch_size, out_features]。
num_of_samples = train_data.size(0)
num_of_rows = train_data[0].shape[0]
num_of_cols = train_data[0].shape[1]
print("num_of_samples:")
print(num_of_samples)
print("num_of_rows:")
print(num_of_rows)
print("num_of_cols:")
print(num_of_cols)
model = AudioClassifier(num_of_samples,num_of_rows*num_of_cols)

# 要保持原始张量的维度顺序不变,可以使用permute方法进行维度变换。
# permute方法接受一个参数,用于指定新的维度顺序。例如,如果你想将最后两维展平,可以使用以下代码:
# 上面说法根本就是错的,需要用到reshape方法,把张量reshape成一个[batch_size, in_features]的东西
new_train_data_tensor = train_data.reshape(num_of_samples, 119*7)
print("new_train_data_tensor shape:")
print(new_train_data_tensor.shape)  # 输出新张量的形状

# 定义损失函数和优化器
loss_func = torch.nn.CrossEntropyLoss()
criterion = torch.nn.BCELoss()  # 使用默认参数
optimizer = torch.optim.Adam(model.parameters())


print("=========tow input shapes=============")
print(new_train_data_tensor.shape, train_labels.shape)
print("=========tow input shapes=============")


# 训练循环
for epoch in range(200):
    outputs = model(new_train_data_tensor.to(torch.float32))
    #outputs = torch.sigmoid(outputs)  # 应用sigmoid函数确保输出在0和1之间
    loss = loss_func(outputs, train_labels)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

print("==========outputs.shape:============")
print(outputs.shape)
# print("==========outputs:============")
# print(outputs)


# ==========outputs:============
# tensor([[ 10.3342,  -3.6685],
#         [ 13.9501,  -0.4912],
#         [ 15.0824,  -1.4321],
#         [ 16.4281,  -1.1477],
#         [ 11.7916,  -4.5559],
#         [ 15.0455,  -5.4270],
#         [ 10.3306,  -8.4917],
#         [  9.7162,  -8.0242],
#         [ 11.2048,  -7.3689],
#         [  6.1282,  -6.2711],
#         [ -3.8303,   7.0840],
#         [ -4.3532,  17.0647],
#         [-15.7246,  44.2592],
#         [ -7.8866,  12.5852],
#         [ -8.0347,  10.8943],
#         [ -1.7064,   9.8779],
#         [ -2.2770,  11.1365],
#         [ -3.4305,  18.9242],
#         [ -3.7132,  28.6930],
#         [-17.9371,  42.6237]], grad_fn=<AddmmBackward0>)

# 所以说这种输出的含义就是,它认为,属于1类别的概率是正数,2类别的概率甚至是个负数是吧...明白了;
# 可以这样理解。在这个输出中,每个子张量的第一个元素(正数)可以被视为模型预测该样本属于第一类别的
# 概率估计值,第二个元素(可能是负数)则是模型预测该样本属于第二类别的概率估计值。
# 需要注意的是,这里的数值本身并不一定直接对应严格意义上的概率,因为模型的输出可能
# 没有经过专门的归一化处理以确保和为 1,但通常可以相对地理解为表示某种倾向或可能性的大小。
# 而且负数在这种情况下也只是表示相对的大小关系,并不意味着真正的负概率。
# 在实际应用中,一般会通过合适的方法将这些值转换为更符合概率解释的形式。

print("==========probabilities:============")
probabilities = torch.softmax(outputs, dim=1)
print(probabilities)

# tensor([[1.0000e+00, 4.4273e-09],
#         [1.0000e+00, 1.5222e-12],
#         [1.0000e+00, 3.5589e-12],
#         [1.0000e+00, 3.0917e-11],
#         [1.0000e+00, 5.3411e-10],
#         [1.0000e+00, 1.3765e-08],
#         [1.0000e+00, 1.4835e-08],
#         [1.0000e+00, 4.2492e-07],
#         [1.0000e+00, 5.1067e-07],
#         [9.9997e-01, 2.8726e-05],
#         [5.4728e-06, 9.9999e-01],
#         [5.5860e-08, 1.0000e+00],
#         [3.0678e-19, 1.0000e+00],
#         [9.8549e-08, 1.0000e+00],
#         [9.0871e-08, 1.0000e+00],
#         [4.0230e-06, 1.0000e+00],
#         [4.5366e-06, 1.0000e+00],
#         [9.9386e-07, 1.0000e+00],
#         [5.5348e-12, 1.0000e+00],
#         [1.3492e-18, 1.0000e+00]], grad_fn=<SoftmaxBackward0>)

#====================================================================================================

# 模型评估
eval_file_paths = ["traindata_c0_7.csv","traindata_c0_8.csv","traindata_c0_9.csv","traindata_c0_10.csv",
                   "traindata_c1_7.csv","traindata_c1_8.csv","traindata_c1_9.csv","traindata_c1_10.csv"
                   ,"traindata_c1_21.csv","traindata_c1_22.csv","traindata_c1_23.csv"]
eval_data, eval_labels = load_data(eval_file_paths)

# 要保持原始张量的维度顺序不变,可以使用permute方法进行维度变换。
# permute方法接受一个参数,用于指定新的维度顺序。例如,如果你想将最后两维展平,可以使用以下代码:
# 上面说法根本就是错的,需要用到reshape方法,把张量reshape成一个[batch_size, in_features]的东西
num_of_eval_samples = eval_data.size(0)
new_eval_data_tensor = eval_data.reshape(num_of_eval_samples, 119*7)
print("new_eval_data_tensor shape:")
print(new_eval_data_tensor.shape)  # 输出新张量的形状


# 评估模型
with torch.no_grad():
    test_outputs = model(new_eval_data_tensor.to(torch.float32))
    predicted_labels = torch.argmax(test_outputs, dim=1)
    accuracy = (predicted_labels == eval_labels).sum().item() / eval_labels.size(0)
    print("Accuracy:", accuracy)

print("==========评估阶段的predicted_labels:============")
print(predicted_labels)

print("==========评估阶段的eval_labels:============")
print(eval_labels)

print("==========评估阶段的probabilities:============")
probabilities = torch.softmax(test_outputs, dim=1)
print(probabilities)

# 假设已经训练好的模型为 model
torch.save(model.state_dict(), 'odel_weights.pth')


4、推理部分

import torch
import pandas as pd
import paho.mqtt.client as mqtt
import json
import tkinter as tk

# 定义神经网络类
class AudioClassifier(torch.nn.Module):
    def __init__(self, batch_size, in_features):
        super(AudioClassifier, self).__init__()
        self.layer1 = torch.nn.Linear(in_features, 128)
        self.additional_layer = torch.nn.Linear(128, 64)
        self.layer2 = torch.nn.Linear(64, 32)
        self.layer3 = torch.nn.Linear(32, 2)  # 假设两类

    def forward(self, x):
        x = torch.relu(self.layer1(x))
        x = torch.relu(self.additional_layer(x))  # 使用新增层
        x = torch.relu(self.layer2(x))
        x = self.layer3(x)
        return x
    
model = AudioClassifier(1,119*7)  # 定义模型结构
model.load_state_dict(torch.load('odel_weights.pth'))
model.eval()  # 设置为评估模式

# # 输入数据
# input_data = torch.tensor([[1.0, 2.0, 3.0, 4.0, 5.0]])

# # 使用模型进行预测
# output = model(input_data)

# 对输出进行处理或查看结果
# print(output)

def update_circle(probabilities):
    print("====update_circle我拿到了一个啥:====")
    print(probabilities)  # 输出新张量的形状
    if probabilities[0] > 0.8:
        canvas.itemconfig(circle, fill='red')
    else:
        canvas.itemconfig(circle, fill='lightblue')

        
root = tk.Tk()
canvas = tk.Canvas(root, width=400, height=400)
canvas.pack()

circle = canvas.create_oval(100, 100, 300, 300, fill='lightblue')


def on_message(client, userdata, msg):
    received_data = json.loads(msg.payload.decode())
    df = pd.DataFrame(received_data)
    #df.values天然就已经去掉表头了,pandas里面的这个语法就是在取值
    print(df.values)
    values119 = df.values[:-1]
    # 将数据转换为张量
    tensor = torch.tensor(values119)
    print("====reviced tensor shape:====")
    print(tensor.shape)  # 输出新张量的形状
    new_predict_tensor = tensor.reshape(119*7)
    # 使用模型进行预测
    output = model(new_predict_tensor.to(torch.float32))
    #对输出进行处理或查看结果
    print(output)
    print("====output tensor shape:====")
    print(output.shape)  # 输出新张量的形状
    print("==========probabilities:============")
    probabilities = torch.softmax(output, dim=0)
    print(probabilities)
    update_circle(probabilities)

client = mqtt.Client()
client.on_message = on_message

client.connect("192.168.50.232", 1883,65535)
client.subscribe("lemon_ken_mic")
client.loop_start()

root.mainloop()

# python -m venv venv
# .\venv\Scripts\Activate.ps1
# pip install pyserial
# python main.py
# pip install pandas
# pip install PyQt5


 https://github.com/lemonhall/esp32_pytorch_audio_nn/tree/main