feat: add SSE support to request-client

This commit is contained in:
zhongming4762 2025-09-11 10:50:19 +08:00
parent 6a85b3ab84
commit eb4f1f8164
4 changed files with 251 additions and 1 deletions

View File

@ -0,0 +1,131 @@
import type { RequestClient } from '../request-client';
import { beforeEach, describe, expect, it, vi } from 'vitest';
import { SSE } from './sse';
// 模拟 TextDecoder
const OriginalTextDecoder = globalThis.TextDecoder;
beforeEach(() => {
vi.stubGlobal(
'TextDecoder',
class {
private decoder = new OriginalTextDecoder();
decode(value: Uint8Array, opts?: any) {
return this.decoder.decode(value, opts);
}
},
);
});
// 创建 fetch mock
const createFetchMock = (chunks: string[], ok = true) => {
const encoder = new TextEncoder();
let index = 0;
return vi.fn().mockResolvedValue({
ok,
status: ok ? 200 : 500,
body: {
getReader: () => ({
read: async () => {
if (index < chunks.length) {
return { done: false, value: encoder.encode(chunks[index++]) };
}
return { done: true, value: undefined };
},
}),
},
});
};
describe('sSE', () => {
let client: RequestClient;
let sse: SSE;
beforeEach(() => {
vi.restoreAllMocks();
client = {
getBaseUrl: () => 'http://localhost',
instance: {
interceptors: {
request: {
handlers: [],
},
},
},
} as unknown as RequestClient;
sse = new SSE(client);
});
it('should call requestSSE when postSSE is used', async () => {
const spy = vi.spyOn(sse, 'requestSSE').mockResolvedValue(undefined);
await sse.postSSE('/test', { foo: 'bar' }, { headers: { a: '1' } });
expect(spy).toHaveBeenCalledWith(
'/test',
{ foo: 'bar' },
{
headers: { a: '1' },
method: 'POST',
},
);
});
it('should throw error if fetch response not ok', async () => {
vi.stubGlobal('fetch', createFetchMock([], false));
await expect(sse.requestSSE('/bad')).rejects.toThrow(
'HTTP error! status: 500',
);
});
it('should trigger onMessage and onEnd callbacks', async () => {
const messages: string[] = [];
const onMessage = vi.fn((msg: string) => messages.push(msg));
const onEnd = vi.fn();
vi.stubGlobal('fetch', createFetchMock(['hello', ' world']));
await sse.requestSSE('/sse', undefined, { onMessage, onEnd });
expect(onMessage).toHaveBeenCalledTimes(2);
expect(messages.join('')).toBe('hello world');
expect(onEnd).toHaveBeenCalledWith('hello world');
});
it('should apply request interceptors', async () => {
const interceptor = vi.fn(async (config) => {
config.headers['x-test'] = 'intercepted';
return config;
});
(client.instance.interceptors.request as any).handlers.push({
fulfilled: interceptor,
});
vi.stubGlobal('fetch', createFetchMock(['data']));
// 创建 fetch mock并挂到全局
const fetchMock = createFetchMock(['data']);
vi.stubGlobal('fetch', fetchMock);
await sse.requestSSE('/sse', undefined, {});
expect(interceptor).toHaveBeenCalled();
expect(fetchMock).toHaveBeenCalledWith(
'http://localhost//sse',
expect.objectContaining({
headers: expect.objectContaining({ 'x-test': 'intercepted' }),
}),
);
});
it('should throw error when no reader', async () => {
vi.stubGlobal(
'fetch',
vi.fn().mockResolvedValue({
ok: true,
status: 200,
body: null,
}),
);
await expect(sse.requestSSE('/sse')).rejects.toThrow('No reader');
});
});

View File

@ -0,0 +1,96 @@
import type { AxiosRequestHeaders, InternalAxiosRequestConfig } from 'axios';
import type { RequestClient } from '../request-client';
import type { SseRequestOptions } from '../types';
/**
* SSE模块
*/
class SSE {
private client: RequestClient;
constructor(client: RequestClient) {
this.client = client;
}
public async postSSE(
url: string,
data?: any,
requestOptions?: SseRequestOptions,
) {
return this.requestSSE(url, data, {
...requestOptions,
method: 'POST',
});
}
/**
* SSE请求方法
* @param url - URL
* @param data -
* @param requestOptions - SSE请求选项
*/
public async requestSSE(
url: string,
data?: any,
requestOptions?: SseRequestOptions,
) {
const baseUrl = this.client.getBaseUrl() || '';
const hasUrlSplit = baseUrl.endsWith('/') && url.startsWith('/');
const axiosConfig: InternalAxiosRequestConfig = {
headers: {} as AxiosRequestHeaders,
};
const requestInterceptors = this.client.instance.interceptors
.request as any;
if (
requestInterceptors.handlers &&
requestInterceptors.handlers.length > 0
) {
for (const handler of requestInterceptors.handlers) {
if (handler.fulfilled) {
await handler.fulfilled(axiosConfig);
}
}
}
const requestInit: RequestInit = {
...requestOptions,
body: data,
headers: {
...(axiosConfig.headers as Record<string, string>),
...requestOptions?.headers,
},
};
const response = await fetch(
`${baseUrl}${hasUrlSplit ? '' : '/'}${url}`,
requestInit,
);
if (!response.ok) {
throw new Error(`HTTP error! status: ${response.status}`);
}
const reader = response.body?.getReader();
const decoder = new TextDecoder();
if (!reader) {
throw new Error('No reader');
}
let isEnd = false;
let allMessage = '';
while (!isEnd) {
const { done, value } = await reader.read();
if (done) {
isEnd = true;
requestOptions?.onEnd?.(allMessage);
break;
}
const content = decoder.decode(value, { stream: true });
requestOptions?.onMessage?.(content);
allMessage += content;
}
}
}
export { SSE };

View File

@ -9,6 +9,7 @@ import qs from 'qs';
import { FileDownloader } from './modules/downloader';
import { InterceptorManager } from './modules/interceptor';
import { SSE } from './modules/sse';
import { FileUploader } from './modules/uploader';
function getParamsSerializer(
@ -41,12 +42,14 @@ class RequestClient {
public addResponseInterceptor: InterceptorManager['addResponseInterceptor'];
public download: FileDownloader['download'];
public readonly instance: AxiosInstance;
// 是否正在刷新token
public isRefreshing = false;
public postSSE: SSE['postSSE'];
// 刷新token队列
public refreshTokenQueue: ((token: string) => void)[] = [];
public requestSSE: SSE['requestSSE'];
public upload: FileUploader['upload'];
private readonly instance: AxiosInstance;
/**
* Axios实例
@ -84,6 +87,10 @@ class RequestClient {
// 实例化文件下载器
const fileDownloader = new FileDownloader(this);
this.download = fileDownloader.download.bind(fileDownloader);
// 实例化SSE模块
const sse = new SSE(this);
this.postSSE = sse.postSSE.bind(sse);
this.requestSSE = sse.requestSSE.bind(sse);
}
/**
@ -103,6 +110,13 @@ class RequestClient {
return this.request<T>(url, { ...config, method: 'GET' });
}
/**
* URL
*/
public getBaseUrl() {
return this.instance.defaults.baseURL;
}
/**
* POST请求方法
*/

View File

@ -41,6 +41,14 @@ type RequestContentType =
type RequestClientOptions = CreateAxiosDefaults & ExtendOptions;
/**
* SSE
*/
interface SseRequestOptions extends RequestInit {
onMessage?: (message: string) => void;
onEnd?: (message: string) => void;
}
interface RequestInterceptorConfig {
fulfilled?: (
config: ExtendOptions & InternalAxiosRequestConfig,
@ -78,4 +86,5 @@ export type {
RequestInterceptorConfig,
RequestResponse,
ResponseInterceptorConfig,
SseRequestOptions,
};