feat: add SSE support to request-client

This commit is contained in:
zhongming4762 2025-09-11 11:22:47 +08:00
parent eb4f1f8164
commit 66822a5f95
3 changed files with 73 additions and 22 deletions

View File

@ -89,7 +89,8 @@ describe('sSE', () => {
expect(onMessage).toHaveBeenCalledTimes(2); expect(onMessage).toHaveBeenCalledTimes(2);
expect(messages.join('')).toBe('hello world'); expect(messages.join('')).toBe('hello world');
expect(onEnd).toHaveBeenCalledWith('hello world'); // onEnd 不再带参数
expect(onEnd).toHaveBeenCalled();
}); });
it('should apply request interceptors', async () => { it('should apply request interceptors', async () => {
@ -101,20 +102,30 @@ describe('sSE', () => {
fulfilled: interceptor, fulfilled: interceptor,
}); });
vi.stubGlobal('fetch', createFetchMock(['data']));
// 创建 fetch mock并挂到全局 // 创建 fetch mock并挂到全局
const fetchMock = createFetchMock(['data']); const fetchMock = createFetchMock(['data']);
vi.stubGlobal('fetch', fetchMock); vi.stubGlobal('fetch', fetchMock);
await sse.requestSSE('/sse', undefined, {}); await sse.requestSSE('/sse', undefined, {});
expect(interceptor).toHaveBeenCalled(); expect(interceptor).toHaveBeenCalled();
expect(fetchMock).toHaveBeenCalledWith( expect(fetchMock).toHaveBeenCalledWith(
'http://localhost//sse', 'http://localhost/sse',
expect.objectContaining({ expect.objectContaining({
headers: expect.objectContaining({ 'x-test': 'intercepted' }), headers: expect.any(Headers),
}), }),
); );
const calls = fetchMock.mock?.calls;
expect(calls).toBeDefined();
expect(calls?.length).toBeGreaterThan(0);
const init = calls?.[0]?.[1] as RequestInit;
expect(init).toBeDefined();
const headers = init?.headers as Headers;
expect(headers?.get('x-test')).toBe('intercepted');
expect(headers?.get('accept')).toBe('text/event-stream');
}); });
it('should throw error when no reader', async () => { it('should throw error when no reader', async () => {

View File

@ -36,9 +36,10 @@ class SSE {
requestOptions?: SseRequestOptions, requestOptions?: SseRequestOptions,
) { ) {
const baseUrl = this.client.getBaseUrl() || ''; const baseUrl = this.client.getBaseUrl() || '';
const hasUrlSplit = baseUrl.endsWith('/') && url.startsWith('/');
const axiosConfig: InternalAxiosRequestConfig = { let axiosConfig: InternalAxiosRequestConfig<any> = {
url,
method: (requestOptions?.method as any) ?? 'GET',
headers: {} as AxiosRequestHeaders, headers: {} as AxiosRequestHeaders,
}; };
const requestInterceptors = this.client.instance.interceptors const requestInterceptors = this.client.instance.interceptors
@ -48,25 +49,45 @@ class SSE {
requestInterceptors.handlers.length > 0 requestInterceptors.handlers.length > 0
) { ) {
for (const handler of requestInterceptors.handlers) { for (const handler of requestInterceptors.handlers) {
if (handler.fulfilled) { if (typeof handler?.fulfilled === 'function') {
await handler.fulfilled(axiosConfig); const next = await handler.fulfilled(axiosConfig as any);
if (next) axiosConfig = next as InternalAxiosRequestConfig<any>;
} }
} }
} }
const merged = new Headers();
Object.entries(
(axiosConfig.headers ?? {}) as Record<string, string>,
).forEach(([k, v]) => merged.set(k, String(v)));
if (requestOptions?.headers) {
new Headers(requestOptions.headers).forEach((v, k) => merged.set(k, v));
}
if (!merged.has('accept')) {
merged.set('accept', 'text/event-stream');
}
let bodyInit = requestOptions?.body ?? data;
const ct = (merged.get('content-type') || '').toLowerCase();
if (
bodyInit &&
typeof bodyInit === 'object' &&
!ArrayBuffer.isView(bodyInit as any) &&
!(bodyInit instanceof ArrayBuffer) &&
!(bodyInit instanceof Blob) &&
!(bodyInit instanceof FormData) &&
ct.includes('application/json')
) {
bodyInit = JSON.stringify(bodyInit);
}
const requestInit: RequestInit = { const requestInit: RequestInit = {
...requestOptions, ...requestOptions,
body: data, method: axiosConfig.method,
headers: { headers: merged,
...(axiosConfig.headers as Record<string, string>), body: bodyInit,
...requestOptions?.headers,
},
}; };
const response = await fetch( const response = await fetch(safeJoinUrl(baseUrl, url), requestInit);
`${baseUrl}${hasUrlSplit ? '' : '/'}${url}`,
requestInit,
);
if (!response.ok) { if (!response.ok) {
throw new Error(`HTTP error! status: ${response.status}`); throw new Error(`HTTP error! status: ${response.status}`);
} }
@ -78,19 +99,38 @@ class SSE {
throw new Error('No reader'); throw new Error('No reader');
} }
let isEnd = false; let isEnd = false;
let allMessage = '';
while (!isEnd) { while (!isEnd) {
const { done, value } = await reader.read(); const { done, value } = await reader.read();
if (done) { if (done) {
isEnd = true; isEnd = true;
requestOptions?.onEnd?.(allMessage); decoder.decode(new Uint8Array(0), { stream: false });
requestOptions?.onEnd?.();
reader.releaseLock?.();
break; break;
} }
const content = decoder.decode(value, { stream: true }); const content = decoder.decode(value, { stream: true });
requestOptions?.onMessage?.(content); requestOptions?.onMessage?.(content);
allMessage += content;
} }
} }
} }
function safeJoinUrl(baseUrl: string | undefined, url: string): string {
if (!baseUrl) {
return url; // 没有 baseUrl直接返回 url
}
// 如果 url 本身就是绝对地址,直接返回
if (/^https?:\/\//i.test(url)) {
return url;
}
// 如果 baseUrl 是完整 URL就用 new URL
if (/^https?:\/\//i.test(baseUrl)) {
return new URL(url, baseUrl).toString();
}
// 否则,当作路径拼接
return `${baseUrl.replace(/\/+$/, '')}/${url.replace(/^\/+/, '')}`;
}
export { SSE }; export { SSE };

View File

@ -46,7 +46,7 @@ type RequestClientOptions = CreateAxiosDefaults & ExtendOptions;
*/ */
interface SseRequestOptions extends RequestInit { interface SseRequestOptions extends RequestInit {
onMessage?: (message: string) => void; onMessage?: (message: string) => void;
onEnd?: (message: string) => void; onEnd?: () => void;
} }
interface RequestInterceptorConfig { interface RequestInterceptorConfig {