#include "radix.h" #define GET_BIT(key, bit) ( ( key.s[ bit / 16 ] & (1u << (15 - bit % 16)) ) != 0 ) radix_node_t create_node(ipv6_t key, unsigned end_bit, unsigned rule) { radix_node_t ret = calloc(1, sizeof(struct radix_node)); ret->_ip = key; ret->_ip_end_bit = end_bit; ret->_rule_valid = VALID_RULE; ret->_rule = rule; return ret; } static radix_node_t clone_node(radix_node_t n) { radix_node_t ret = calloc(1, sizeof(struct radix_node)); memcpy(ret, n, sizeof(struct radix_node)); return ret; } static void split_node(radix_node_t n, unsigned bit) { radix_node_t tmp = clone_node(n); n->_ip_end_bit = bit; n->_l = n->_r = tmp; n->_rule_valid = INVALID_RULE; } radix_holder_t create_holder() { radix_holder_t ret = calloc(1, sizeof(struct radix_holder)); return ret; } int destory_holder(struct radix_holder* t) { destory_node(t->_root); free(t); return 1; } int destory_node(struct radix_node* n) { int ret = 0; if (!n) return ret; ret = destory_node(n->_l) + destory_node(n->_r); free(n); return ret; } int radix_insert(radix_holder_t t, ipv6_t key, int val, int mask) { unsigned char b; int ibit; radix_node_t node = t->_root; radix_node_t* next = &t->_root; if (!t->_root) { t->_root = create_node(key, mask, val); return t->_size = 1; } for (ibit = 0; ibit < mask; ++ibit) { b = GET_BIT(key, ibit); if (ibit >= node->_ip_end_bit) { /* node key ends */ next = b ? &node->_l : &node->_r; if (!*next) { *next = create_node(key, mask, val); ++t->_size; return 1; } node = *next; continue; } if (b == GET_BIT(node->_ip, ibit)) /* bit match */ continue; next = b ? &node->_l : &node->_r; if (*next && ibit == mask - 1) { /* last bit mismatch */ node = *next; continue; } split_node(node, ibit); /* split node */ *next = create_node(key, mask, val); ++t->_size; return 1; } /* exact match */ if (node->_rule_valid && node->_ip_end_bit == mask) { printf("WARNING: conflicting rule %x:%x:%x:%x:%x:%x:%x:%x", key.s[0], key.s[1], key.s[2], key.s[3], key.s[4], key.s[5], key.s[6], key.s[7]); printf("/%d %d\n", mask, val); return 0; } /* match with lower mask */ if (node->_rule_valid) { split_node(node, ibit); b = GET_BIT(node->_ip, ibit + 1); next = b ? &node->_r : &node->_l; *next = NULL; } node->_rule_valid = VALID_RULE; node->_rule = val; return 1; } int radix_search(radix_holder_t t, ipv6_t key) { unsigned char b; int ret = INVALID_RULE; radix_node_t iter = t->_root; int ibit = 0; if (!iter) return INVALID_RULE; for (;ibit < 128; ++ibit) { b = GET_BIT(key, ibit); if (ibit < iter->_ip_end_bit && b == GET_BIT(iter->_ip, ibit)) continue; if (ibit == iter->_ip_end_bit) /* set result if full match */ ret = iter->_rule_valid ? iter->_rule : ret; iter = b ? iter->_l : iter->_r; if (!iter) return ret; } return ret; } /*--------------------------------------------------------------------------------*/ /* TESTING */ /*--------------------------------------------------------------------------------*/ #include #define RULES_FILE "../data/routing-data" #define TEST_FILE "../data/test-data" #define TEST_ROWS 5527 #define READ_MODE "r" int parse_line(ipv6_t* key, int* val, int* mask, FILE* f) { /*! assumes only good input */ memset(key, 0, sizeof(ipv6_t)); fscanf(f, "%hx:%hx:%hx:%hx:%hx:%hx:%hx:%hx", &key->s[0], &key->s[1], &key->s[2], &key->s[3], &key->s[4], &key->s[5], &key->s[6], &key->s[7]); if (feof(f)) return 0; /* ends with newline */ while (fgetc(f) != '/'); fscanf(f, "%d %d", mask, val); return !feof(f); } int load_input(radix_holder_t t) { int val; int mask; ipv6_t key; FILE* f; int cnt = 0; if (!(f = fopen(RULES_FILE, READ_MODE))) { puts("cant open input file"); exit(1); } while (parse_line(&key, &val, &mask, f)) { cnt += radix_insert(t, key, val, mask); } fclose(f); return cnt; } int load_tests(ipv6_t** rips, int** rref) { ipv6_t* ips; int* ref; int row; FILE* f; if (!(f = fopen( TEST_FILE, READ_MODE ))) { puts("cant open test file"); exit(1); } ips = calloc( TEST_ROWS, sizeof(ipv6_t) ); ref = calloc( TEST_ROWS, sizeof(int) ); for (row = 0; row < TEST_ROWS; ++ row) { fscanf(f, "%hx:%hx:%hx:%hx:%hx:%hx:%hx:%hx", &ips[row].s[0], &ips[row].s[1], &ips[row].s[2], &ips[row].s[3], &ips[row].s[4], &ips[row].s[5], &ips[row].s[6], &ips[row].s[7]); while (fgetc(f) != ' '); fscanf(f, "%d", &(ref[row])); } fclose(f); *rips = ips; *rref = ref; return TEST_ROWS; } int fire_tests(radix_holder_t t, ipv6_t* ips, int* ref, int num_tests ) { int row; int wrong = 0; for (row = 0; row < TEST_ROWS; ++ row) { if (radix_search(t, ips[row]) != ref[row]) { printf("ref: %d got: %d @ %d\n", ref[row], radix_search(t, ips[row]), row ); ++wrong; } } free(ips); free(ref); return wrong; } int main() { ipv6_t* ips; int* ref; radix_holder_t t = create_holder(); printf("Loaded %d rules\n", load_input(t) ); printf("Loaded %d tests\n", load_tests(&ips, &ref) ); printf("%d tests wrong\n", fire_tests(t, ips, ref, TEST_ROWS) ); destory_holder(t); return 0; }