Merge "Fix MediaRouter leaks." into jb-mr2-dev
diff --git a/v7/mediarouter/src/android/support/v7/media/MediaRouter.java b/v7/mediarouter/src/android/support/v7/media/MediaRouter.java
index 881ccde..c7df8e0 100644
--- a/v7/mediarouter/src/android/support/v7/media/MediaRouter.java
+++ b/v7/mediarouter/src/android/support/v7/media/MediaRouter.java
@@ -31,11 +31,10 @@
 import android.util.Log;
 import android.view.Display;
 
+import java.lang.ref.WeakReference;
 import java.util.ArrayList;
 import java.util.Collections;
 import java.util.List;
-import java.util.WeakHashMap;
-import java.util.concurrent.CopyOnWriteArrayList;
 
 /**
  * MediaRouter allows applications to control the routing of media channels
@@ -67,8 +66,7 @@
 
     // Context-bound state of the media router.
     final Context mContext;
-    final CopyOnWriteArrayList<CallbackRecord> mCallbackRecords =
-            new CopyOnWriteArrayList<CallbackRecord>();
+    final ArrayList<CallbackRecord> mCallbackRecords = new ArrayList<CallbackRecord>();
 
     /**
      * Flag for {@link #addCallback}: Actively scan for routes while this callback
@@ -130,7 +128,21 @@
     }
 
     /**
-     * Gets an instance of the media router service from the context.
+     * Gets an instance of the media router service associated with the context.
+     * <p>
+     * The application is responsible for holding a strong reference to the returned
+     * {@link MediaRouter} instance, such as by storing the instance in a field of
+     * the {@link android.app.Activity}, to ensure that the media router remains alive
+     * as long as the application is using its features.
+     * </p><p>
+     * In other words, the support library only holds a {@link WeakReference weak reference}
+     * to each media router instance.  When there are no remaining strong references to the
+     * media router instance, all of its callbacks will be removed and route discovery
+     * will no longer be performed on its behalf.
+     * </p>
+     *
+     * @return The media router instance for the context.  The application must hold
+     * a strong reference to this object as long as it is in use.
      */
     public static MediaRouter getInstance(Context context) {
         if (context == null) {
@@ -403,7 +415,7 @@
         CallbackRecord record;
         int index = findCallbackRecord(callback);
         if (index < 0) {
-            record = new CallbackRecord(callback);
+            record = new CallbackRecord(this, callback);
             mCallbackRecords.add(record);
         } else {
             record = mCallbackRecords.get(index);
@@ -1274,11 +1286,13 @@
     }
 
     private static final class CallbackRecord {
+        public final MediaRouter mRouter;
         public final Callback mCallback;
         public MediaRouteSelector mSelector;
         public int mFlags;
 
-        public CallbackRecord(Callback callback) {
+        public CallbackRecord(MediaRouter router, Callback callback) {
+            mRouter = router;
             mCallback = callback;
             mSelector = MediaRouteSelector.EMPTY;
         }
@@ -1299,8 +1313,8 @@
     private static final class GlobalMediaRouter implements SystemMediaRouteProvider.SyncCallback {
         private final Context mApplicationContext;
         private final MediaRouter mApplicationRouter;
-        private final WeakHashMap<Context, MediaRouter> mRouters =
-                new WeakHashMap<Context, MediaRouter>();
+        private final ArrayList<WeakReference<MediaRouter>> mRouters =
+                new ArrayList<WeakReference<MediaRouter>>();
         private final ArrayList<RouteInfo> mRoutes = new ArrayList<RouteInfo>();
         private final ArrayList<ProviderInfo> mProviders =
                 new ArrayList<ProviderInfo>();
@@ -1336,11 +1350,17 @@
         }
 
         public MediaRouter getRouter(Context context) {
-            MediaRouter router = mRouters.get(context);
-            if (router == null) {
-                router = new MediaRouter(context);
-                mRouters.put(context, router);
+            MediaRouter router;
+            for (int i = mRouters.size(); --i >= 0; ) {
+                router = mRouters.get(i).get();
+                if (router == null) {
+                    mRouters.remove(i);
+                } else if (router.mContext == context) {
+                    return router;
+                }
             }
+            router = new MediaRouter(context);
+            mRouters.add(new WeakReference<MediaRouter>(router));
             return router;
         }
 
@@ -1466,13 +1486,18 @@
             // Combine all of the callback selectors and active scan flags.
             boolean activeScan = false;
             MediaRouteSelector.Builder builder = new MediaRouteSelector.Builder();
-            for (MediaRouter router : mRouters.values()) {
-                final int count = router.mCallbackRecords.size();
-                for (int i = 0; i < count; i++) {
-                    CallbackRecord callback = router.mCallbackRecords.get(i);
-                    builder.addSelector(callback.mSelector);
-                    if ((callback.mFlags & CALLBACK_FLAG_ACTIVE_SCAN) != 0) {
-                        activeScan = true;
+            for (int i = mRouters.size(); --i >= 0; ) {
+                MediaRouter router = mRouters.get(i).get();
+                if (router == null) {
+                    mRouters.remove(i);
+                } else {
+                    final int count = router.mCallbackRecords.size();
+                    for (int j = 0; j < count; j++) {
+                        CallbackRecord callback = router.mCallbackRecords.get(j);
+                        builder.addSelector(callback.mSelector);
+                        if ((callback.mFlags & CALLBACK_FLAG_ACTIVE_SCAN) != 0) {
+                            activeScan = true;
+                        }
                     }
                 }
             }
@@ -1751,8 +1776,8 @@
         }
 
         private final class CallbackHandler extends Handler {
-            private final ArrayList<MediaRouter> mTempMediaRouters =
-                    new ArrayList<MediaRouter>();
+            private final ArrayList<CallbackRecord> mTempCallbackRecords =
+                    new ArrayList<CallbackRecord>();
 
             private static final int MSG_TYPE_MASK = 0xff00;
             private static final int MSG_TYPE_ROUTE = 0x0100;
@@ -1783,19 +1808,24 @@
                 syncWithSystemProvider(what, obj);
 
                 // Invoke all registered callbacks.
-                mTempMediaRouters.addAll(mRouters.values());
+                // Build a list of callbacks before invoking them in case callbacks
+                // are added or removed during dispatch.
                 try {
-                    final int routerCount = mTempMediaRouters.size();
-                    for (int i = 0; i < routerCount; i++) {
-                        final MediaRouter router = mTempMediaRouters.get(i);
-                        if (!router.mCallbackRecords.isEmpty()) {
-                            for (CallbackRecord record : router.mCallbackRecords) {
-                                invokeCallback(router, record, what, obj);
-                            }
+                    for (int i = mRouters.size(); --i >= 0; ) {
+                        MediaRouter router = mRouters.get(i).get();
+                        if (router == null) {
+                            mRouters.remove(i);
+                        } else {
+                            mTempCallbackRecords.addAll(router.mCallbackRecords);
                         }
                     }
+
+                    final int callbackCount = mTempCallbackRecords.size();
+                    for (int i = 0; i < callbackCount; i++) {
+                        invokeCallback(mTempCallbackRecords.get(i), what, obj);
+                    }
                 } finally {
-                    mTempMediaRouters.clear();
+                    mTempCallbackRecords.clear();
                 }
             }
 
@@ -1816,8 +1846,8 @@
                 }
             }
 
-            private void invokeCallback(MediaRouter router, CallbackRecord record,
-                    int what, Object obj) {
+            private void invokeCallback(CallbackRecord record, int what, Object obj) {
+                final MediaRouter router = record.mRouter;
                 final MediaRouter.Callback callback = record.mCallback;
                 switch (what & MSG_TYPE_MASK) {
                     case MSG_TYPE_ROUTE: {