Skip to content

Commit 78d1a03

Browse files
authored
Addition of optional parameter save_stages
1) An optional parameter named save_stages is added to the constructor of the ArithmeticEncoding class. If True, then the intervals of each stage are saved in a list. Note that setting save_stages=True may cause memory overflow if the message is large 2) The decoded message returned by the decode() method is always a list. 3) The order of the returned values by the the encode() and decode() method changed. 4) Added an example in the example_image.py script to encode and decode an image.
1 parent 26440ae commit 78d1a03

File tree

4 files changed

+123
-24
lines changed

4 files changed

+123
-24
lines changed

example.py

+6-8
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,25 @@
11
import pyae
2-
# from decimal import getcontext
32

43
# Example for encoding a simple text message using the PyAE module.
54

65
frequency_table = {"a": 2,
76
"b": 7,
87
"c": 1}
98

10-
AE = pyae.ArithmeticEncoding(frequency_table)
11-
12-
# Default precision is 28. Change it to do arithmetic operations with larger/smaller numbers.
13-
# getcontext().prec = 28
9+
AE = pyae.ArithmeticEncoding(frequency_table=frequency_table,
10+
save_stages=True)
1411

1512
original_msg = "abc"
1613
print("Original Message: {msg}".format(msg=original_msg))
1714

18-
encoder, encoded_msg = AE.encode(msg=original_msg,
15+
encoded_msg, encoder = AE.encode(msg=original_msg,
1916
probability_table=AE.probability_table)
2017
print("Encoded Message: {msg}".format(msg=encoded_msg))
2118

22-
decoder, decoded_msg = AE.decode(encoded_msg=encoded_msg,
19+
decoded_msg, decoder = AE.decode(encoded_msg=encoded_msg,
2320
msg_length=len(original_msg),
2421
probability_table=AE.probability_table)
2522
print("Decoded Message: {msg}".format(msg=decoded_msg))
2623

27-
print("Message Decoded Successfully? {result}".format(result=original_msg == decoded_msg))
24+
decoded_msg = "".join(decoded_msg)
25+
print("Message Decoded Successfully? {result}".format(result=original_msg == decoded_msg))

example2.py

+34
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import pyae
2+
# from decimal import getcontext
3+
4+
# Example for encoding a simple text message using the PyAE module.
5+
6+
# Create the frequency table.
7+
frequency_table = {"a": 2,
8+
"b": 3,
9+
"c": 1,
10+
"d": 4}
11+
12+
# Create an instance of the ArithmeticEncoding class.
13+
AE = pyae.ArithmeticEncoding(frequency_table,
14+
save_stages=True)
15+
16+
# Default precision is 28. Change it to do arithmetic operations with larger/smaller numbers.
17+
# getcontext().prec = 28
18+
19+
original_msg = "bdab"
20+
print("Original Message: {msg}".format(msg=original_msg))
21+
22+
# Encode the message
23+
encoded_msg, encoder = AE.encode(msg=original_msg,
24+
probability_table=AE.probability_table)
25+
print("Encoded Message: {msg}".format(msg=encoded_msg))
26+
27+
# Decode the message
28+
decoded_msg, decoder = AE.decode(encoded_msg=encoded_msg,
29+
msg_length=len(original_msg),
30+
probability_table=AE.probability_table)
31+
print("Decoded Message: {msg}".format(msg=decoded_msg))
32+
33+
decoded_msg = "".join(decoded_msg)
34+
print("Message Decoded Successfully? {result}".format(result=original_msg == decoded_msg))

example_image.py

+48
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import scipy.io
2+
import pyae
3+
import numpy
4+
import matplotlib.pyplot
5+
6+
# Change the precision to a bigger value
7+
from decimal import getcontext
8+
getcontext().prec = 10000
9+
10+
# Read an image.
11+
im = scipy.misc.face(gray=True)
12+
13+
# Just work on a small part to save time. The larger the image, the more time consumed.
14+
im = im[:50, :50]
15+
16+
# Convert the image into a 1D vector.
17+
msg = im.flatten()
18+
19+
# Create the frequency table based on its hitogram.
20+
hist, bin_edges = numpy.histogram(a=im,
21+
bins=range(0, 257))
22+
frequency_table = {key: value for key, value in zip(bin_edges[0:256], hist)}
23+
24+
# Create an instance of the ArithmeticEncoding class.
25+
AE = pyae.ArithmeticEncoding(frequency_table=frequency_table)
26+
27+
# Encode the message
28+
encoded_msg, _ = AE.encode(msg=msg,
29+
probability_table=AE.probability_table)
30+
31+
# Decode the message
32+
decoded_msg, _ = AE.decode(encoded_msg=encoded_msg,
33+
msg_length=len(msg),
34+
probability_table=AE.probability_table)
35+
36+
# Reshape the image to its original shape.
37+
decoded_msg = numpy.reshape(decoded_msg, im.shape)
38+
39+
# Show the original and decoded images.
40+
fig, ax = matplotlib.pyplot.subplots(1, 2)
41+
ax[0].imshow(im, cmap="gray")
42+
ax[0].set_title("Original Image")
43+
ax[0].set_xticks([])
44+
ax[0].set_yticks([])
45+
ax[1].imshow(decoded_msg, cmap="gray")
46+
ax[1].set_title("Reconstructed Image")
47+
ax[1].set_xticks([])
48+
ax[1].set_yticks([])

pyae.py

+35-16
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,16 @@ class ArithmeticEncoding:
55
ArithmeticEncoding is a class for building the arithmetic encoding.
66
"""
77

8-
def __init__(self, frequency_table):
8+
def __init__(self, frequency_table, save_stages=False):
9+
"""
10+
frequency_table: Frequency table as a dictionary where key is the symbol and value is the frequency.
11+
save_stages: If True, then the intervals of each stage are saved in a list. Note that setting save_stages=True may cause memory overflow if the message is large
12+
"""
13+
14+
self.save_stages = save_stages
15+
if(save_stages == True):
16+
print("WARNING: Setting save_stages=True may cause memory overflow if the message is large.")
17+
918
self.probability_table = self.get_probability_table(frequency_table)
1019

1120
def get_probability_table(self, frequency_table):
@@ -20,13 +29,13 @@ def get_probability_table(self, frequency_table):
2029

2130
return probability_table
2231

23-
def get_encoded_value(self, encoder):
32+
def get_encoded_value(self, last_stage_probs):
2433
"""
2534
After encoding the entire message, this method returns the single value that represents the entire message.
2635
"""
27-
last_stage = list(encoder[-1].values())
36+
last_stage_probs = list(last_stage_probs.values())
2837
last_stage_values = []
29-
for sublist in last_stage:
38+
for sublist in last_stage_probs:
3039
for element in sublist:
3140
last_stage_values.append(element)
3241

@@ -53,9 +62,12 @@ def encode(self, msg, probability_table):
5362
"""
5463
Encodes a message.
5564
"""
65+
66+
# Make sure
67+
msg = list(msg)
5668

5769
encoder = []
58-
70+
5971
stage_min = Decimal(0.0)
6072
stage_max = Decimal(1.0)
6173

@@ -66,22 +78,26 @@ def encode(self, msg, probability_table):
6678
stage_min = stage_probs[msg_term][0]
6779
stage_max = stage_probs[msg_term][1]
6880

69-
encoder.append(stage_probs)
81+
if self.save_stages:
82+
encoder.append(stage_probs)
7083

71-
stage_probs = self.process_stage(probability_table, stage_min, stage_max)
72-
encoder.append(stage_probs)
84+
last_stage_probs = self.process_stage(probability_table, stage_min, stage_max)
85+
86+
if self.save_stages:
87+
encoder.append(last_stage_probs)
7388

74-
encoded_msg = self.get_encoded_value(encoder)
89+
encoded_msg = self.get_encoded_value(last_stage_probs)
7590

76-
return encoder, encoded_msg
91+
return encoded_msg, encoder
7792

7893
def decode(self, encoded_msg, msg_length, probability_table):
7994
"""
8095
Decodes a message.
8196
"""
8297

8398
decoder = []
84-
decoded_msg = ""
99+
100+
decoded_msg = []
85101

86102
stage_min = Decimal(0.0)
87103
stage_max = Decimal(1.0)
@@ -93,13 +109,16 @@ def decode(self, encoded_msg, msg_length, probability_table):
93109
if encoded_msg >= value[0] and encoded_msg <= value[1]:
94110
break
95111

96-
decoded_msg = decoded_msg + msg_term
112+
decoded_msg.append(msg_term)
113+
97114
stage_min = stage_probs[msg_term][0]
98115
stage_max = stage_probs[msg_term][1]
99116

100-
decoder.append(stage_probs)
117+
if self.save_stages:
118+
decoder.append(stage_probs)
101119

102-
stage_probs = self.process_stage(probability_table, stage_min, stage_max)
103-
decoder.append(stage_probs)
120+
if self.save_stages:
121+
last_stage_probs = self.process_stage(probability_table, stage_min, stage_max)
122+
decoder.append(last_stage_probs)
104123

105-
return decoder, decoded_msg
124+
return decoded_msg, decoder

0 commit comments

Comments
 (0)