Skip to content

Commit

Permalink
Refine SegmentFetcherFactory (apache#12936)
Browse files Browse the repository at this point in the history
  • Loading branch information
Jackie-Jiang authored Apr 16, 2024
1 parent d4cb93d commit 67cb52c
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 114 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,61 +18,45 @@
*/
package org.apache.pinot.common.utils.fetcher;

import com.google.common.base.Preconditions;
import java.io.File;
import java.net.URI;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Supplier;
import javax.annotation.Nullable;
import org.apache.pinot.common.auth.AuthConfig;
import org.apache.pinot.common.auth.AuthProviderUtils;
import org.apache.pinot.spi.crypt.PinotCrypter;
import org.apache.pinot.spi.crypt.PinotCrypterFactory;
import org.apache.pinot.spi.env.PinotConfiguration;
import org.apache.pinot.spi.utils.CommonConstants;
import org.checkerframework.checker.nullness.qual.NonNull;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;


public class SegmentFetcherFactory {
private final static SegmentFetcherFactory INSTANCE = new SegmentFetcherFactory();

static final String SEGMENT_FETCHER_CLASS_KEY_SUFFIX = ".class";
private static final String PROTOCOLS_KEY = "protocols";
private static final String ENCODED_SUFFIX = ".enc";
private static final String AUTH_KEY = CommonConstants.KEY_OF_AUTH;

private static final Logger LOGGER = LoggerFactory.getLogger(SegmentFetcherFactory.class);
private static final Random RANDOM = new Random();

private final Map<String, SegmentFetcher> _segmentFetcherMap = new HashMap<>();
private final SegmentFetcher _httpSegmentFetcher = new HttpSegmentFetcher();
private final SegmentFetcher _pinotFSSegmentFetcher = new PinotFSSegmentFetcher();

private SegmentFetcherFactory() {
// left blank
}

public static SegmentFetcherFactory getInstance() {
return INSTANCE;
}
public static final String SEGMENT_FETCHER_CLASS_KEY_SUFFIX = ".class";
public static final String PROTOCOLS_KEY = "protocols";
public static final String ENCODED_SUFFIX = ".enc";

private static final Logger LOGGER = LoggerFactory.getLogger(SegmentFetcherFactory.class);
private static final Map<String, SegmentFetcher> SEGMENT_FETCHER_MAP = new HashMap<>();
private static final SegmentFetcher HTTP_SEGMENT_FETCHER = new HttpSegmentFetcher();
private static final SegmentFetcher PINOT_FS_SEGMENT_FETCHER = new PinotFSSegmentFetcher();

/**
* Initializes the segment fetcher factory. This method should only be called once.
*/
public static void init(PinotConfiguration config)
throws Exception {
getInstance().initInternal(config);
}

private void initInternal(PinotConfiguration config)
throws Exception {
_httpSegmentFetcher.init(config); // directly, without sub-namespace
_pinotFSSegmentFetcher.init(config); // directly, without sub-namespace
HTTP_SEGMENT_FETCHER.init(config); // directly, without sub-namespace
PINOT_FS_SEGMENT_FETCHER.init(config); // directly, without sub-namespace

List<String> protocols = config.getProperty(PROTOCOLS_KEY, Collections.emptyList());
for (String protocol : protocols) {
Expand All @@ -93,22 +77,22 @@ private void initInternal(PinotConfiguration config)
}
} else {
LOGGER.info("Creating segment fetcher for protocol: {} with class: {}", protocol, segmentFetcherClassName);
segmentFetcher = (SegmentFetcher) Class.forName(segmentFetcherClassName).newInstance();
segmentFetcher = (SegmentFetcher) Class.forName(segmentFetcherClassName).getConstructor().newInstance();
}

AuthConfig authConfig = AuthProviderUtils.extractAuthConfig(config, AUTH_KEY);

PinotConfiguration subConfig = config.subset(protocol);
AuthConfig subAuthConfig = AuthProviderUtils.extractAuthConfig(subConfig, AUTH_KEY);
Map<String, Object> subConfigMap = subConfig.toMap();

Map<String, Object> subConfigMap = config.subset(protocol).toMap();
// Put global auth properties into sub-config if sub-config does not have auth properties
AuthConfig authConfig = AuthProviderUtils.extractAuthConfig(config, CommonConstants.KEY_OF_AUTH);
AuthConfig subAuthConfig = AuthProviderUtils.extractAuthConfig(subConfig, CommonConstants.KEY_OF_AUTH);
if (subAuthConfig.getProperties().isEmpty() && !authConfig.getProperties().isEmpty()) {
authConfig.getProperties().forEach((key, value) -> subConfigMap.put(AUTH_KEY + "." + key, value));
authConfig.getProperties()
.forEach((key, value) -> subConfigMap.put(CommonConstants.KEY_OF_AUTH + "." + key, value));
}

segmentFetcher.init(new PinotConfiguration(subConfigMap));

_segmentFetcherMap.put(protocol, segmentFetcher);
SEGMENT_FETCHER_MAP.put(protocol, segmentFetcher);
}
}

Expand All @@ -117,21 +101,17 @@ private void initInternal(PinotConfiguration config)
* ({@link HttpSegmentFetcher} for "http" and "https", {@link PinotFSSegmentFetcher} for other protocols).
*/
public static SegmentFetcher getSegmentFetcher(String protocol) {
return getInstance().getSegmentFetcherInternal(protocol);
}

private SegmentFetcher getSegmentFetcherInternal(String protocol) {
SegmentFetcher segmentFetcher = _segmentFetcherMap.get(protocol);
SegmentFetcher segmentFetcher = SEGMENT_FETCHER_MAP.get(protocol);
if (segmentFetcher != null) {
return segmentFetcher;
} else {
LOGGER.info("Segment fetcher is not configured for protocol: {}, using default", protocol);
switch (protocol) {
case CommonConstants.HTTP_PROTOCOL:
case CommonConstants.HTTPS_PROTOCOL:
return _httpSegmentFetcher;
return HTTP_SEGMENT_FETCHER;
default:
return _pinotFSSegmentFetcher;
return PINOT_FS_SEGMENT_FETCHER;
}
}
}
Expand All @@ -141,21 +121,15 @@ private SegmentFetcher getSegmentFetcherInternal(String protocol) {
*/
public static void fetchSegmentToLocal(URI uri, File dest)
throws Exception {
getInstance().fetchSegmentToLocalInternal(uri, dest);
getSegmentFetcher(uri.getScheme()).fetchSegmentToLocal(uri, dest);
}

/**
* Fetches a segment from URI location to local.
*/
public static void fetchSegmentToLocal(String uri, File dest)
throws Exception {
getInstance().fetchSegmentToLocalInternal(new URI(uri), dest);
}

private void fetchSegmentToLocalInternal(URI uri, File dest)
throws Exception {
// caller untars
getSegmentFetcher(uri.getScheme()).fetchSegmentToLocal(uri, dest);
fetchSegmentToLocal(new URI(uri), dest);
}

/**
Expand All @@ -167,36 +141,25 @@ private void fetchSegmentToLocalInternal(URI uri, File dest)
* @return the untared directory
* @throws Exception
*/
public static File fetchAndStreamUntarToLocal(String uri, File tempRootDir,
long maxStreamRateInByte, AtomicInteger attempts)
public static File fetchAndStreamUntarToLocal(URI uri, File tempRootDir, long maxStreamRateInByte,
AtomicInteger attempts)
throws Exception {
return getInstance().fetchAndStreamUntarToLocalInternal(new URI(uri), tempRootDir, maxStreamRateInByte, attempts);
return getSegmentFetcher(uri.getScheme()).fetchUntarSegmentToLocalStreamed(uri, tempRootDir, maxStreamRateInByte,
attempts);
}

private File fetchAndStreamUntarToLocalInternal(URI uri, File tempRootDir,
long maxStreamRateInByte, AtomicInteger attempts)
public static File fetchAndStreamUntarToLocal(String uri, File tempRootDir, long maxStreamRateInByte,
AtomicInteger attempts)
throws Exception {
return getSegmentFetcher(uri.getScheme()).fetchUntarSegmentToLocalStreamed(uri, tempRootDir, maxStreamRateInByte,
attempts);
return fetchAndStreamUntarToLocal(new URI(uri), tempRootDir, maxStreamRateInByte, attempts);
}

/**
* Fetches a segment from a URI location to a local file and decrypts it if needed
* @param uri remote segment location
* @param dest local file
*/
public static void fetchAndDecryptSegmentToLocal(String uri, File dest, String crypterName)
throws Exception {
getInstance().fetchAndDecryptSegmentToLocalInternal(uri, dest, crypterName);
}

// uris have equal weight to be selected for segment download
public static void fetchAndDecryptSegmentToLocal(List<URI> uris, File dest, String crypterName)
throws Exception {
getInstance().fetchAndDecryptSegmentToLocalInternal(uris, dest, crypterName);
}

private void fetchAndDecryptSegmentToLocalInternal(String uri, File dest, String crypterName)
public static void fetchAndDecryptSegmentToLocal(String uri, File dest, @Nullable String crypterName)
throws Exception {
if (crypterName == null) {
fetchSegmentToLocal(uri, dest);
Expand All @@ -211,16 +174,16 @@ private void fetchAndDecryptSegmentToLocalInternal(String uri, File dest, String
}
}

private void fetchAndDecryptSegmentToLocalInternal(@NonNull List<URI> uris, File dest, String crypterName)
throws Exception {
Preconditions.checkArgument(!uris.isEmpty(), "empty uris passed into the fetchAndDecryptSegmentToLocalInternal");
URI uri = uris.get(RANDOM.nextInt(uris.size()));
public static void fetchAndDecryptSegmentToLocal(String segmentName, String scheme, Supplier<List<URI>> uriSupplier,
File dest, @Nullable String crypterName)
throws Exception {
SegmentFetcher segmentFetcher = getSegmentFetcher(scheme);
if (crypterName == null) {
fetchSegmentToLocal(uri, dest);
segmentFetcher.fetchSegmentToLocal(segmentName, uriSupplier, dest);
} else {
// download
File tempDownloadedFile = new File(dest.getPath() + ENCODED_SUFFIX);
fetchSegmentToLocal(uri, tempDownloadedFile);
segmentFetcher.fetchSegmentToLocal(segmentName, uriSupplier, tempDownloadedFile);

// decrypt
PinotCrypter crypter = PinotCrypterFactory.create(crypterName);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -693,27 +693,22 @@ File downloadAndDecrypt(String segmentName, SegmentZKMetadata zkMetadata, File t
}
}

// not thread safe. Caller should invoke it with safe concurrency control.
protected void downloadFromPeersWithoutStreaming(String segmentName, SegmentZKMetadata zkMetadata, File destTarFile)
throws Exception {
Preconditions.checkState(_peerDownloadScheme != null, "Download peers require non null peer download scheme");
List<URI> peerSegmentURIs =
PeerServerSegmentFinder.getPeerServerURIs(_helixManager, _tableNameWithType, segmentName, _peerDownloadScheme);
if (peerSegmentURIs.isEmpty()) {
String msg = String.format("segment %s doesn't have any peers", segmentName);
LOGGER.warn(msg);
// HelixStateTransitionHandler would catch the runtime exception and mark the segment state as Error
throw new RuntimeException(msg);
}
Preconditions.checkState(_peerDownloadScheme != null, "Peer download is not enabled for table: %s",
_tableNameWithType);
try {
// Next download the segment from a randomly chosen server using configured scheme.
SegmentFetcherFactory.fetchAndDecryptSegmentToLocal(peerSegmentURIs, destTarFile, zkMetadata.getCrypterName());
LOGGER.info("Fetched segment {} from peers: {} to: {} of size: {}", segmentName, peerSegmentURIs, destTarFile,
SegmentFetcherFactory.fetchAndDecryptSegmentToLocal(segmentName, _peerDownloadScheme, () -> {
List<URI> peerServerURIs =
PeerServerSegmentFinder.getPeerServerURIs(_helixManager, _tableNameWithType, segmentName,
_peerDownloadScheme);
Collections.shuffle(peerServerURIs);
return peerServerURIs;
}, destTarFile, zkMetadata.getCrypterName());
_logger.info("Downloaded tarred segment: {} from peers to: {}, file length: {}", segmentName, destTarFile,
destTarFile.length());
} catch (AttemptsExceededException e) {
LOGGER.error("Attempts exceeded when downloading segment: {} for table: {} from peers {} to: {}", segmentName,
_tableNameWithType, peerSegmentURIs, destTarFile);
_serverMetrics.addMeteredTableValue(_tableNameWithType, ServerMeter.SEGMENT_DOWNLOAD_FROM_PEERS_FAILURES, 1L);
} catch (Exception e) {
_serverMetrics.addMeteredTableValue(_tableNameWithType, ServerMeter.SEGMENT_DOWNLOAD_FROM_PEERS_FAILURES, 1);
throw e;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
import org.apache.pinot.common.utils.fetcher.BaseSegmentFetcher;
import org.apache.pinot.common.utils.fetcher.SegmentFetcherFactory;
import org.apache.pinot.core.data.manager.offline.OfflineTableDataManager;
import org.apache.pinot.core.util.PeerServerSegmentFinder;
import org.apache.pinot.segment.local.segment.creator.impl.SegmentIndexCreationDriverImpl;
import org.apache.pinot.segment.local.segment.index.loader.IndexLoadingConfig;
import org.apache.pinot.segment.local.segment.readers.GenericRowRecordReader;
Expand Down Expand Up @@ -647,26 +646,6 @@ public void testDownloadAndDecryptPeerDownload()
verify(tmgr, times(1)).downloadFromPeersWithoutStreaming("seg01", zkmd, destFile);
}

// happy case: download from peers
@Test
public void testDownloadFromPeersWithoutStreaming()
throws Exception {
URI uri = mockRemoteCopy();
InstanceDataManagerConfig config = createDefaultInstanceDataManagerConfig();
when(config.getSegmentPeerDownloadScheme()).thenReturn("http");
HelixManager helixManager = mock(HelixManager.class);
BaseTableDataManager tmgr = createTableManager(config, helixManager);
File tempRootDir = tmgr.getTmpSegmentDataDir("test-download-peer-without-streaming");
File destFile = new File(tempRootDir, "seg01" + TarGzCompressionUtils.TAR_GZ_FILE_EXTENSION);
try (MockedStatic<PeerServerSegmentFinder> mockPeerSegFinder = mockStatic(PeerServerSegmentFinder.class)) {
mockPeerSegFinder.when(
() -> PeerServerSegmentFinder.getPeerServerURIs(helixManager, TABLE_NAME_WITH_TYPE, "seg01",
CommonConstants.HTTP_PROTOCOL)).thenReturn(List.of(uri));
tmgr.downloadFromPeersWithoutStreaming("seg01", mock(SegmentZKMetadata.class), destFile);
}
assertEquals(FileUtils.readFileToString(destFile), "this is from somewhere remote");
}

@Test
public void testUntarAndMoveSegment()
throws IOException {
Expand Down

0 comments on commit 67cb52c

Please sign in to comment.