""" music_separate_and_classify.py Pipeline: - Separate stems with Spleeter (4stems) - Extract pitch (librosa.pyin) from vocal stem - Cluster voiced frames into up to 5 clusters -> map to Soprano1, Soprano2, Alto, Bass1, Bass2 - Save each vocal cluster as a separate WAV and print instrument heuristics for the other stems. Usage: python music_separate_and_classify.py input_mix.wav output_folder/ Notes: - For best results use 44.1 kHz stereo WAV input. - This is a heuristic baseline. Improve by swapping Spleeter for Demucs and training an instrument classifier for the instrument step. """ import os import sys import numpy as np import soundfile as sf import librosa import librosa.display from sklearn.cluster import AgglomerativeClustering from spleeter.separator import Separator from scipy.signal import medfilt # ---------- Parameters ---------- SR = 44100 N_CLUSTERS = 5 # Soprano1, Soprano2, Alto, Bass1, Bass2 FRAME_LENGTH = 2048 HOP_LENGTH = 512 MIN_F0 = 65.0 # ~ low male F0 (Hz) MAX_F0 = 2000.0 # upper bound for singing (Hz) # ------------------------------- def separate_stems(input_path, out_dir): """Use Spleeter to separate into 4 stems: vocals, drums, bass, other""" os.makedirs(out_dir, exist_ok=True) print("Running Spleeter separation (4stems)...") separator = Separator('spleeter:4stems') # requires spleeter installed separator.separate_to_file(input_path, out_dir) # Spleeter writes a folder named like / with WAVs inside # We'll try to find the produced files base = os.path.splitext(os.path.basename(input_path))[0] sep_folder = os.path.join(out_dir, base) stems = {} mapping = { 'vocals': 'vocals.wav', 'drums': 'drums.wav', 'bass': 'bass.wav', 'other': 'other.wav' } for k, fname in mapping.items(): p = os.path.join(sep_folder, fname) if os.path.exists(p): stems[k] = p else: print(f"Warning: expected {p} but not found.") return stems def load_mono(path, sr=SR): y, _ = librosa.load(path, sr=sr, mono=True) return y def extract_f0_and_features(y, sr=SR): """Return f0 (Hz) per frame, voiced mask, and MFCC features per frame""" # librosa.pyin yields f0 estimate per frame (None for unvoiced) f0, voiced_flag, voiced_prob = librosa.pyin( y, fmin=MIN_F0, fmax=MAX_F0, sr=sr, frame_length=FRAME_LENGTH, hop_length=HOP_LENGTH ) # Replace NaN with 0 for clustering convenience but keep voiced_mask f0_clean = np.where(np.isnan(f0), 0.0, f0) # Get MFCCs per frame aligned with pyin frames S = np.abs(librosa.stft(y, n_fft=FRAME_LENGTH, hop_length=HOP_LENGTH)) mfcc = librosa.feature.mfcc(S=librosa.power_to_db(S**2), sr=sr, n_mfcc=13) # Transpose frames -> shape (n_frames, n_mfcc) mfcc_t = mfcc.T[:len(f0_clean)] return f0_clean, voiced_flag, mfcc_t def cluster_vocal_frames(f0, voiced, mfcc, n_clusters=N_CLUSTERS): """ Cluster only voiced frames using features [log(f0), mfcc]. Returns an array cluster_id per original frame (unvoiced frames = -1). """ voiced_idx = np.where(voiced)[0] if len(voiced_idx) == 0: print("No voiced frames detected.") return np.full(len(voiced), -1, dtype=int) # feature vector: log(f0) + mfcc logf0 = np.log1p(f0[voiced_idx]) feat = np.hstack((logf0.reshape(-1,1), mfcc[voiced_idx])) # simple agglomerative clustering with distance threshold or fixed clusters n_clusters = min(n_clusters, len(voiced_idx)) if n_clusters <= 1: labels = np.zeros(len(voiced_idx), dtype=int) else: clust = AgglomerativeClustering(n_clusters=n_clusters, linkage='average') labels = clust.fit_predict(feat) # Map back to full frame length frame_labels = np.full(len(voiced), -1, dtype=int) frame_labels[voiced_idx] = labels return frame_labels def reconstruct_clusters_to_audio(original_y, frame_labels, sr=SR, hop_length=HOP_LENGTH, frame_length=FRAME_LENGTH): """ Reconstruct audio for each cluster by applying a time-varying mask in STFT domain. Returns dict cluster_id -> audio_array """ S_complex = librosa.stft(original_y, n_fft=frame_length, hop_length=hop_length) magnitude, phase = np.abs(S_complex), np.angle(S_complex) n_frames = magnitude.shape[1] n_clusters = int(frame_labels.max()) + 1 if frame_labels.max() >= 0 else 0 # create soft masks per cluster based on voiced frames => use harmonic mask # We'll create frame-wise cluster masks (1 for frames belonging to cluster, else 0) frame_mask = np.zeros((n_clusters, n_frames)) for f in range(n_frames): lbl = frame_labels[f] if lbl >= 0: frame_mask[lbl, f] = 1.0 # Project masks to spectrogram time frames (they already align with STFT frames) cluster_audio = {} for c in range(n_clusters): mask = np.tile(frame_mask[c], (magnitude.shape[0], 1)) S_c = magnitude * mask * np.exp(1j * phase) y_c = librosa.istft(S_c, hop_length=hop_length, length=len(original_y)) # optional median filtering to smooth rapid switching y_c = medfilt(y_c, kernel_size=3) cluster_audio[c] = y_c return cluster_audio def map_clusters_to_voice_parts(cluster_audio, sr=SR): """ Map clusters to voice part labels by median pitch. Compute median pitch via librosa.pyin on each cluster. """ medians = [] for c, y in cluster_audio.items(): f0, voiced, _ = extract_f0_and_features(y, sr=sr) # median of positive F0s pos = f0[f0>0] med = np.median(pos) if len(pos)>0 else 0 medians.append((c, med)) # sort clusters by descending median frequency medians_sorted = sorted(medians, key=lambda x: x[1], reverse=True) labels = {} voice_order = ["Soprano 1", "Soprano 2", "Alto", "Bass 1", "Bass 2"] for idx, (cluster_id, median_hz) in enumerate(medians_sorted): label = voice_order[idx] if idx < len(voice_order) else f"Voice_{idx+1}" labels[cluster_id] = (label, median_hz) return labels def instrument_heuristics(path, sr=SR): """Very rough instrument tagging using spectral heuristics.""" y = load_mono(path, sr=sr) S = np.abs(librosa.stft(y, n_fft=FRAME_LENGTH, hop_length=HOP_LENGTH)) # spectral centroid and rolloff centroid = librosa.feature.spectral_centroid(S=S, sr=sr).mean() rolloff = librosa.feature.spectral_rolloff(S=S, sr=sr).mean() # percussive vs harmonic energy y_h, y_p = librosa.effects.hpss(y) percussive_energy = np.sum(np.abs(y_p)) harmonic_energy = np.sum(np.abs(y_h)) percussive_ratio = percussive_energy / (harmonic_energy + 1e-9) low_energy = np.mean(librosa.feature.rms(S=S)[0][:10]) # low freq rms proxy tags = [] # drums heuristic if percussive_ratio > 0.6 and centroid > 2000: tags.append("Drum/Percussion") # bass heuristic if centroid < 800 or low_energy > 0.01: tags.append("Bass / Low-frequency instrument") # harmonic instrument guess if harmonic_energy > percussive_energy: tags.append("Harmonic (guitar/piano/strings)") if not tags: tags.append("Unknown / mixed") return { "centroid": centroid, "rolloff": rolloff, "percussive_ratio": percussive_ratio, "tags": tags } def save_wav(y, path, sr=SR): sf.write(path, y, sr) print(f"Saved: {path}") def main(input_path, out_dir): stems = separate_stems(input_path, out_dir) # Load vocal stem mono if 'vocals' not in stems: print("No vocals stem found. Exiting.") return vocals_path = stems['vocals'] print("Loading vocal stem:", vocals_path) y_vocal = load_mono(vocals_path, sr=SR) f0, voiced, mfcc = extract_f0_and_features(y_vocal, sr=SR) frame_labels = cluster_vocal_frames(f0, voiced, mfcc, n_clusters=N_CLUSTERS) print("Frame-wise cluster ids (sample):", frame_labels[:50]) cluster_audio = reconstruct_clusters_to_audio(y_vocal, frame_labels, sr=SR) mapping = map_clusters_to_voice_parts(cluster_audio, sr=SR) # Save each cluster audio and print mapping base = os.path.splitext(os.path.basename(input_path))[0] out_base = os.path.join(out_dir, base + "_vocal_parts") os.makedirs(out_base, exist_ok=True) for c, y in cluster_audio.items(): label, median_hz = mapping.get(c, (f"Voice_{c}", 0)) fname = f"{label.replace(' ','_')}_cluster{c}.wav" save_wav(y, os.path.join(out_base, fname), sr=SR) print(f"Cluster {c} -> {label}, median F0: {median_hz:.1f} Hz") # Instrument heuristics for other stems print("\nInstrument heuristics for non-vocal stems:") for s in ['drums', 'bass', 'other']: if s in stems: h = instrument_heuristics(stems[s], sr=SR) print(f"Stem: {s} -> tags: {h['tags']}, centroid={h['centroid']:.1f}, percussive_ratio={h['percussive_ratio']:.2f}") else: print(f"Stem {s} not found.") if __name__ == "__main__": if len(sys.argv) < 3: print("Usage: python music_separate_and_classify.py ") else: main(sys.argv[1], sys.argv[2])