Skip to content

Commit

Permalink
[OTE-130] implement new geo-blocking strategy for socks (#1049)
Browse files Browse the repository at this point in the history
  • Loading branch information
dydxwill authored Feb 26, 2024
1 parent 0de0096 commit 48e582b
Show file tree
Hide file tree
Showing 6 changed files with 29 additions and 124 deletions.
50 changes: 9 additions & 41 deletions indexer/services/socks/__tests__/lib/subscriptions.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import { btcTicker, invalidChannel, invalidTicker } from '../constants';
import { axiosRequest } from '../../src/lib/axios';
import { AxiosSafeServerError, makeAxiosSafeServerError } from '@dydxprotocol-indexer/base';
import { BlockedError } from '../../src/lib/errors';
import { isRestrictedCountry } from '@dydxprotocol-indexer/compliance';

jest.mock('ws');
jest.mock('../../src/helpers/wss');
Expand Down Expand Up @@ -58,8 +57,7 @@ describe('Subscriptions', () => {
[Channel.V4_TRADES]: ['/v4/trades/perpetualMarket/.+'],
};
const initialMessage: Object = { a: 'b' };
const restrictedCountry: string = 'US';
const nonRestrictedCountry: string = 'AR';
const country: string = 'AR';

beforeAll(async () => {
await dbHelpers.migrate();
Expand All @@ -83,9 +81,6 @@ describe('Subscriptions', () => {
axiosRequestMock = (axiosRequest as jest.Mock);
axiosRequestMock.mockClear();
axiosRequestMock.mockImplementation(() => (JSON.stringify(initialMessage)));
(isRestrictedCountry as jest.Mock).mockImplementation((country: string): boolean => {
return country === restrictedCountry;
});
});

describe('subscribe', () => {
Expand All @@ -106,7 +101,7 @@ describe('Subscriptions', () => {
initialMsgId,
id,
false,
nonRestrictedCountry,
country,
);

expect(sendMessageStringMock).toHaveBeenCalledTimes(1);
Expand All @@ -126,6 +121,9 @@ describe('Subscriptions', () => {
for (const urlPattern of urlPatterns) {
expect(axiosRequestMock).toHaveBeenCalledWith(expect.objectContaining({
url: expect.stringMatching(RegExp(urlPattern)),
headers: {
'cf-ipcountry': country,
},
}));
}
} else {
Expand All @@ -150,7 +148,6 @@ describe('Subscriptions', () => {
initialMsgId,
id,
false,
nonRestrictedCountry,
);

expect(sendMessageMock).toHaveBeenCalledTimes(1);
Expand Down Expand Up @@ -179,7 +176,6 @@ describe('Subscriptions', () => {
initialMsgId,
defaultId,
false,
nonRestrictedCountry,
);
},
).rejects.toEqual(new Error(`Invalid channel: ${invalidChannel}`));
Expand All @@ -194,7 +190,6 @@ describe('Subscriptions', () => {
initialMsgId,
mockSubaccountId,
false,
nonRestrictedCountry,
);

expect(sendMessageMock).toHaveBeenCalledTimes(1);
Expand All @@ -217,7 +212,6 @@ describe('Subscriptions', () => {
initialMsgId,
mockSubaccountId,
false,
nonRestrictedCountry,
);

expect(sendMessageMock).toHaveBeenCalledTimes(1);
Expand Down Expand Up @@ -253,32 +247,7 @@ describe('Subscriptions', () => {
initialMsgId,
mockSubaccountId,
false,
nonRestrictedCountry,
);

expect(sendMessageMock).toHaveBeenCalledTimes(1);
expect(sendMessageMock).toHaveBeenCalledWith(
mockWs,
connectionId,
expect.objectContaining({
connection_id: connectionId,
type: 'error',
message: expectedError.message,
}));
expect(subscriptions.subscriptions[Channel.V4_ACCOUNTS]).toBeUndefined();
expect(subscriptions.subscriptionLists[connectionId]).toBeUndefined();
});

it('sends blocked error if subscribing to subaccount from restricted country', async () => {
const expectedError: BlockedError = new BlockedError();
await subscriptions.subscribe(
mockWs,
Channel.V4_ACCOUNTS,
connectionId,
initialMsgId,
mockSubaccountId,
false,
restrictedCountry,
country,
);

expect(sendMessageMock).toHaveBeenCalledTimes(1);
Expand All @@ -305,7 +274,7 @@ describe('Subscriptions', () => {
initialMsgId,
mockSubaccountId,
false,
nonRestrictedCountry,
country,
);

expect(sendMessageStringMock).toHaveBeenCalledTimes(1);
Expand Down Expand Up @@ -342,7 +311,7 @@ describe('Subscriptions', () => {
initialMsgId,
id,
false,
nonRestrictedCountry,
country,
);
subscriptions.unsubscribe(
connectionId,
Expand All @@ -362,7 +331,7 @@ describe('Subscriptions', () => {
initialMsgId,
mockSubaccountId,
false,
nonRestrictedCountry,
country,
);
subscriptions.unsubscribe(
connectionId,
Expand All @@ -386,7 +355,6 @@ describe('Subscriptions', () => {
initialMsgId,
validIds[channel],
false,
nonRestrictedCountry,
);
}));

Expand Down
47 changes: 1 addition & 46 deletions indexer/services/socks/__tests__/websocket/index.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,13 @@ import {
} from '../../src/types';
import { InvalidMessageHandler } from '../../src/lib/invalid-message';
import { PingHandler } from '../../src/lib/ping';
import config from '../../src/config';
import { isRestrictedCountryHeaders, COUNTRY_HEADER_KEY } from '@dydxprotocol-indexer/compliance';
import { COUNTRY_HEADER_KEY } from '@dydxprotocol-indexer/compliance';

jest.mock('uuid');
jest.mock('../../src/helpers/wss');
jest.mock('../../src/lib/subscription');
jest.mock('../../src/lib/invalid-message');
jest.mock('../../src/lib/ping');
jest.mock('@dydxprotocol-indexer/compliance');

describe('Index', () => {
let index: Index;
Expand All @@ -32,12 +30,10 @@ describe('Index', () => {
let mockConnect: (ws: WebSocket, req: IncomingMessage) => void;
let wsOnSpy: jest.SpyInstance;
let wsPingSpy: jest.SpyInstance;
let wsTerminateSpy: jest.SpyInstance;
let invalidMsgHandlerSpy: jest.SpyInstance;
let pingHandlerSpy: jest.SpyInstance;

const connectionId: string = 'conId';
const defaultGeoblockingEnabled: boolean = config.INDEXER_LEVEL_GEOBLOCKING_ENABLED;
const countryCode: string = 'AR';

beforeAll(() => {
Expand All @@ -58,7 +54,6 @@ describe('Index', () => {
websocket = new WebSocket(null);
wsOnSpy = jest.spyOn(websocket, 'on');
wsPingSpy = jest.spyOn(websocket, 'ping').mockImplementation(jest.fn());
wsTerminateSpy = jest.spyOn(websocket, 'terminate').mockImplementation(jest.fn());
mockWss.onConnection = jest.fn().mockImplementation(
(cb: (ws: WebSocket, req: IncomingMessage) => void) => {
mockConnect = cb;
Expand Down Expand Up @@ -97,46 +92,6 @@ describe('Index', () => {
}),
);
});

describe('geoblocking', () => {
const isRestrictedCountrySpy: jest.Mock = isRestrictedCountryHeaders as unknown as jest.Mock;

beforeAll(() => {
config.INDEXER_LEVEL_GEOBLOCKING_ENABLED = true;
});

afterAll(() => {
config.INDEXER_LEVEL_GEOBLOCKING_ENABLED = defaultGeoblockingEnabled;
});

it('rejects connection if from restricted country', () => {
jest.spyOn(websocket, 'terminate').mockImplementation(jest.fn());
// restricted country headers
isRestrictedCountrySpy.mockReturnValue(true);

const message: IncomingMessage = new IncomingMessage(new Socket());
mockConnect(websocket, message);
expect(websocket.terminate).toHaveBeenCalled();
expect(Object.keys(index.connections)).toHaveLength(0);
expect(wsOnSpy).not.toHaveBeenCalled();
expect(wsTerminateSpy).toHaveBeenCalled();
expect(sendMessage).not.toHaveBeenCalled();
});

it('does not reject connection if from restricted country', () => {
(v4 as unknown as jest.Mock).mockReturnValueOnce(connectionId);
// non-restricted country headers
isRestrictedCountrySpy.mockReturnValue(false);

const message: IncomingMessage = new IncomingMessage(new Socket());
mockConnect(websocket, message);

// Test that the connection is tracked.
expect(index.connections[connectionId]).not.toBeUndefined();
expect(index.connections[connectionId].ws).toEqual(websocket);
expect(index.connections[connectionId].messageId).toEqual(0);
});
});
});

describe('handlers', () => {
Expand Down
8 changes: 8 additions & 0 deletions indexer/services/socks/src/helpers/header-utils.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import { CountryHeaders } from '@dydxprotocol-indexer/compliance';

import { IncomingMessage } from '../types';

export function getCountry(req: IncomingMessage): string | undefined {
const countryHeaders: CountryHeaders = req.headers as CountryHeaders;
return countryHeaders['cf-ipcountry'];
}
17 changes: 9 additions & 8 deletions indexer/services/socks/src/lib/subscription.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ import {
logger,
stats,
} from '@dydxprotocol-indexer/base';
import { isRestrictedCountry } from '@dydxprotocol-indexer/compliance';
import { CandleResolution, perpetualMarketRefresher } from '@dydxprotocol-indexer/postgres';
import WebSocket from 'ws';

Expand Down Expand Up @@ -491,13 +490,6 @@ export class Subscriptions {
throw new Error('Invalid undefined id');
}

// TODO(IND-508): Change this to match technical spec for persistent geo-blocking. This may
// either have to replicate any blocking logic added on comlink, or re-direct to comlink to
// determine if subscribing to a specific subaccount is blocked.
if (country !== undefined && isRestrictedCountry(country)) {
throw new BlockedError();
}

try {
const {
address,
Expand All @@ -518,13 +510,19 @@ export class Subscriptions {
method: RequestMethod.GET,
url: `${COMLINK_URL}/v4/addresses/${address}/subaccountNumber/${subaccountNumber}`,
timeout: config.INITIAL_GET_TIMEOUT_MS,
headers: {
'cf-ipcountry': country,
},
transformResponse: (res) => res,
}),
// TODO(DEC-1462): Use the /active-orders endpoint once it's added.
axiosRequest({
method: RequestMethod.GET,
url: `${COMLINK_URL}/v4/orders?address=${address}&subaccountNumber=${subaccountNumber}&status=OPEN,UNTRIGGERED,BEST_EFFORT_OPENED`,
timeout: config.INITIAL_GET_TIMEOUT_MS,
headers: {
'cf-ipcountry': country,
},
transformResponse: (res) => res,
}),
]);
Expand Down Expand Up @@ -597,6 +595,9 @@ export class Subscriptions {
method: RequestMethod.GET,
url: endpoint,
timeout: config.INITIAL_GET_TIMEOUT_MS,
headers: {
'cf-ipcountry': country,
},
transformResponse: (res) => res, // Disables JSON parsing
});
}
Expand Down
10 changes: 2 additions & 8 deletions indexer/services/socks/src/websocket/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import { v4 as uuidv4 } from 'uuid';
import WebSocket from 'ws';

import config from '../config';
import { getCountry } from '../helpers/header-utils';
import {
createErrorMessage,
createConnectedMessage,
Expand All @@ -26,7 +27,6 @@ import {
ALL_CHANNELS,
WebsocketEvents,
} from '../types';
import { CountryRestrictor } from './restrict-countries';

const HEARTBEAT_INTERVAL_MS: number = config.WS_HEARTBEAT_INTERVAL_MS;
const HEARTBEAT_TIMEOUT_MS: number = config.WS_HEARTBEAT_TIMEOUT_MS;
Expand All @@ -42,15 +42,13 @@ export class Index {
// Handlers for pings and invalid messages.
private pingHandler: PingHandler;
private invalidMessageHandler: InvalidMessageHandler;
private countryRestrictor: CountryRestrictor;

constructor(wss: Wss, subscriptions: Subscriptions) {
this.wss = wss;
this.connections = {};
this.subscriptions = subscriptions;
this.pingHandler = new PingHandler();
this.invalidMessageHandler = new InvalidMessageHandler();
this.countryRestrictor = new CountryRestrictor();

// Attach the new connection handler to the websocket server.
this.wss.onConnection((ws: WebSocket, req: IncomingMessage) => this.onConnection(ws, req));
Expand Down Expand Up @@ -99,17 +97,13 @@ export class Index {
* @param req HTTP request accompanying new connection request.
*/
private onConnection(ws: WebSocket, req: IncomingMessage): void {
// Terminate the connection if the connection requestion originated from a restricted country
if (this.countryRestrictor.isRestrictedCountry(req)) {
return ws.terminate();
}

const connectionId: string = uuidv4();

this.connections[connectionId] = {
ws,
messageId: 0,
countryCode: this.countryRestrictor.getCountry(req),
countryCode: getCountry(req),
};

const numConcurrentConnections: number = Object.keys(this.connections).length;
Expand Down
21 changes: 0 additions & 21 deletions indexer/services/socks/src/websocket/restrict-countries.ts

This file was deleted.

0 comments on commit 48e582b

Please sign in to comment.