Skip to content

Commit 5dcc192

Browse files
committed
[LOG4J2-1863] Add class filtering to AbstractSocketServer
This allows a whitelist of class names to be specified to configure which classes are allowed to be deserialized in both TcpSocketServer and UdpSocketServer.
1 parent 5aff929 commit 5dcc192

File tree

5 files changed

+140
-6
lines changed

5 files changed

+140
-6
lines changed

log4j-core/src/main/java/org/apache/logging/log4j/core/net/server/AbstractSocketServer.java

+13
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
import java.net.InetAddress;
2727
import java.net.URI;
2828
import java.net.URL;
29+
import java.util.Collections;
30+
import java.util.List;
2931
import java.util.Objects;
3032

3133
import com.beust.jcommander.Parameter;
@@ -70,6 +72,9 @@ protected static class CommandLineArguments extends BasicCommandLineArguments {
7072
"-a" }, converter = InetAddressConverter.class, description = "Server socket local bind address.")
7173
private InetAddress localBindAddress;
7274

75+
@Parameter(names = {"--classes", "-C"}, description = "Additional classes to allow deserialization")
76+
private List<String> allowedClasses;
77+
7378
String getConfigLocation() {
7479
return configLocation;
7580
}
@@ -101,6 +106,14 @@ InetAddress getLocalBindAddress() {
101106
void setLocalBindAddress(final InetAddress localBindAddress) {
102107
this.localBindAddress = localBindAddress;
103108
}
109+
110+
List<String> getAllowedClasses() {
111+
return allowedClasses == null ? Collections.<String>emptyList() : allowedClasses;
112+
}
113+
114+
void setAllowedClasses(final List<String> allowedClasses) {
115+
this.allowedClasses = allowedClasses;
116+
}
104117
}
105118

106119
/**

log4j-core/src/main/java/org/apache/logging/log4j/core/net/server/ObjectInputStreamLogEventBridge.java

+21-2
Original file line numberDiff line numberDiff line change
@@ -19,18 +19,37 @@
1919
import java.io.IOException;
2020
import java.io.InputStream;
2121
import java.io.ObjectInputStream;
22+
import java.util.Collections;
23+
import java.util.List;
2224

2325
import org.apache.logging.log4j.core.LogEvent;
2426
import org.apache.logging.log4j.core.LogEventListener;
27+
import org.apache.logging.log4j.core.util.FilteredObjectInputStream;
2528

2629
/**
2730
* Reads and logs serialized {@link LogEvent} objects from an {@link ObjectInputStream}.
2831
*/
2932
public class ObjectInputStreamLogEventBridge extends AbstractLogEventBridge<ObjectInputStream> {
3033

34+
private final List<String> allowedClasses;
35+
36+
public ObjectInputStreamLogEventBridge() {
37+
this(Collections.<String>emptyList());
38+
}
39+
40+
/**
41+
* Constructs an ObjectInputStreamLogEventBridge with additional allowed classes to deserialize.
42+
*
43+
* @param allowedClasses class names to also allow for deserialization
44+
* @since 2.8.2
45+
*/
46+
public ObjectInputStreamLogEventBridge(final List<String> allowedClasses) {
47+
this.allowedClasses = allowedClasses;
48+
}
49+
3150
@Override
3251
public void logEvents(final ObjectInputStream inputStream, final LogEventListener logEventListener)
33-
throws IOException {
52+
throws IOException {
3453
try {
3554
logEventListener.log((LogEvent) inputStream.readObject());
3655
} catch (final ClassNotFoundException e) {
@@ -40,6 +59,6 @@ public void logEvents(final ObjectInputStream inputStream, final LogEventListene
4059

4160
@Override
4261
public ObjectInputStream wrapStream(final InputStream inputStream) throws IOException {
43-
return new ObjectInputStream(inputStream);
62+
return new FilteredObjectInputStream(inputStream, allowedClasses);
4463
}
4564
}

log4j-core/src/main/java/org/apache/logging/log4j/core/net/server/TcpSocketServer.java

+22-3
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
import java.net.InetAddress;
2525
import java.net.ServerSocket;
2626
import java.net.Socket;
27+
import java.util.Collections;
28+
import java.util.List;
2729
import java.util.Map;
2830
import java.util.concurrent.ConcurrentHashMap;
2931
import java.util.concurrent.ConcurrentMap;
@@ -148,9 +150,26 @@ public static TcpSocketServer<ObjectInputStream> createSerializedSocketServer(fi
148150
*/
149151
public static TcpSocketServer<ObjectInputStream> createSerializedSocketServer(final int port, final int backlog,
150152
final InetAddress localBindAddress) throws IOException {
153+
return createSerializedSocketServer(port, backlog, localBindAddress, Collections.<String>emptyList());
154+
}
155+
156+
/**
157+
* Creates a socket server that reads serialized log events.
158+
*
159+
* @param port the port to listen
160+
* @param localBindAddress The server socket's local bin address
161+
* @param allowedClasses additional class names to allow for deserialization
162+
* @return a new a socket server
163+
* @throws IOException
164+
* if an I/O error occurs when opening the socket.
165+
* @since 2.8.2
166+
*/
167+
public static TcpSocketServer<ObjectInputStream> createSerializedSocketServer(
168+
final int port, final int backlog, final InetAddress localBindAddress, final List<String> allowedClasses
169+
) throws IOException {
151170
LOGGER.entry(port);
152171
final TcpSocketServer<ObjectInputStream> socketServer = new TcpSocketServer<>(port, backlog, localBindAddress,
153-
new ObjectInputStreamLogEventBridge());
172+
new ObjectInputStreamLogEventBridge(allowedClasses));
154173
return LOGGER.exit(socketServer);
155174
}
156175

@@ -185,8 +204,8 @@ public static void main(final String[] args) throws Exception {
185204
if (cla.getConfigLocation() != null) {
186205
ConfigurationFactory.setConfigurationFactory(new ServerConfigurationFactory(cla.getConfigLocation()));
187206
}
188-
final TcpSocketServer<ObjectInputStream> socketServer = TcpSocketServer
189-
.createSerializedSocketServer(cla.getPort(), cla.getBacklog(), cla.getLocalBindAddress());
207+
final TcpSocketServer<ObjectInputStream> socketServer = TcpSocketServer.createSerializedSocketServer(
208+
cla.getPort(), cla.getBacklog(), cla.getLocalBindAddress(), cla.getAllowedClasses());
190209
final Thread serverThread = socketServer.startNewThread();
191210
if (cla.isInteractive()) {
192211
socketServer.awaitTermination(serverThread);

log4j-core/src/main/java/org/apache/logging/log4j/core/net/server/UdpSocketServer.java

+17-1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import java.io.OptionalDataException;
2525
import java.net.DatagramPacket;
2626
import java.net.DatagramSocket;
27+
import java.util.List;
2728

2829
import org.apache.logging.log4j.core.config.ConfigurationFactory;
2930
import org.apache.logging.log4j.core.util.BasicCommandLineArguments;
@@ -63,6 +64,21 @@ public static UdpSocketServer<ObjectInputStream> createSerializedSocketServer(fi
6364
return new UdpSocketServer<>(port, new ObjectInputStreamLogEventBridge());
6465
}
6566

67+
/**
68+
* Creates a socket server that reads serialized log events.
69+
*
70+
* @param port the port to listen
71+
* @param allowedClasses additional classes to allow for deserialization
72+
* @return a new a socket server
73+
* @throws IOException if an I/O error occurs when opening the socket.
74+
* @since 2.8.2
75+
*/
76+
public static UdpSocketServer<ObjectInputStream> createSerializedSocketServer(final int port,
77+
final List<String> allowedClasses)
78+
throws IOException {
79+
return new UdpSocketServer<>(port, new ObjectInputStreamLogEventBridge(allowedClasses));
80+
}
81+
6682
/**
6783
* Creates a socket server that reads XML log events.
6884
*
@@ -93,7 +109,7 @@ public static void main(final String[] args) throws Exception {
93109
ConfigurationFactory.setConfigurationFactory(new ServerConfigurationFactory(cla.getConfigLocation()));
94110
}
95111
final UdpSocketServer<ObjectInputStream> socketServer = UdpSocketServer
96-
.createSerializedSocketServer(cla.getPort());
112+
.createSerializedSocketServer(cla.getPort(), cla.getAllowedClasses());
97113
final Thread serverThread = socketServer.startNewThread();
98114
if (cla.isInteractive()) {
99115
socketServer.awaitTermination(serverThread);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache license, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the license for the specific language governing permissions and
15+
* limitations under the license.
16+
*/
17+
package org.apache.logging.log4j.core.util;
18+
19+
import java.io.IOException;
20+
import java.io.InputStream;
21+
import java.io.InvalidObjectException;
22+
import java.io.ObjectInputStream;
23+
import java.io.ObjectStreamClass;
24+
import java.util.Arrays;
25+
import java.util.Collection;
26+
import java.util.List;
27+
28+
/**
29+
* Extended ObjectInputStream that only allows certain classes to be deserialized.
30+
*
31+
* @since 2.8.2
32+
*/
33+
public class FilteredObjectInputStream extends ObjectInputStream {
34+
35+
private static final List<String> REQUIRED_JAVA_CLASSES = Arrays.asList(
36+
// for StandardLevel
37+
"java.lang.Enum",
38+
// for location information
39+
"java.lang.StackTraceElement",
40+
// for Message delegate
41+
"java.rmi.MarshalledObject",
42+
"[B"
43+
);
44+
45+
private final Collection<String> allowedClasses;
46+
47+
public FilteredObjectInputStream(final InputStream in, final Collection<String> allowedClasses) throws IOException {
48+
super(in);
49+
this.allowedClasses = allowedClasses;
50+
}
51+
52+
@Override
53+
protected Class<?> resolveClass(final ObjectStreamClass desc) throws IOException, ClassNotFoundException {
54+
String name = desc.getName();
55+
if (!(isAllowedByDefault(name) || allowedClasses.contains(name))) {
56+
throw new InvalidObjectException("Class is not allowed for deserialization: " + name);
57+
}
58+
return super.resolveClass(desc);
59+
}
60+
61+
private static boolean isAllowedByDefault(final String name) {
62+
return name.startsWith("org.apache.logging.log4j.") ||
63+
name.startsWith("[Lorg.apache.logging.log4j.") ||
64+
REQUIRED_JAVA_CLASSES.contains(name);
65+
}
66+
67+
}

0 commit comments

Comments
 (0)