Merge "Implement sf_open(SFM_WRITE) and sf_writef_short"
diff --git a/audio_utils/include/audio_utils/sndfile.h b/audio_utils/include/audio_utils/sndfile.h
index 1d12ae5..c652654 100644
--- a/audio_utils/include/audio_utils/sndfile.h
+++ b/audio_utils/include/audio_utils/sndfile.h
@@ -42,7 +42,8 @@
 typedef struct SNDFILE_ SNDFILE;
 
 // Access modes
-#define SFM_READ   1
+#define SFM_READ    1
+#define SFM_WRITE   2
 
 // Format
 #define SF_FORMAT_TYPEMASK  1
@@ -60,6 +61,9 @@
 // Read interleaved frames and return actual number of frames read
 sf_count_t sf_readf_short(SNDFILE *handle, short *ptr, sf_count_t desired);
 
+// Write interleaved frames and return actual number of frames written
+sf_count_t sf_writef_short(SNDFILE *handle, const short *ptr, sf_count_t desired);
+
 __END_DECLS
 
 #endif /* __AUDIO_UTIL_SNDFILE_H */
diff --git a/audio_utils/tinysndfile.c b/audio_utils/tinysndfile.c
index f028685..efdb3ce 100644
--- a/audio_utils/tinysndfile.c
+++ b/audio_utils/tinysndfile.c
@@ -20,9 +20,11 @@
 #include <string.h>
 
 struct SNDFILE_ {
+    int mode;
+    uint8_t *temp;  // realloc buffer used for shrinking 16 bits to 8 bits and byte-swapping
     FILE *stream;
     size_t bytesPerFrame;
-    size_t remaining;
+    size_t remaining;   // frames unread for SFM_READ, frames written for SFM_WRITE
     SF_INFO info;
 };
 
@@ -51,10 +53,8 @@
     }
 }
 
-SNDFILE *sf_open(const char *path, int mode, SF_INFO *info)
+static SNDFILE *sf_open_read(const char *path, SF_INFO *info)
 {
-    if (path == NULL || mode != SFM_READ || info == NULL)
-        return NULL;
     FILE *stream = fopen(path, "rb");
     if (stream == NULL)
         return NULL;
@@ -68,7 +68,7 @@
         if (memcmp(wav, "RIFF", 4))
             break;
         unsigned riffSize = little4u(&wav[4]);
-        if (riffSize < 44)
+        if (riffSize < 36)
             break;
         if (memcmp(&wav[8], "WAVEfmt ", 8))
             break;
@@ -94,6 +94,8 @@
             break;
         unsigned dataSize = little4u(&wav[40]);
         SNDFILE *handle = (SNDFILE *) malloc(sizeof(SNDFILE));
+        handle->mode = SFM_READ;
+        handle->temp = NULL;
         handle->stream = stream;
         handle->bytesPerFrame = bytesPerFrame;
         handle->remaining = dataSize / bytesPerFrame;
@@ -111,19 +113,94 @@
     return NULL;
 }
 
+static void write4u(unsigned char *ptr, unsigned u)
+{
+    ptr[0] = u;
+    ptr[1] = u >> 8;
+    ptr[2] = u >> 16;
+    ptr[3] = u >> 24;
+}
+
+static SNDFILE *sf_open_write(const char *path, SF_INFO *info)
+{
+    if (!(
+            (info->samplerate > 0) &&
+            (info->channels == 1 || info->channels == 2) &&
+            ((info->format & SF_FORMAT_TYPEMASK) == SF_FORMAT_WAV) &&
+            ((info->format & SF_FORMAT_SUBMASK) == SF_FORMAT_PCM_16 ||
+             (info->format & SF_FORMAT_SUBMASK) == SF_FORMAT_PCM_U8)
+          )) {
+        return NULL;
+    }
+    FILE *stream = fopen(path, "w+b");
+    unsigned char wav[44];
+    memset(wav, 0, sizeof(wav));
+    memcpy(wav, "RIFF", 4);
+    wav[4] = 36;    // riffSize
+    memcpy(&wav[8], "WAVEfmt ", 8);
+    wav[16] = 16;   // fmtsize
+    wav[20] = 1;    // format = PCM
+    wav[22] = info->channels;
+    write4u(&wav[24], info->samplerate);
+    unsigned bitsPerSample = (info->format & SF_FORMAT_SUBMASK) == SF_FORMAT_PCM_16 ? 16 : 8;
+    unsigned blockAlignment = (bitsPerSample >> 3) * info->channels;
+    unsigned byteRate = info->samplerate * blockAlignment;
+    write4u(&wav[28], byteRate);
+    wav[32] = blockAlignment;
+    wav[34] = bitsPerSample;
+    memcpy(&wav[36], "data", 4);
+    // dataSize is initially zero
+    (void) fwrite(wav, sizeof(wav), 1, stream);
+    SNDFILE *handle = (SNDFILE *) malloc(sizeof(SNDFILE));
+    handle->mode = SFM_WRITE;
+    handle->temp = NULL;
+    handle->stream = stream;
+    handle->bytesPerFrame = blockAlignment;
+    handle->remaining = 0;
+    handle->info = *info;
+    return handle;
+}
+
+SNDFILE *sf_open(const char *path, int mode, SF_INFO *info)
+{
+    if (path == NULL || info == NULL)
+        return NULL;
+    switch (mode) {
+    case SFM_READ:
+        return sf_open_read(path, info);
+    case SFM_WRITE:
+        return sf_open_write(path, info);
+    default:
+        return NULL;
+    }
+}
+
 void sf_close(SNDFILE *handle)
 {
     if (handle == NULL)
         return;
+    free(handle->temp);
+    if (handle->mode == SFM_WRITE) {
+        (void) fflush(handle->stream);
+        rewind(handle->stream);
+        unsigned char wav[44];
+        (void) fread(wav, sizeof(wav), 1, handle->stream);
+        unsigned dataSize = handle->remaining * handle->bytesPerFrame;
+        write4u(&wav[4], dataSize + 36);    // riffSize
+        write4u(&wav[40], dataSize);        // dataSize
+        rewind(handle->stream);
+        (void) fwrite(wav, sizeof(wav), 1, handle->stream);
+    }
     (void) fclose(handle->stream);
-    handle->stream = NULL;
-    handle->remaining = 0;
+    free(handle);
 }
 
 sf_count_t sf_readf_short(SNDFILE *handle, short *ptr, sf_count_t desiredFrames)
 {
-    if (handle == NULL || ptr == NULL || !handle->remaining || desiredFrames <= 0)
+    if (handle == NULL || handle->mode != SFM_READ || ptr == NULL || !handle->remaining ||
+            desiredFrames <= 0) {
         return 0;
+    }
     if (handle->remaining < (size_t) desiredFrames)
         desiredFrames = handle->remaining;
     size_t desiredBytes = desiredFrames * handle->bytesPerFrame;
@@ -142,3 +219,32 @@
     }
     return actualFrames;
 }
+
+sf_count_t sf_writef_short(SNDFILE *handle, const short *ptr, sf_count_t desiredFrames)
+{
+    if (handle == NULL || handle->mode != SFM_WRITE || ptr == NULL || desiredFrames <= 0)
+        return 0;
+    size_t desiredBytes = desiredFrames * handle->bytesPerFrame;
+    size_t actualBytes = 0;
+    switch (handle->info.format & SF_FORMAT_SUBMASK) {
+    case SF_FORMAT_PCM_U8:
+        handle->temp = realloc(handle->temp, desiredBytes);
+        memcpy_to_u8_from_i16(handle->temp, ptr, desiredBytes);
+        actualBytes = fwrite(handle->temp, sizeof(char), desiredBytes, handle->stream);
+        break;
+    case SF_FORMAT_PCM_16:
+        // does not check for numeric overflow
+        if (isLittleEndian()) {
+            actualBytes = fwrite(ptr, sizeof(char), desiredBytes, handle->stream);
+        } else {
+            handle->temp = realloc(handle->temp, desiredBytes);
+            memcpy(handle->temp, ptr, desiredBytes);
+            swab((short *) handle->temp, desiredFrames * handle->info.channels);
+            actualBytes = fwrite(handle->temp, sizeof(char), desiredBytes, handle->stream);
+        }
+        break;
+    }
+    size_t actualFrames = actualBytes / handle->bytesPerFrame;
+    handle->remaining += actualFrames;
+    return actualFrames;
+}