/* This program implements the Krumhansl-Schmuckler key-finding algorithm. It takes input in
the form of "Note" statements (Note [ontime] [offtime] [pitch]) or "TPCNote" statements 
(TPCNote [ontime] [offtime] [pitch] [TPC]). It does not segment the piece at all, but simply
creates an input vector for the entire piece and chooses the best-matching key. With
verbosity=0, it simply outputs the best key; with verbosity=2, it outputs the correlation 
value for all keys, and other information. */

#include <stdio.h>
#include <string.h>
#include <math.h>

FILE *in_file;

int numnotes;
int total_duration;

double major_profile[] = {6.35, 2.23, 3.48, 2.33, 4.38, 4.09, 2.52, 5.19, 2.39, 3.66, 2.29, 2.88};
double minor_profile[] = {6.33, 2.68, 3.52, 5.38, 2.60, 3.53, 2.54, 4.75, 3.98, 2.69, 3.34, 3.17};

int verbosity = 0;

typedef struct note_struct {
  int ontime;
  int offtime;
  int duration;
  int pitch;
  int tpc;
  int npc;
} blah;

struct note_struct note[10000];

double average_dur;
double key_profile[24][12];
double input_prof[12];
double key_score[24];

char letter[] = {'C', 'D', 'E', 'F', 'G', 'A', 'B'};

FILE * in_file;

main(argc, argv)
int argc;
char *argv[];

{
    char line[100];
    char noteword[10];
    char junk[10];
    int z=0, s=0, j, line_no=0, i, pitch, tpc_found, npc_found;
    
    if(argc==2) {
	in_file = fopen(argv[1], "r");
	if (in_file == NULL) {
	    printf("I can't open that file\n");
	    exit(1);
	}
    }
    else in_file = stdin;

    tpc_found = 0;
    npc_found = 0;

    while (fgets(line, sizeof(line), in_file) !=NULL) {            /* read in TPC_Notes, Chords, and Beats */
	line_no++;
	for (i=0; isspace(line[i]); i++);
	if (line[i]=='%') continue;                                  /* Ignore comments and blank lines */
	if (sscanf (line, "%s", noteword) !=1) continue;             
	
	if (strcmp (noteword, "TPCNote") == 0) { 
	    
	    if(npc_found == 1) {
		printf("Error: Can't combine Notes and TPCNotes in a single input file\n");
		exit(1);
	    }
	    if (sscanf (line, "%s %d %d %d %d %10s", noteword, &note[z].ontime, &note[z].offtime, &note[z].pitch,
			&note[z].tpc, junk) !=5) {
		printf("Bad input\n");
		exit(1);
	    }
	    note[z].duration = note[z].offtime - note[z].ontime;
	    
	    note[z].npc = (note[z].pitch % 12);
	    
	    /*printf("%d %d %d %d\n", note[z].ontime, note[z].offtime, note[z].pitch, note[z].tpc);  */
	    total_duration = total_duration + note[z].duration;
	    tpc_found = 1;
	    z++;
      
	}
	
	else if (strcmp (noteword, "Note") == 0) {  
	    
	    if(tpc_found == 1) {
		printf("Error: Can't combine Notes and TPCNotes in a single input file\n");
		exit(1);
	    }
	    
	    if (sscanf (line, "%s %d %d %d %10s", noteword, &note[z].ontime, &note[z].offtime, &pitch, junk)!=4) {
		printf("Bad input\n");
		exit(1);
	    }
	    
	    note[z].npc = pitch % 12;

	    note[z].duration = note[z].offtime - note[z].ontime;
	    /* printf("%d %d %d %d\n", note[z].ontime, note[z].offtime, pitch, note[z].tpc);   */
	    total_duration = total_duration + note[z].duration;
	    npc_found = 1;
	    z++;
	}

    }
    /*    printf("the number of events is %d\n", z);           */
    /*	  printf("total duration is %d\n", total_duration);    */

    numnotes = z;

    if(z==0) {
	printf("Error: No notes in input.\n");
	exit(1);
    }

    tally_notes();
    prepare_profiles(); 
    generate_npc_profiles();
    match_profiles(); 
}



print_keyname(int f) {
    if(f==0 || f==12) printf("C");
    else if(f==1 || f==2 || f==13 || f==14) printf("D");
    else if(f==3 || f==4 || f==15 || f==16) printf("E");
    else if(f==5 || f==6 || f==17 || f==18) printf("F");
    else if(f==7 || f==19) printf("G");
    else if(f==8 || f==9 || f==20 || f==21) printf("A");
    else printf("B");

    if(f==1 || f==13 || f==3 || f==15 || f==8 || f==20 || f==10 || f==22) printf("b");
    if(f==6 || f==18) printf("#");

    if(f >= 12) printf("m");

    printf("\n");
}


tally_notes() {                             /* Tally up the notes of each PC */
    int y, n;
    double total_dur;
    
    for (y=0; y<12; y++) {                        /* cycle through the pc's, make sure all the input_prof values are zero */
	input_prof[y] = 0;
    }
	
    total_dur = 0.0;
	
    for (n=0; n<numnotes; n++) {                 
	total_dur += (double)note[n].duration;
	for (y=0; y<12; y++) {                        
	    if (note[n].npc == y) input_prof[y] += (double)note[n].duration;  
	}
    }
    
    average_dur = total_dur / 12.0;
    if(verbosity>=2) {
	printf("total dur = %6.3f, average dur = %6.3f\n", total_dur, average_dur); 
    }			
    
    if(verbosity>=2) {
	for (y=0; y<12; y++) {                          
	    printf("%d ", input_prof[y]);  
	}
	printf("\n"); 
    }

}

prepare_profiles() {

  /* Sum all the profile values, take the mean, and subtract that from each value */
    
    double total;
    double average;
    int i;

    total = 0.0;
    for(i=0; i<12; i++) {
	total += major_profile[i];
    }
    average = total / 12.0;
    for(i=0; i<12; i++) major_profile[i]=major_profile[i]-average;
    
    total = 0.0;
    for(i=0; i<12; i++) {
	total += minor_profile[i];
    }
    average = total / 12.0;
    for(i=0; i<12; i++) minor_profile[i]=minor_profile[i]-average;
    
    if(verbosity>=2) {
	printf("Adjusted major profile: ");
	for(i=0; i<12; i++) printf("%6.3f ", major_profile[i]);
	printf("\n");
	printf("Adjusted minor profile: ");
	for(i=0; i<12; i++) printf("%6.3f ", minor_profile[i]);
	printf("\n");
    }
}


/* This is a function for generating profiles given an NPC
   profile. */

generate_npc_profiles() {

    int key, pc;
    
    for (key=0; key<24; key++) {
	for(pc=0; pc<12; pc++) {
	    key_profile[key][pc]=0;
	}
    }
    
    for (key=0; key<12; key++) {              
	for (pc=0; pc<12; pc++) {                                 
	    key_profile[key][pc] = major_profile[ ((pc-key)+12) % 12 ];   
	}
    }
    for (key=12; key<24; key++) {              
	for (pc=0; pc<12; pc++) {                                 
	    key_profile[key][pc] = minor_profile[ ((pc-(key%12))+12) % 12 ];   
	}
    }

    /*    
    for(key=0; key<24; key++) {   
	for(pc=0; pc<12; pc++) {
	    
            printf("%1.2f ", key_profile[key][pc]);
        }
        printf("\n");
	}  */
    
}

match_profiles() {

    /* Here we generate the "key scores" */

    int key, pc, best_key, i;
    double major_sumsq, minor_sumsq, input_sumsq;
    double best_score;
    int best, second_best;
    
    for(key=0; key<24; key++) {
	key_score[key]= 0.0;
    }
    
    major_sumsq = 0.0;
    minor_sumsq = 0.0;
    for(i=0; i<12; i++) major_sumsq += major_profile[i]*major_profile[i];
    for(i=0; i<12; i++) minor_sumsq += minor_profile[i]*minor_profile[i];
    if(verbosity>=2) printf("major_sumsq = %6.3f, minor_sumsq = %6.3f\n", major_sumsq, minor_sumsq);
    
    input_sumsq=0.0;

    for(i=0; i<12; i++) {
	input_sumsq += pow((input_prof[i]-average_dur), 2.0);		
    }
    if(verbosity==3) printf("average_dur = %6.3f; input_sumsq = %6.3f\n", average_dur, input_sumsq); 
	
    best_key=0;
    
    for (key=0; key<24; key++) {
	key_score[key]=0.0;      

	for (pc=0; pc<12; pc++) {
	    
	    /* 
	       We assume the K-S algorithm.
	       Input profile values represent total duration of each pc. Key-profiles have been normalized linearly 
	       around the average key-profile value. We normalize the input values similarly by taking
	       (input_prof[pc]-average_dur). Then we multiply each normalized KP value by
	       the normalized input value, and sum these products; this gives us the numerator of the 
	       correlation expression (as commented below). We've summed the squares of the normalized 
	       key-profile value (major_sumsq and minor_sumsq above) and the normalized input values 
	       (input_sumsq above), so this allows us to calculate the denominator also.
	       
	    */
	    
	    /* calculate numerator */
	    key_score[key] += key_profile[key][pc] * (input_prof[pc]-average_dur);		
	    
	}
	
	/* printf("sqrt(major_sumsq * input_sumsq) = %6.3f\n", sqrt(major_sumsq * input_sumsq)); */
	/* calculate denominator */
	if(key<12) key_score[key] = key_score[key] / sqrt(major_sumsq * input_sumsq);
	else key_score[key] = key_score[key] / sqrt(minor_sumsq * input_sumsq);

	if(verbosity==2) printf("Score for key %d = %6.3f\n", key, key_score[key]);

	if (key_score[key] > key_score[best_key]) best_key = key;
    }
    
    print_keyname(best_key);

    if(verbosity>=2) {
	printf("The best key is ");
	print_keyname(best_key);
	printf("with score %6.3f\n", key_score[best_key]);
    }

}



