Skip to content

Commit 459f86b

Browse files
authored
Add dims argument to LogSumExpAtom (#692)
1 parent 2d863c7 commit 459f86b

File tree

3 files changed

+111
-13
lines changed

3 files changed

+111
-13
lines changed

src/atoms/LogSumExpAtom.jl

+24-8
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,26 @@
33
# Use of this source code is governed by a BSD-style license that can be found
44
# in the LICENSE file or at https://opensource.org/license/bsd-2-clause
55

6+
"""
7+
LogSumExpAtom(x::AbstractExpr, dims::Union{Colon,Int} = :)
8+
9+
Represents the expression `log.(sum(exp.(x); dims))`.
10+
"""
611
mutable struct LogSumExpAtom <: AbstractExpr
712
children::Tuple{AbstractExpr}
813
size::Tuple{Int,Int}
14+
dims::Union{Colon,Int}
915

10-
function LogSumExpAtom(x::AbstractExpr)
16+
function LogSumExpAtom(x::AbstractExpr, dims::Union{Colon,Int} = :)
17+
@assert dims == Colon() || 1 <= dims <= 2
1118
if sign(x) == ComplexSign()
1219
error(
1320
"[LogSumExpAtom] the argument should be real but it's instead complex",
1421
)
1522
end
16-
return new((x,), (1, 1))
23+
m = dims == 2 ? size(x, 1) : 1
24+
n = dims == 1 ? size(x, 2) : 1
25+
return new((x,), (m, n), dims)
1726
end
1827
end
1928

@@ -28,22 +37,29 @@ curvature(::LogSumExpAtom) = ConvexVexity()
2837
function evaluate(x::LogSumExpAtom)
2938
_x = evaluate(x.children[1])
3039
max_x = maximum(_x)
31-
return max_x + log(sum(exp.(_x .- max_x)))
40+
return max_x .+ log.(sum(exp.(_x .- max_x); x.dims))
3241
end
3342

34-
logsumexp(x::AbstractExpr) = LogSumExpAtom(x)
43+
logsumexp(x::AbstractExpr; dims = Colon()) = LogSumExpAtom(x, dims)
3544

3645
function new_conic_form!(context::Context, e::LogSumExpAtom)
3746
# log(sum(exp(x))) <= t <=> sum(exp(x)) <= exp(t) <=> sum(exp(x - t)) <= 1
38-
t = Variable()
39-
z = sum(exp(e.children[1] - t * ones(size(e.children[1]))))
40-
add_constraint!(context, 1 >= z)
47+
x = only(e.children)
48+
t = Variable(size(e))
49+
y = if e.dims == 1 # t is a row-vector
50+
ones(size(x, 1), 1) * t
51+
elseif e.dims == 2 # t is a col-vector
52+
t * ones(1, size(x, 2))
53+
else
54+
t * ones(size(x))
55+
end
56+
add_constraint!(context, 1 >= sum(exp(x - y); dims = e.dims))
4157
return conic_form!(context, t)
4258
end
4359

4460
function logisticloss(e::AbstractExpr)
4561
if length(e) == 1
4662
return logsumexp([e; 0])
4763
end
48-
return sum(logsumexp([e[i]; 0]) for i in 1:length(e))
64+
return sum(logsumexp(hcat(vec(e), zeros(length(e))); dims = 2))
4965
end

src/problem_depot/problems/exp.jl

+41-1
Original file line numberDiff line numberDiff line change
@@ -121,14 +121,54 @@ end
121121
) where {T,test}
122122
y = Variable(5)
123123
p = minimize(logsumexp(y), y >= 1; numeric_type = T)
124-
125124
if test
126125
@test problem_vexity(p) == ConvexVexity()
127126
end
128127
handle_problem!(p)
129128
if test
130129
@test p.optval log(exp(1) * 5) atol = atol rtol = rtol
131130
end
131+
y = Variable(5, 2)
132+
p = minimize(
133+
sum(Convex.logsumexp(y; dims = 1)),
134+
y[:, 1] >= 1,
135+
y[:, 2] >= 2;
136+
numeric_type = T,
137+
)
138+
handle_problem!(p)
139+
if test
140+
@test evaluate(y[:, 1]) ones(5) atol = atol rtol = rtol
141+
@test evaluate(y[:, 2]) 2 * ones(5) atol = atol rtol = rtol
142+
@test (
143+
p.optval,
144+
log(exp(1) * 5) + log(exp(2) * 5);
145+
atol = atol,
146+
rtol = rtol,
147+
)
148+
end
149+
p = minimize(logsumexp(y), y[:, 1] >= 1, y[:, 2] >= 2; numeric_type = T)
150+
handle_problem!(p)
151+
if test
152+
@test evaluate(y[:, 1]) ones(5) atol = atol rtol = rtol
153+
@test evaluate(y[:, 2]) 2 * ones(5) atol = atol rtol = rtol
154+
@test p.optval log(exp(1) * 5 + exp(2) * 5) atol = atol rtol = rtol
155+
end
156+
157+
x = Variable(2, 3)
158+
v = Convex.logsumexp(x; dims = 1)
159+
p = minimize(sum(v), x >= [1 2 3; 4 5 6]; numeric_type = T)
160+
handle_problem!(p)
161+
if test
162+
@test evaluate(x) [1 2 3; 4 5 6] atol = atol rtol = rtol
163+
@test (
164+
evaluate(v),
165+
log.(sum(exp, evaluate(x); dims = 1));
166+
atol = atol,
167+
rtol = rtol,
168+
)
169+
@test vexity(v) == Convex.ConvexVexity()
170+
end
171+
return
132172
end
133173

134174
@add_problem exp function exp_logistic_loss_atom(

test/test_atoms.jl

+46-4
Original file line numberDiff line numberDiff line change
@@ -1115,14 +1115,13 @@ function test_LogSumExpAtom()
11151115
return logisticloss(Variable())
11161116
end
11171117
target = """
1118-
variables: x1, x1_, t, z1, z2, t_, z1_, z2_
1118+
variables: x1, x1_, t, t_, z1, z1_, z2, z2_
11191119
minobjective: 1.0 * t + 1.0 * t_
1120+
[1.0 + -1.0*z1 + -1.0*z2, 1.0 + -1.0*z1_ + -1.0*z2_] in Nonnegatives(2)
11201121
[1.0 * x1 + -1.0 * t, 1.0, 1.0 * z1] in ExponentialCone()
1121-
[-1.0 * t, 1.0, 1.0 * z2] in ExponentialCone()
11221122
[1.0 * x1_ + -1.0 * t_, 1.0, 1.0 * z1_] in ExponentialCone()
1123+
[-1.0 * t, 1.0, 1.0 * z2] in ExponentialCone()
11231124
[-1.0 * t_, 1.0, 1.0 * z2_] in ExponentialCone()
1124-
[1.0 + -1.0*z1 + -1.0*z2] in Nonnegatives(1)
1125-
[1.0 + -1.0*z1_ + -1.0*z2_] in Nonnegatives(1)
11261125
"""
11271126
_test_atom(target) do context
11281127
return logisticloss(Variable(2))
@@ -1137,6 +1136,49 @@ function test_LogSumExpAtom()
11371136
atom = logsumexp(x)
11381137
x.value = [1.0 1_000.0]
11391138
@test evaluate(atom) 1_000.0
1139+
x = Variable(2, 3)
1140+
x.value = [1 2 3; 4 5 6]
1141+
@test evaluate(logsumexp(x; dims = :)) 6.456193316018123
1142+
@test (
1143+
evaluate(logsumexp(x; dims = 1)),
1144+
[4.04859 5.04859 6.04859],
1145+
atol = 1e-5,
1146+
)
1147+
@test (
1148+
evaluate(logsumexp(x; dims = 2)),
1149+
[3.40760596444438, 6.407605964444381],
1150+
atol = 1e-5,
1151+
)
1152+
target = """
1153+
variables: x11, x12, x21, x22, t1, t2, y11, y12, y21, y22
1154+
minobjective: [1.0 * t1, 1.0 * t2]
1155+
[-1.0 + x11, -2 + x12, -3 + x21, -4 + x22] in Nonnegatives(4)
1156+
[1.0 + -1.0 * y11 + -1.0 * y21, 1.0 + -1.0 * y12 + -1.0 * y22] in Nonnegatives(2)
1157+
[1.0 * x11 + -1.0 * t1, 1.0, y11] in ExponentialCone()
1158+
[1.0 * x12 + -1.0 * t2, 1.0, y12] in ExponentialCone()
1159+
[1.0 * x21 + -1.0 * t1, 1.0, y21] in ExponentialCone()
1160+
[1.0 * x22 + -1.0 * t2, 1.0, y22] in ExponentialCone()
1161+
"""
1162+
_test_atom(target) do context
1163+
x = Variable(2, 2)
1164+
add_constraint!(context, x >= [1 3; 2 4])
1165+
return logsumexp(x; dims = 2)
1166+
end
1167+
target = """
1168+
variables: x11, x12, x21, x22, t1, t2, y11, y12, y21, y22
1169+
minobjective: [1.0 * t1, 1.0 * t2]
1170+
[-1.0 + x11, -2 + x12, -3 + x21, -4 + x22] in Nonnegatives(4)
1171+
[1.0 + -1.0 * y11 + -1.0 * y12, 1.0 + -1.0 * y21 + -1.0 * y22] in Nonnegatives(2)
1172+
[1.0 * x11 + -1.0 * t1, 1.0, y11] in ExponentialCone()
1173+
[1.0 * x12 + -1.0 * t1, 1.0, y12] in ExponentialCone()
1174+
[1.0 * x21 + -1.0 * t2, 1.0, y21] in ExponentialCone()
1175+
[1.0 * x22 + -1.0 * t2, 1.0, y22] in ExponentialCone()
1176+
"""
1177+
_test_atom(target) do context
1178+
x = Variable(2, 2)
1179+
add_constraint!(context, x >= [1 3; 2 4])
1180+
return logsumexp(x; dims = 1)
1181+
end
11401182
return
11411183
end
11421184

0 commit comments

Comments
 (0)