|
| 1 | +from decimal import Decimal |
| 2 | + |
| 3 | +class ArithmeticEncoding: |
| 4 | + """ |
| 5 | + ArithmeticEncoding is a class for building the arithmetic encoding. |
| 6 | + """ |
| 7 | + |
| 8 | + def __init__(self, frequency_table): |
| 9 | + |
| 10 | + frequency_table = frequency_table |
| 11 | + self.probability_table = self.get_probability_table(frequency_table) |
| 12 | + |
| 13 | + self.probability_table |
| 14 | + |
| 15 | + def get_probability_table(self, frequency_table): |
| 16 | + """ |
| 17 | + Calculates the probability table out of the frequency table. |
| 18 | + """ |
| 19 | + total_frequency = sum(list(frequency_table.values())) |
| 20 | + |
| 21 | + probability_table = {} |
| 22 | + for key, value in frequency_table.items(): |
| 23 | + probability_table[key] = value/total_frequency |
| 24 | + |
| 25 | + return probability_table |
| 26 | + |
| 27 | + def get_encoded_value(self, encoder): |
| 28 | + """ |
| 29 | + After encoding the entire message, this method returns the single value that represents the entire message. |
| 30 | + """ |
| 31 | + last_stage = list(encoder[-1].values()) |
| 32 | + last_stage_values = [] |
| 33 | + for sublist in last_stage: |
| 34 | + for element in sublist: |
| 35 | + last_stage_values.append(element) |
| 36 | + |
| 37 | + last_stage_min = min(last_stage_values) |
| 38 | + last_stage_max = max(last_stage_values) |
| 39 | + |
| 40 | + return (last_stage_min + last_stage_max)/2 |
| 41 | + |
| 42 | + def process_stage(self, probability_table, stage_min, stage_max): |
| 43 | + """ |
| 44 | + Processing a stage in the encoding/decoding process. |
| 45 | + """ |
| 46 | + stage_probs = {} |
| 47 | + stage_domain = stage_max - stage_min |
| 48 | + for term_idx in range(len(probability_table.items())): |
| 49 | + term = list(probability_table.keys())[term_idx] |
| 50 | + term_prob = Decimal(probability_table[term]) |
| 51 | + cum_prob = term_prob * stage_domain + stage_min |
| 52 | + stage_probs[term] = [stage_min, cum_prob] |
| 53 | + stage_min = cum_prob |
| 54 | + return stage_probs |
| 55 | + |
| 56 | + def ae_encoder(self, msg, probability_table): |
| 57 | + """ |
| 58 | + Encodes a message. |
| 59 | + """ |
| 60 | + |
| 61 | + encoder = [] |
| 62 | + |
| 63 | + stage_min = Decimal(0.0) |
| 64 | + stage_max = Decimal(1.0) |
| 65 | + |
| 66 | + for msg_term_idx in range(len(msg)): |
| 67 | + stage_probs = self.process_stage(probability_table, stage_min, stage_max) |
| 68 | + |
| 69 | + msg_term = msg[msg_term_idx] |
| 70 | + stage_min = stage_probs[msg_term][0] |
| 71 | + stage_max = stage_probs[msg_term][1] |
| 72 | + |
| 73 | + encoder.append(stage_probs) |
| 74 | + |
| 75 | + stage_probs = self.process_stage(probability_table, stage_min, stage_max) |
| 76 | + encoder.append(stage_probs) |
| 77 | + |
| 78 | + encoded_msg = self.get_encoded_value(encoder) |
| 79 | + |
| 80 | + return encoder, encoded_msg |
| 81 | + |
| 82 | + def ae_decoder(self, encoded_msg, msg_length, probability_table): |
| 83 | + """ |
| 84 | + Decodes a message. |
| 85 | + """ |
| 86 | + |
| 87 | + decoder = [] |
| 88 | + decoded_msg = "" |
| 89 | + |
| 90 | + stage_min = Decimal(0.0) |
| 91 | + stage_max = Decimal(1.0) |
| 92 | + |
| 93 | + for idx in range(msg_length): |
| 94 | + stage_probs = self.process_stage(probability_table, stage_min, stage_max) |
| 95 | + |
| 96 | + for msg_term, value in stage_probs.items(): |
| 97 | + if encoded_msg >= value[0] and encoded_msg <= value[1]: |
| 98 | + break |
| 99 | + |
| 100 | + decoded_msg = decoded_msg + msg_term |
| 101 | + stage_min = stage_probs[msg_term][0] |
| 102 | + stage_max = stage_probs[msg_term][1] |
| 103 | + |
| 104 | + decoder.append(stage_probs) |
| 105 | + |
| 106 | + stage_probs = self.process_stage(probability_table, stage_min, stage_max) |
| 107 | + decoder.append(stage_probs) |
| 108 | + |
| 109 | + return decoder, decoded_msg |
0 commit comments