Improve checksum calculation and address checking

1. Add a function that calculates the checksum of all the packet
   components starting from the specified position. This
   simplifies the code a bit and makes it easier to translate
   nested packets like ICMP error messages.

2. Don't hardcode IP source and destination addresses. This is
   required to translate ICMP error messages.

Bug: 8276725
Change-Id: I2cae45683ae3943e508608fd0a140180dbc60823
diff --git a/clatd.c b/clatd.c
index 8dddc96..a914f02 100644
--- a/clatd.c
+++ b/clatd.c
@@ -333,11 +333,11 @@
   if(ntohs(tun_header->proto) == ETH_P_IP) {
     fd = tunnel->fd6;
     fill_tun_header(&tun_targ, ETH_P_IPV6);
-    iov_len = ipv4_packet(out, POS_IPHDR, packet, packetsize);
+    iov_len = ipv4_packet(out, CLAT_POS_IPHDR, packet, packetsize);
   } else if(ntohs(tun_header->proto) == ETH_P_IPV6) {
     fd = tunnel->fd4;
     fill_tun_header(&tun_targ, ETH_P_IP);
-    iov_len = ipv6_packet(out, POS_IPHDR, packet, packetsize);
+    iov_len = ipv6_packet(out, CLAT_POS_IPHDR, packet, packetsize);
   } else {
     logmsg(ANDROID_LOG_WARN,"packet_handler: unknown packet type = %x",tun_header->proto);
   }
diff --git a/ipv4.c b/ipv4.c
index 1d34e1e..89e47e4 100644
--- a/ipv4.c
+++ b/ipv4.c
@@ -139,6 +139,6 @@
   }
 
   // Set the length.
-  ip6_targ->ip6_plen = htons(payload_length(out, pos));
+  ip6_targ->ip6_plen = htons(packet_length(out, pos));
   return iov_len;
 }
diff --git a/ipv6.c b/ipv6.c
index 8011ce9..bb1dc24 100644
--- a/ipv6.c
+++ b/ipv6.c
@@ -100,12 +100,11 @@
     return 0; // silently ignore
   }
 
-  for(i = 0; i < 3; i++) {
-    if(ip6->ip6_src.s6_addr32[i] != Global_Clatd_Config.plat_subnet.s6_addr32[i]) {
-      log_bad_address("ipv6_packet/wrong source address: %s", &ip6->ip6_src);
-      return 0;
-    }
+  if (!is_in_plat_subnet(&ip6->ip6_src) && ip6->ip6_nxt) {
+    log_bad_address("ipv6_packet/wrong source address: %s", &ip6->ip6_src);
+    return 0;
   }
+
   if(!IN6_ARE_ADDR_EQUAL(&ip6->ip6_dst, &Global_Clatd_Config.ipv6_local_subnet)) {
     log_bad_address("ipv6_packet/wrong destination address: %s", &ip6->ip6_dst);
     return 0;
@@ -149,7 +148,7 @@
   }
 
   // Set the length and calculate the checksum.
-  ip_targ->tot_len = htons(ntohs(ip_targ->tot_len) + payload_length(out, pos));
+  ip_targ->tot_len = htons(ntohs(ip_targ->tot_len) + packet_length(out, pos));
   ip_targ->check = ip_checksum(ip_targ, sizeof(struct iphdr));
   return iov_len;
 }
diff --git a/translate.c b/translate.c
index c0bd59a..4092bcc 100644
--- a/translate.c
+++ b/translate.c
@@ -33,21 +33,81 @@
 #include "logging.h"
 #include "debug.h"
 
-/* function: payload_length
- * calculates the total length of the packet components after pos
+/* function: packet_checksum
+ * calculates the checksum over all the packet components starting from pos
+ * checksum - checksum of packet components before pos
+ * packet   - packet to calculate the checksum of
+ * pos      - position to start counting from
+ * returns  - the completed 16-bit checksum, ready to write into a checksum header field
+ */
+uint16_t packet_checksum(uint32_t checksum, clat_packet packet, int pos) {
+  int i;
+  for (i = pos; i < CLAT_POS_MAX; i++) {
+    if (packet[i].iov_len > 0) {
+      checksum = ip_checksum_add(checksum, packet[i].iov_base, packet[i].iov_len);
+    }
+  }
+  return ip_checksum_finish(checksum);
+}
+
+/* function: packet_length
+ * returns the total length of all the packet components after pos
  * packet - packet to calculate the length of
  * pos    - position to start counting from
  * returns: the total length of the packet components after pos
  */
-uint16_t payload_length(clat_packet packet, int pos) {
+uint16_t packet_length(clat_packet packet, int pos) {
   size_t len = 0;
   int i;
-  for (i = pos + 1; i < POS_MAX; i++) {
+  for (i = pos + 1; i < CLAT_POS_MAX; i++) {
     len += packet[i].iov_len;
   }
   return len;
 }
 
+/* function: is_in_plat_subnet
+ * returns true iff the given IPv6 address is in the plat subnet.
+ * addr - IPv6 address
+ */
+int is_in_plat_subnet(const struct in6_addr *addr6) {
+  // Assumes a /96 plat subnet.
+  return (addr6 != NULL) && (memcmp(addr6, &Global_Clatd_Config.plat_subnet, 12) == 0);
+}
+
+/* function: ipv6_addr_to_ipv4_addr
+ * return the corresponding ipv4 address for the given ipv6 address
+ * addr6 - ipv6 address
+ * returns: the IPv4 address
+ */
+uint32_t ipv6_addr_to_ipv4_addr(const struct in6_addr *addr6) {
+
+  if (is_in_plat_subnet(addr6)) {
+    // Assumes a /96 plat subnet.
+    return addr6->s6_addr32[3];
+  } else {
+    // Currently this can only be our own address; other packets are dropped by ipv6_packet.
+    return Global_Clatd_Config.ipv4_local_subnet.s_addr;
+  }
+}
+
+/* function: ipv4_addr_to_ipv6_addr
+ * return the corresponding ipv6 address for the given ipv4 address
+ * addr4 - ipv4 address
+ */
+struct in6_addr ipv4_addr_to_ipv6_addr(uint32_t addr4) {
+  struct in6_addr addr6;
+  // Both addresses are in network byte order (addr4 comes from a network packet, and the config
+  // file entry is read using inet_ntop).
+  if (addr4 == Global_Clatd_Config.ipv4_local_subnet.s_addr) {
+    return Global_Clatd_Config.ipv6_local_subnet;
+  } else {
+    // Assumes a /96 plat subnet.
+    addr6 = Global_Clatd_Config.plat_subnet;
+    addr6.s6_addr32[3] = addr4;
+    return addr6;
+  }
+}
+
 /* function: fill_tun_header
  * fill in the header for the tun fd
  * tun_header - tunnel header, already allocated
@@ -58,16 +118,6 @@
   tun_header->proto = htons(proto);
 }
 
-/* function: ipv6_src_to_ipv4_src
- * return the corresponding ipv4 address for the given ipv6 address
- * sourceaddr - ipv6 source address
- * returns: the IPv4 address
- */
-uint32_t ipv6_src_to_ipv4_src(const struct in6_addr *sourceaddr) {
-  // assumes a /96 plat subnet
-  return sourceaddr->s6_addr32[3];
-}
-
 /* function: fill_ip_header
  * generate an ipv4 header from an ipv6 header
  * ip_targ     - (ipv4) target packet header, source: original ipv4 addr, dest: local subnet addr
@@ -89,22 +139,8 @@
   ip->protocol = protocol;
   ip->check = 0;
 
-  ip->saddr = ipv6_src_to_ipv4_src(&old_header->ip6_src);
-  ip->daddr = Global_Clatd_Config.ipv4_local_subnet.s_addr;
-}
-
-/* function: ipv4_dst_to_ipv6_dst
- * return the corresponding ipv6 address for the given ipv4 address
- * destination - ipv4 destination address (network byte order)
- */
-struct in6_addr ipv4_dst_to_ipv6_dst(uint32_t destination) {
-  struct in6_addr v6_destination;
-
-  // assumes a /96 plat subnet
-  v6_destination = Global_Clatd_Config.plat_subnet;
-  v6_destination.s6_addr32[3] = destination;
-
-  return v6_destination;
+  ip->saddr = ipv6_addr_to_ipv4_addr(&old_header->ip6_src);
+  ip->daddr = ipv6_addr_to_ipv4_addr(&old_header->ip6_dst);
 }
 
 /* function: fill_ip6_header
@@ -123,8 +159,8 @@
   ip6->ip6_nxt = protocol;
   ip6->ip6_hlim = old_header->ttl;
 
-  ip6->ip6_src = Global_Clatd_Config.ipv6_local_subnet;
-  ip6->ip6_dst = ipv4_dst_to_ipv6_dst(old_header->daddr);
+  ip6->ip6_src = ipv4_addr_to_ipv6_addr(old_header->saddr);
+  ip6->ip6_dst = ipv4_addr_to_ipv6_addr(old_header->daddr);
 }
 
 /* function: icmp_to_icmp6
@@ -152,16 +188,14 @@
   icmp6_targ->icmp6_id = icmp->un.echo.id;
   icmp6_targ->icmp6_seq = icmp->un.echo.sequence;
 
-  icmp6_targ->icmp6_cksum = 0;
-  checksum = ip_checksum_add(checksum, icmp6_targ, sizeof(struct icmp6_hdr));
-  checksum = ip_checksum_add(checksum, payload, payload_size);
-  icmp6_targ->icmp6_cksum = ip_checksum_finish(checksum);
-
   out[pos].iov_len = sizeof(struct icmp6_hdr);
-  out[POS_PAYLOAD].iov_base = (char *) payload;
-  out[POS_PAYLOAD].iov_len = payload_size;
+  out[CLAT_POS_PAYLOAD].iov_base = (char *) payload;
+  out[CLAT_POS_PAYLOAD].iov_len = payload_size;
 
-  return POS_PAYLOAD + 1;
+  icmp6_targ->icmp6_cksum = 0;  // Checksum field must be 0 when calculating checksum.
+  icmp6_targ->icmp6_cksum = packet_checksum(checksum, out, pos);
+
+  return CLAT_POS_PAYLOAD + 1;
 }
 
 /* function: icmp6_to_icmp
@@ -189,16 +223,14 @@
   icmp_targ->un.echo.id = icmp6->icmp6_id;
   icmp_targ->un.echo.sequence = icmp6->icmp6_seq;
 
-  icmp_targ->checksum = 0;
-  checksum = ip_checksum_add(0, icmp_targ, sizeof(struct icmphdr));
-  checksum = ip_checksum_add(checksum, (void *)payload, payload_size);
-  icmp_targ->checksum = ip_checksum_finish(checksum);
-
   out[pos].iov_len = sizeof(struct icmphdr);
-  out[POS_PAYLOAD].iov_base = (char *) payload;
-  out[POS_PAYLOAD].iov_len = payload_size;
+  out[CLAT_POS_PAYLOAD].iov_base = (char *) payload;
+  out[CLAT_POS_PAYLOAD].iov_len = payload_size;
 
-  return POS_PAYLOAD + 1;
+  icmp_targ->checksum = 0;  // Checksum field must be 0 when calculating checksum.
+  icmp_targ->checksum = packet_checksum(0, out, pos);
+
+  return CLAT_POS_PAYLOAD + 1;
 }
 
 /* function: udp_packet
@@ -271,17 +303,15 @@
   struct udphdr *udp_targ = out[pos].iov_base;
 
   memcpy(udp_targ, udp, sizeof(struct udphdr));
-  udp_targ->check = 0; // reset checksum, to be calculated
-
-  checksum = ip_checksum_add(checksum, udp_targ, sizeof(struct udphdr));
-  checksum = ip_checksum_add(checksum, payload, payload_size);
-  udp_targ->check = ip_checksum_finish(checksum);
 
   out[pos].iov_len = sizeof(struct udphdr);
-  out[POS_PAYLOAD].iov_base = (char *) payload;
-  out[POS_PAYLOAD].iov_len = payload_size;
+  out[CLAT_POS_PAYLOAD].iov_base = (char *) payload;
+  out[CLAT_POS_PAYLOAD].iov_len = payload_size;
 
-  return POS_PAYLOAD + 1;
+  udp_targ->check = 0;  // Checksum field must be 0 when calculating checksum.
+  udp_targ->check = packet_checksum(checksum, out, pos);
+
+  return CLAT_POS_PAYLOAD + 1;
 }
 
 /* function: tcp_translate
@@ -312,13 +342,11 @@
 
   memcpy(tcp_targ, tcp, header_size);
 
-  tcp_targ->check = 0;
-  checksum = ip_checksum_add(checksum, tcp_targ, header_size);
-  checksum = ip_checksum_add(checksum, payload, payload_size);
-  tcp_targ->check = ip_checksum_finish(checksum);
+  out[CLAT_POS_PAYLOAD].iov_base = (char *)payload;
+  out[CLAT_POS_PAYLOAD].iov_len = payload_size;
 
-  out[POS_PAYLOAD].iov_base = (char *)payload;
-  out[POS_PAYLOAD].iov_len = payload_size;
+  tcp_targ->check = 0;  // Checksum field must be 0 when calculating checksum.
+  tcp_targ->check = packet_checksum(checksum, out, pos);
 
-  return POS_PAYLOAD + 1;
+  return CLAT_POS_PAYLOAD + 1;
 }
diff --git a/translate.h b/translate.h
index 07db023..120fecf 100644
--- a/translate.h
+++ b/translate.h
@@ -23,14 +23,21 @@
 #define MAX_TCP_HDR (15 * 4)   // Data offset field is 4 bits and counts in 32-bit words.
 
 // A clat_packet is an array of iovec structures representing a packet that we are translating.
-// The POS_XXX constants represent the array indices within the clat_packet that contain specific
-// parts of the packet.
-enum clat_packet_index { POS_TUNHDR, POS_IPHDR, POS_TRANSPORTHDR, POS_ICMPIPHDR,
-                         POS_PAYLOAD, POS_MAX };
-typedef struct iovec clat_packet[POS_MAX];
+// The CLAT_POS_XXX constants represent the array indices within the clat_packet that contain
+// specific parts of the packet. The packet_* functions operate on all the packet segments past a
+// given position.
+enum clat_packet_index { CLAT_POS_TUNHDR, CLAT_POS_IPHDR, CLAT_POS_TRANSPORTHDR,
+                         CLAT_POS_PAYLOAD, CLAT_POS_MAX };
+typedef struct iovec clat_packet[CLAT_POS_MAX];
 
-// Returns the total length of the packet components after index.
-uint16_t payload_length(clat_packet packet, int index);
+// Calculates the checksum over all the packet components starting from pos.
+uint16_t packet_checksum(uint32_t checksum, clat_packet packet, int pos);
+
+// Returns the total length of the packet components after pos.
+uint16_t packet_length(clat_packet packet, int pos);
+
+// Returns true iff the given IPv6 address is in the plat subnet.
+int is_in_plat_subnet(const struct in6_addr *addr6);
 
 // Functions to create tun, IPv4, and IPv6 headers.
 void fill_tun_header(struct tun_pi *tun_header, uint16_t proto);