script-astra/Android/Sdk/sources/android-35/android/adservices/ondevicepersonalization/ModelManager.java

133 lines
6.2 KiB
Java
Raw Permalink Normal View History

2025-01-20 15:15:20 +00:00
/*
* Copyright (C) 2024 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package android.adservices.ondevicepersonalization;
import android.adservices.ondevicepersonalization.aidl.IDataAccessService;
import android.adservices.ondevicepersonalization.aidl.IIsolatedModelService;
import android.adservices.ondevicepersonalization.aidl.IIsolatedModelServiceCallback;
import android.annotation.CallbackExecutor;
import android.annotation.FlaggedApi;
import android.annotation.NonNull;
import android.annotation.WorkerThread;
import android.os.Bundle;
import android.os.OutcomeReceiver;
import android.os.RemoteException;
import com.android.adservices.ondevicepersonalization.flags.Flags;
import com.android.ondevicepersonalization.internal.util.LoggerFactory;
import java.util.Objects;
import java.util.concurrent.Executor;
/**
* Handles model inference and only support TFLite model inference now. See {@link
* IsolatedService#getModelManager}.
*/
@FlaggedApi(Flags.FLAG_ON_DEVICE_PERSONALIZATION_APIS_ENABLED)
public class ModelManager {
private static final LoggerFactory.Logger sLogger = LoggerFactory.getLogger();
private static final String TAG = ModelManager.class.getSimpleName();
@NonNull private final IDataAccessService mDataService;
@NonNull private final IIsolatedModelService mModelService;
/** @hide */
public ModelManager(
@NonNull IDataAccessService dataService, @NonNull IIsolatedModelService modelService) {
mDataService = dataService;
mModelService = modelService;
}
/**
* Run a single model inference. Only supports TFLite model inference now.
*
* @param input contains all the information needed for a run of model inference.
* @param executor the {@link Executor} on which to invoke the callback.
* @param receiver this returns a {@link InferenceOutput} which contains model inference result
* or {@link Exception} on failure.
*/
@WorkerThread
public void run(
@NonNull InferenceInput input,
@NonNull @CallbackExecutor Executor executor,
@NonNull OutcomeReceiver<InferenceOutput, Exception> receiver) {
final long startTimeMillis = System.currentTimeMillis();
Objects.requireNonNull(input);
Bundle bundle = new Bundle();
bundle.putBinder(Constants.EXTRA_DATA_ACCESS_SERVICE_BINDER, mDataService.asBinder());
bundle.putParcelable(Constants.EXTRA_INFERENCE_INPUT, new InferenceInputParcel(input));
try {
mModelService.runInference(
bundle,
new IIsolatedModelServiceCallback.Stub() {
@Override
public void onSuccess(Bundle result) {
executor.execute(
() -> {
int responseCode = Constants.STATUS_SUCCESS;
long endTimeMillis = System.currentTimeMillis();
try {
InferenceOutputParcel outputParcel =
Objects.requireNonNull(
result.getParcelable(
Constants.EXTRA_RESULT,
InferenceOutputParcel.class));
InferenceOutput output =
new InferenceOutput(outputParcel.getData());
endTimeMillis = System.currentTimeMillis();
receiver.onResult(output);
} catch (Exception e) {
endTimeMillis = System.currentTimeMillis();
responseCode = Constants.STATUS_INTERNAL_ERROR;
receiver.onError(e);
} finally {
logApiCallStats(
Constants.API_NAME_MODEL_MANAGER_RUN,
endTimeMillis - startTimeMillis,
responseCode);
}
});
}
@Override
public void onError(int errorCode) {
executor.execute(
() -> {
long endTimeMillis = System.currentTimeMillis();
receiver.onError(
new IllegalStateException("Error: " + errorCode));
logApiCallStats(
Constants.API_NAME_MODEL_MANAGER_RUN,
endTimeMillis - startTimeMillis,
Constants.STATUS_INTERNAL_ERROR);
});
}
});
} catch (RemoteException e) {
receiver.onError(new IllegalStateException(e));
}
}
private void logApiCallStats(int apiName, long duration, int responseCode) {
try {
mDataService.logApiCallStats(apiName, duration, responseCode);
} catch (Exception e) {
sLogger.d(e, TAG + ": failed to log metrics");
}
}
}